# Listen to training time generations

### Import libraries

In [None]:
import sys
sys.path.insert(0, '../')

import IPython.display as ipd

from audiocraft.models import encodec, loaders, builders
from audiocraft.utils import utils
import omegaconf
import torch
import os
import random
import json
random.seed(3)

### Preparing sample paths

In [None]:
base_folder = "/home/karlos/Documents/workspace/proj/music/trained_models/v0_22_apr26/dea3706f/samples"
samples_per_epoch = 3
epochs = list(range(50, 101, 10))
epochs

In [None]:
files = []
for e in epochs:
    all_generations = [(e, os.path.join(base_folder, str(e), f)) for f in os.listdir(os.path.join(base_folder, str(e))) if f.endswith('.pt')]
    selected_generations = random.sample(all_generations, samples_per_epoch)
    files.extend(selected_generations)

files

### Load model

In [None]:
musicgen_model_name = 'facebook/musicgen-small'

#Loading Encodec model
pkg = loaders.load_compression_model_ckpt(musicgen_model_name)
cfg = omegaconf.OmegaConf.create(pkg['xp.cfg'])

kwargs = utils.dict_from_config(getattr(cfg, 'encodec'))

encoder_name = kwargs.pop('autoencoder')
quantizer_name = kwargs.pop('quantizer')

encoder, decoder = builders.get_encodec_autoencoder(encoder_name, cfg)
quantizer = builders.get_quantizer(quantizer_name, cfg, encoder.dimension)
frame_rate = kwargs['sample_rate'] // encoder.hop_length
renormalize = kwargs.pop('renormalize', False)
kwargs.pop('renorm', None), type(quantizer), frame_rate
model = encodec.EncodecModel(encoder, decoder, quantizer,
                    frame_rate=frame_rate, renormalize=renormalize, **kwargs)
model.load_state_dict(pkg['best_state'])
model = model.eval()

In [None]:
for e, file in files:
    #reading generation
    encoded_generation = torch.load(file)
    with open(file[:-3] + '.json') as j:
        original_input = json.load(j)['conditioning']['condition']
        print(f'Epoch: {e}, Original Input: {original_input}, Filename: {file}')
    # decode and listen
    melody_waveform_reconstructed = model.decode(encoded_generation.unsqueeze(0), None)
    ipd.display(ipd.Audio(melody_waveform_reconstructed[0].detach().numpy(), rate=kwargs['sample_rate']))