In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

sys.path.append('/coc/scratch/anarayanan68/mint/')

In [23]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, initializers, regularizers

from mint.core import fact_model, base_model_util, base_models, primitive_models
from mint.utils import inputs_util, config_util

In [24]:
class AudioToBlend(keras.Model):
    def __init__(self, num_primitives, config_dict, name="AudioToBlend", **kwargs):
        super(AudioToBlend, self).__init__(name=name, **kwargs)

        self.num_primitives = num_primitives

        transformer_config_yaml = config_dict['transformer']
        self.transformer = base_models.Transformer(
            hidden_size=transformer_config_yaml['hidden_size'],
            num_hidden_layers=transformer_config_yaml['num_hidden_layers'],
            num_attention_heads=transformer_config_yaml['num_attention_heads'],
            intermediate_size=transformer_config_yaml['intermediate_size']
        )
        self.audio_pos_embedding = base_models.PositionEmbedding(
            transformer_config_yaml['sequence_length'],
            transformer_config_yaml['hidden_size'])
        self.audio_linear_embedding = base_models.LinearEmbedding(
            transformer_config_yaml['hidden_size'])

        output_block_config_yaml = config_dict['output_block']
        self.output_block = keras.Sequential([
            layers.GlobalAveragePooling1D(),
            base_models.MLP(out_dim=num_primitives, hidden_dim=output_block_config_yaml['hidden_dim']),
            layers.Softmax()
        ])


    def call(self, audio_seq):
        # audio_seq shape: (batch_size, seq_len, input_feature_dim)
        audio_features = self.audio_linear_embedding(audio_seq)

        # audio_features shape: (batch_size, seq_len, transformer_hidden_size)
        audio_features = self.audio_pos_embedding(audio_features)
        audio_features = self.transformer(audio_features)

        # audio_features shape: (batch_size, seq_len, transformer_hidden_size)
        out_vec = self.output_block(audio_features)

        # out_vec shape: (batch_size, num_primitives)
        return out_vec

In [25]:
enc_config_yaml_path = '/coc/scratch/anarayanan68/mint/configs/audio_based_blending__pilot-enc_config.yml'
enc_config_yaml = config_util.read_yaml_config(enc_config_yaml_path)

enc_config_yaml

{'num_primitives': 10,
 'audio_to_blend': {'transformer': {'sequence_length': 60,
   'hidden_size': 256,
   'num_hidden_layers': 1,
   'num_attention_heads': 8,
   'intermediate_size': 1024},
  'output_block': {'hidden_dim': 256}},
 'blend_vec_to_seq': {'target_shape': '120,147', 'wt_seed': 101}}

In [26]:
a2b = AudioToBlend(
    num_primitives=enc_config_yaml['num_primitives'],
    config_dict=enc_config_yaml['audio_to_blend']
)

In [27]:
test_ins = np.random.randn(8,60,35)
test_outs = a2b(test_ins)

test_outs, tf.reduce_sum(test_outs, axis=-1)

(<tf.Tensor: shape=(8, 10), dtype=float32, numpy=
 array([[0.09079967, 0.09859749, 0.09884806, 0.1179162 , 0.08714379,
         0.1028384 , 0.1120392 , 0.08785363, 0.08625776, 0.11770577],
        [0.08388319, 0.09766105, 0.1091526 , 0.11005241, 0.10903998,
         0.08448704, 0.10793084, 0.08343273, 0.08686956, 0.12749058],
        [0.08138527, 0.11689258, 0.09156946, 0.09183931, 0.10789804,
         0.09526936, 0.11078382, 0.08774048, 0.09458779, 0.12203394],
        [0.09270462, 0.11047561, 0.12352896, 0.09548024, 0.10819941,
         0.1078762 , 0.0818181 , 0.09399509, 0.0817726 , 0.10414915],
        [0.08114699, 0.09779686, 0.10253843, 0.10997347, 0.10502414,
         0.11275101, 0.09661342, 0.08905603, 0.10094567, 0.10415403],
        [0.07180721, 0.10969441, 0.12554705, 0.09265181, 0.09760097,
         0.09966142, 0.10811237, 0.08201694, 0.10200477, 0.11090305],
        [0.0993668 , 0.11980567, 0.08112411, 0.10295081, 0.10283118,
         0.08881556, 0.10981584, 0.07350355, 0.

In [30]:
b2s = primitive_models.BlendVecToSeq(
    num_primitives=enc_config_yaml['num_primitives'],
    config_dict=enc_config_yaml['blend_vec_to_seq']
)

test_seq_outs = b2s(test_outs)
test_seq_outs

<tf.Tensor: shape=(8, 120, 147), dtype=float32, numpy=
array([[[-0.01116452, -0.00605964,  0.01113016, ...,  0.02368946,
         -0.00804726,  0.00970953],
        [-0.01516559,  0.0179451 , -0.0018675 , ...,  0.02008962,
          0.02980702,  0.01009487],
        [-0.01411252,  0.01633983,  0.01735257, ..., -0.00031472,
          0.02716686,  0.01474576],
        ...,
        [-0.02071053,  0.05081129,  0.00598954, ...,  0.01156972,
          0.04686275,  0.02126425],
        [ 0.02015749,  0.0191913 , -0.00886635, ...,  0.0240495 ,
          0.02729913, -0.03569255],
        [ 0.02464804,  0.00617509, -0.01686607, ...,  0.02765157,
          0.01559675, -0.01129722]],

       [[-0.0074795 , -0.0063454 ,  0.012312  , ...,  0.02201549,
         -0.0066855 ,  0.00919778],
        [-0.01296428,  0.01701553,  0.00118419, ...,  0.01931095,
          0.02868771,  0.01077864],
        [-0.01466077,  0.01378675,  0.01567617, ..., -0.0014971 ,
          0.02811928,  0.01439935],
        ...,