In [1]:
import os
os.chdir('..')

In [2]:
from allennlp.models.archival import load_archive
from src.models.vae import VAE
from src.models.lyrics_generator import LyricsGenerator
from src.models.lyrics_discriminator import LyricsDiscriminator
from src.models.lyrics_gan import LyricsGan
from src.modules.encoders import VariationalEncoder
from src.modules.decoders.variational_decoder import VariationalDecoder

archive = load_archive('models/lyrics_gan_mse_with_spec_id/model.tar.gz')
model = archive.model
model.eval()

LyricsGan(
  (generator): LyricsGenerator(
    (_latent_mapper): Sequential(
      (0): Linear(in_features=128, out_features=256, bias=True)
      (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=256, out_features=256, bias=True)
      (4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Linear(in_features=256, out_features=128, bias=True)
    )
    (_ce_loss): BCEWithLogitsLoss()
  )
  (discriminator): LyricsDiscriminator(
    (_activation): LeakyReLU(negative_slope=0.2)
    (_classifier): Sequential(
      (0): Linear(in_features=256, out_features=128, bias=True)
      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
      (3): Linear(in_features=128, out_features=128, bias=True)
      (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stat

In [6]:
import random
from typing import Iterable
from allennlp.data.instance import Instance
import logging
logger = logging.getLogger(__name__)
import torch

class LyricsLatentGen(object):
    """
    This class handles predicting latent code
    """
    def __init__(self,
                 validation_data: Iterable[Instance],
                 num_replies: int = 1,
                 num_samples: int = 1):
        self.instances = validation_data
        self.num_samples = len(validation_data)
        self.num_replies = num_replies

    def generate_sample(self, model):
        sample_instances = random.sample(self.instances, self.num_samples)
        for sample in sample_instances:
            model_input = {'source_mu': torch.from_numpy(sample['spec_mu']),
                           'source_std': torch.from_numpy(sample['spec_std']),
                           'target_mu': torch.from_numpy(sample['lyrics_mu']),
                           'stage': ['generator']}
            predicted_latent = model.decode(model.forward(**model_input))
            yield {'spec_id': sample['spec_id'], 'z': predicted_latent}
    
from allennlp.data.iterators import HomogeneousBatchIterator
from allennlp.data.vocabulary import Vocabulary
from allennlp.common.params import Params
from src.data.dataset import LyricsGanDatasetReader
import pickle

valid_ds = pickle.load(open('data/processed/test_data.pkl', 'rb'))

predicted_latents = []
dsampler = LyricsLatentGen(valid_ds)
predicted_latents += dsampler.generate_sample(model)
print(predicted_latents[0].keys())
len(predicted_latents)

dict_keys(['spec_id', 'z'])


808

In [7]:
with open('data/outputs/lyrics_latent_test.pkl', 'wb') as f:
    pickle.dump(predicted_latents, f)