In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

sys.path.append('/coc/scratch/anarayanan68/mint/')

In [3]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, initializers, regularizers

from mint.core import fact_model, base_model_util, base_models, primitive_models
from mint.utils import inputs_util, config_util



In [8]:
class BlendController(keras.Model):
    def __init__(self, num_primitives, cond_vocab_size, config_dict, name="BlendController", **kwargs):
        super(BlendController, self).__init__(name=name, **kwargs)

        self.num_primitives = num_primitives
        
        transformer_config_yaml = config_dict['transformer']
        self.audio_linear_embedding = base_models.LinearEmbedding(
            transformer_config_yaml['hidden_size'])
        self.audio_pos_embedding = base_models.PositionEmbedding(
            transformer_config_yaml['sequence_length'],
            transformer_config_yaml['hidden_size'])
        self.transformer = base_models.Transformer(
            hidden_size=transformer_config_yaml['hidden_size'],
            num_hidden_layers=transformer_config_yaml['num_hidden_layers'],
            num_attention_heads=transformer_config_yaml['num_attention_heads'],
            intermediate_size=transformer_config_yaml['intermediate_size']
        )

        initializer = initializers.RandomNormal()   # so that, hopefully, each embedding is different from the start
        self.conditioning_block = layers.Embedding(cond_vocab_size, transformer_config_yaml['hidden_size'],
            input_length=1, embeddings_initializer=initializer, embeddings_regularizer=None, name='cond_input_embedding')

        output_block_config_yaml = config_dict['output_block']
        self.output_block = keras.Sequential([
            layers.GlobalAveragePooling1D(),
            base_models.MLP(out_dim=num_primitives, hidden_dim=output_block_config_yaml['hidden_dim']),
            layers.Softmax()
        ])


    def call(self, inputs):
        audio_seq = inputs['audio_input']                               # (batch_size, seq_len, audio_feature_dim)
        audio_features = self.audio_linear_embedding(audio_seq)         # (batch_size, seq_len, transformer_hidden_size)
        audio_features = self.audio_pos_embedding(audio_features)       # (batch_size, seq_len, transformer_hidden_size)
        audio_features = self.transformer(audio_features)               # (batch_size, seq_len, transformer_hidden_size)

        conditioning = inputs['conditioning_input']                     # (batch_size, conditioning_input_dim)
        conditioning_features = self.conditioning_block(conditioning)   # (batch_size, 1, transformer_hidden_size)

        combined_features = audio_features + conditioning_features      # (batch_size, seq_len, transformer_hidden_size)
        out_vec = self.output_block(combined_features)                  # (batch_size, num_primitives)
        return out_vec

In [5]:
enc_config_yaml_path = '/coc/scratch/anarayanan68/mint/configs/audio_based_blending__embed_based_conditioning-enc_config.yml'
enc_config_yaml = config_util.read_yaml_config(enc_config_yaml_path)

enc_config_yaml

{'num_primitives': 100,
 'conditioning_dim': 1,
 'conditioning_vocab_size': 10,
 'audio_to_blend_vec': {'transformer': {'sequence_length': 60,
   'hidden_size': 256,
   'num_hidden_layers': 1,
   'num_attention_heads': 8,
   'intermediate_size': 1024},
  'output_block': {'hidden_dim': 256}},
 'blend_vec_to_seq': {'target_shape': '120,147'}}

In [9]:
a2b = BlendController(
    num_primitives=enc_config_yaml['num_primitives'],
    cond_vocab_size=enc_config_yaml['conditioning_vocab_size'],
    config_dict=enc_config_yaml['audio_to_blend_vec']
)

In [10]:
test_audios = np.random.randn(10,60,35)
test_cond = np.arange(10).reshape((-1,1))

test_ins = {
    'audio_input': test_audios,
    'conditioning_input': test_cond
}
test_outs = a2b(test_ins)

test_outs

<tf.Tensor: shape=(10, 100), dtype=float32, numpy=
array([[0.01170535, 0.01079346, 0.00968513, 0.01035765, 0.00944266,
        0.01025323, 0.00821733, 0.01215233, 0.00861848, 0.01071779,
        0.0118113 , 0.0087884 , 0.00981206, 0.01228856, 0.00941273,
        0.0096304 , 0.01045018, 0.00959557, 0.00934233, 0.00923613,
        0.00847851, 0.00967992, 0.00929915, 0.00890301, 0.00842417,
        0.00981407, 0.01289271, 0.00912185, 0.00999935, 0.01246821,
        0.00742049, 0.01167622, 0.00959626, 0.00694625, 0.00841384,
        0.00995899, 0.01132338, 0.00792442, 0.01104884, 0.01151963,
        0.0122895 , 0.01073207, 0.00873414, 0.00938094, 0.00847098,
        0.01282649, 0.00891778, 0.00905286, 0.01034872, 0.00996136,
        0.00849223, 0.01116323, 0.01140295, 0.00972154, 0.01341921,
        0.00828702, 0.00961407, 0.00973543, 0.00942487, 0.00945527,
        0.00925506, 0.00908624, 0.01231977, 0.01080496, 0.00917186,
        0.0095695 , 0.0099464 , 0.00814201, 0.01108983, 0.0112592

In [11]:
b2s = primitive_models.BlendVecToSeq(
    num_primitives=enc_config_yaml['num_primitives'],
    config_dict=enc_config_yaml['blend_vec_to_seq']
)

test_seq_outs = b2s(test_outs)
test_seq_outs

<tf.Tensor: shape=(10, 120, 147), dtype=float32, numpy=
array([[[0.00752923, 0.00752923, 0.00752923, ..., 0.00752923,
         0.00752923, 0.00752923],
        [0.00752923, 0.00752923, 0.00752923, ..., 0.00752923,
         0.00752923, 0.00752923],
        [0.00752923, 0.00752923, 0.00752923, ..., 0.00752923,
         0.00752923, 0.00752923],
        ...,
        [0.00752923, 0.00752923, 0.00752923, ..., 0.00752923,
         0.00752923, 0.00752923],
        [0.00752923, 0.00752923, 0.00752923, ..., 0.00752923,
         0.00752923, 0.00752923],
        [0.00752923, 0.00752923, 0.00752923, ..., 0.00752923,
         0.00752923, 0.00752923]],

       [[0.00752923, 0.00752923, 0.00752923, ..., 0.00752923,
         0.00752923, 0.00752923],
        [0.00752923, 0.00752923, 0.00752923, ..., 0.00752923,
         0.00752923, 0.00752923],
        [0.00752923, 0.00752923, 0.00752923, ..., 0.00752923,
         0.00752923, 0.00752923],
        ...,
        [0.00752923, 0.00752923, 0.00752923, ..., 0.