In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

sys.path.append('/coc/scratch/anarayanan68/mint/')

In [3]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, initializers, regularizers

from mint.core import fact_model, base_model_util, base_models, primitive_models, model_builder
from mint.utils import inputs_util, config_util



In [4]:
config_path = '/coc/scratch/anarayanan68/mint/configs/audio_based_blending__embed_cond_vary_num_prims.config'
enc_config_path = '/coc/scratch/anarayanan68/mint/configs/audio_based_blending__embed_cond_vary_num_prims-enc_config.yml'

In [5]:
num_primitives = 20
expt_root = f'/coc/scratch/anarayanan68/mint/_expts/audio_based_blending__embed_cond__L2_on_prims/{num_primitives}prims'

checkpoint_dir = os.path.join(expt_root, 'checkpoints/')

In [6]:
configs = config_util.get_configs_from_pipeline_file(config_path)
model_config = configs['model']
eval_config = configs['eval_config']
eval_dataset_config = configs['eval_dataset']

In [7]:
enc_config_yaml = config_util.read_yaml_config(enc_config_path)

In [8]:
# Model build & restore

model = model_builder.build(model_config, is_training=False,
    num_primitives=num_primitives, encoder_config_yaml=enc_config_yaml, dataset_config=eval_dataset_config)

checkpoint_manager=tf.train.CheckpointManager(
    tf.train.Checkpoint(model=model),
    directory=checkpoint_dir,
    max_to_keep=None)

checkpoint_manager.restore_or_initialize()

'/coc/scratch/anarayanan68/mint/_expts/audio_based_blending__embed_cond__L2_on_prims/20prims/checkpoints/ckpt-9999'

In [9]:
all_primitives = model.blend_vec_to_seq(np.eye(num_primitives))
all_primitives

<tf.Tensor: shape=(20, 120, 147), dtype=float32, numpy=
array([[[-4.7377771e-33, -4.7378359e-33, -4.7384736e-33, ...,
         -4.7377771e-33, -4.7378359e-33, -4.7384662e-33],
        [-4.7377771e-33, -4.7384662e-33, -4.7384736e-33, ...,
         -4.7378359e-33, -4.7378359e-33, -4.7384662e-33],
        [-4.7377771e-33, -4.7378359e-33, -4.7384736e-33, ...,
         -4.7378359e-33, -4.7378359e-33, -4.7384662e-33],
        ...,
        [-4.7377771e-33, -4.7378359e-33, -4.7377830e-33, ...,
         -4.7377771e-33, -4.7378359e-33, -4.7384662e-33],
        [-4.7377771e-33, -4.7384662e-33, -4.7377830e-33, ...,
         -4.7378359e-33, -4.7378359e-33, -4.7384662e-33],
        [-4.7378359e-33, -4.7384662e-33, -4.7384736e-33, ...,
         -4.7377771e-33, -4.7378359e-33, -4.7384662e-33]],

       [[-6.9793252e-08,  7.6786797e-08,  4.6918940e-14, ...,
         -1.3940189e-32,  8.2696909e-32,  2.5367493e-08],
        [-1.1678607e-07,  2.6478352e-15,  1.6153719e-16, ...,
         -4.9009876e-22,  5

In [10]:
save_dir = os.path.join(expt_root, 'primitives')
os.makedirs(save_dir, exist_ok=True)
for i in range(num_primitives):
    np.save(os.path.join(save_dir, f"PRIMITIVE--{i+1}_of_{num_primitives}.npy"), all_primitives[i])