In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

sys.path.append('/coc/scratch/anarayanan68/mint/')

In [3]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, initializers, regularizers

from mint.core import fact_model, base_model_util, base_models, primitive_models
from mint.utils import inputs_util, config_util



In [14]:
class BlendController(keras.Model):
    def __init__(self, num_primitives, config_dict, name="BlendController", **kwargs):
        super(BlendController, self).__init__(name=name, **kwargs)

        self.num_primitives = num_primitives
        
        transformer_config_yaml = config_dict['transformer']
        self.audio_linear_embedding = base_models.LinearEmbedding(
            transformer_config_yaml['hidden_size'])
        self.audio_pos_embedding = base_models.PositionEmbedding(
            transformer_config_yaml['sequence_length'],
            transformer_config_yaml['hidden_size'])
        self.transformer = base_models.Transformer(
            hidden_size=transformer_config_yaml['hidden_size'],
            num_hidden_layers=transformer_config_yaml['num_hidden_layers'],
            num_attention_heads=transformer_config_yaml['num_attention_heads'],
            intermediate_size=transformer_config_yaml['intermediate_size']
        )
        self.conditioning_block = base_models.LinearEmbedding(
            transformer_config_yaml['hidden_size'])

        output_block_config_yaml = config_dict['output_block']
        self.output_block = keras.Sequential([
            layers.GlobalAveragePooling1D(),
            base_models.MLP(out_dim=num_primitives, hidden_dim=output_block_config_yaml['hidden_dim']),
            layers.Softmax()
        ])


    def call(self, inputs):
        audio_seq = inputs['audio_input']                               # (batch_size, seq_len, audio_feature_dim)
        audio_features = self.audio_linear_embedding(audio_seq)         # (batch_size, seq_len, transformer_hidden_size)
        audio_features = self.audio_pos_embedding(audio_features)       # (batch_size, seq_len, transformer_hidden_size)
        audio_features = self.transformer(audio_features)               # (batch_size, seq_len, transformer_hidden_size)

        conditioning = inputs['conditioning_input']                     # (batch_size, conditioning_input_dim)
        conditioning_features = self.conditioning_block(conditioning)   # (batch_size, transformer_hidden_size)
        conditioning_features = tf.expand_dims(conditioning_features, axis=1)   # (batch_size, 1, transformer_hidden_size) for broadcasting

        combined_features = audio_features + conditioning_features      # (batch_size, seq_len, transformer_hidden_size)
        out_vec = self.output_block(combined_features)                  # (batch_size, num_primitives)
        return out_vec

In [5]:
enc_config_yaml_path = '/coc/scratch/anarayanan68/mint/configs/audio_based_blending__pilot-enc_config.yml'
enc_config_yaml = config_util.read_yaml_config(enc_config_yaml_path)

enc_config_yaml

{'num_primitives': 4,
 'audio_to_blend_vec': {'transformer': {'sequence_length': 60,
   'hidden_size': 256,
   'num_hidden_layers': 1,
   'num_attention_heads': 8,
   'intermediate_size': 1024},
  'output_block': {'hidden_dim': 256}},
 'blend_vec_to_seq': {'target_shape': '120,147'}}

In [15]:
a2b = BlendController(
    num_primitives=enc_config_yaml['num_primitives'],
    config_dict=enc_config_yaml['audio_to_blend_vec']
)

In [17]:
test_audios = np.random.randn(8,60,35)
test_cond = np.random.randn(8,2)

test_ins = {
    'audio_input': test_audios,
    'conditioning_input': test_cond
}
test_outs = a2b(test_ins)

test_outs

<tf.Tensor: shape=(8, 4), dtype=float32, numpy=
array([[0.2872154 , 0.2312548 , 0.2329574 , 0.2485724 ],
       [0.26555488, 0.24848545, 0.26498887, 0.22097085],
       [0.26632044, 0.25842467, 0.30879375, 0.1664612 ],
       [0.33855763, 0.17743748, 0.28232503, 0.20167986],
       [0.2971854 , 0.22945176, 0.257187  , 0.21617585],
       [0.27057102, 0.17500177, 0.28820422, 0.26622298],
       [0.30671242, 0.22543874, 0.2509774 , 0.21687144],
       [0.33105534, 0.23257144, 0.19516602, 0.24120717]], dtype=float32)>

In [10]:
b2s = primitive_models.BlendVecToSeq(
    num_primitives=enc_config_yaml['num_primitives'],
    config_dict=enc_config_yaml['blend_vec_to_seq']
)

test_seq_outs = b2s(test_outs)
test_seq_outs

<tf.Tensor: shape=(8, 120, 147), dtype=float32, numpy=
array([[[0.00752923, 0.00752923, 0.00752923, ..., 0.00752923,
         0.00752923, 0.00752923],
        [0.00752923, 0.00752923, 0.00752923, ..., 0.00752923,
         0.00752923, 0.00752923],
        [0.00752923, 0.00752923, 0.00752923, ..., 0.00752923,
         0.00752923, 0.00752923],
        ...,
        [0.00752923, 0.00752923, 0.00752923, ..., 0.00752923,
         0.00752923, 0.00752923],
        [0.00752923, 0.00752923, 0.00752923, ..., 0.00752923,
         0.00752923, 0.00752923],
        [0.00752923, 0.00752923, 0.00752923, ..., 0.00752923,
         0.00752923, 0.00752923]],

       [[0.00752923, 0.00752923, 0.00752923, ..., 0.00752923,
         0.00752923, 0.00752923],
        [0.00752923, 0.00752923, 0.00752923, ..., 0.00752923,
         0.00752923, 0.00752923],
        [0.00752923, 0.00752923, 0.00752923, ..., 0.00752923,
         0.00752923, 0.00752923],
        ...,
        [0.00752923, 0.00752923, 0.00752923, ..., 0.0