In [3]:
import torch
import pandas as pd
import numpy as np
import os
from dataloader import data_split,onehot, SongData
from models import GenLSTM
from configs import cfg

In [5]:
# Load the state_dict of a model given path
def load_model(path):
    
    # Make a fresh model
    model = GenLSTM(input_size = len(SongData.vocab), output_size = len(SongData.vocab))
    
    # Load in saved model params at checkpoint path
    model.load_state_dict(torch.load(path))
    
    return model

# Takes a pre-trained model and generate new text
def generateText(data, model, temp, num_samples, max_len):
    
    # Number of texts to generate
    # Max length of text
    numText = num_samples
    maxLen = max_len
    
    # Start the sentence with the SOS ordinal, ordinal value is 0
    # One-hot the values for generation
    text = torch.tensor(np.zeros((numText, 1))).long()
    
    EOSreached = False
    currSeqIndex = 0
    hc = None
    Temp = temp
    
    model.to(device)
    
    # Put into eval mode
    model.eval()
    
    # No need for gradient
    with torch.no_grad():
        
        # While we haven't reached the EOS run latest character through model with the previous hidden and cell state
        # Softmax and sample output and use that as input
        while (not EOSreached and currSeqIndex < maxLen):
            
            # Onehot the batch at the current sequence index
            oh_input = onehot(text[:,currSeqIndex:currSeqIndex+1], len(data.encode))
            oh_input = oh_input.to(device)
            
            # Pass characters at current sequence index through model
            # First character we don't have hc
            if not hc:
                output, hc = model(oh_input)
            # Every subsequent character we pass in the previous hc
            else:
                output, hc = model(oh_input, hc)
                
            # Convert outputs to probabilities using softmax with temperature
            probs = torch.nn.functional.softmax(output/Temp, dim=2)
            
            # Create distribution and sample for the next indices
            sampled = torch.distributions.categorical.Categorical(probs).sample()
            
            # Join sample with the current body of characters
            text = torch.cat((text, sampled.cpu()), dim=1)
            
            currSeqIndex += 1
            
            # Check whether each rows contain any 1s(EOS) and that must be true for all rows
            if (np.array(text) == 1).any(axis=1).all():
                EOSreached = True
                
    # Decode the ordinals to characters
    generated = [''.join([data.decode[c] for c in gen]) for gen in np.array(text)]
    
    # For each string remove everything after the first EOS(\3) and remove all SOS(\2)
    generated = [text[0:text.find('\3')+1].replace('\2','') if text.find('\3') > 0 else text.replace('\2','') 
                 for text in generated]
    
    print("Finished text generation...", flush=True)
    
    return generated

if __name__ == "__main__":

    # Parse arguments
    temp = 0.8
    num_samples = 10
    max_len = 5000
    
    
    print("Loading necessary files to generate...", flush=True)
    # Get dataset to know meta-data
    data = pd.read_csv('songdata.csv')
    SongData.init_vocab(data)
    train_set, val_set = data_split(data)
    
    # Get path
    checkpointPath = 'model_checkpoints/training_session0/'
    savedModePath = 'LSTMmodel_Final.fnl'
    
    # Load in trained model
    model = load_model(os.path.join(checkpointPath, savedModePath))
    
    useCuda = cfg['use_cuda']
    # Check for cuda and set default compute device
    if ( torch.cuda.is_available() and useCuda ):
        device = torch.device("cuda")

    else:
        device = torch.device("cpu")

    print("Using %s for compute..." % device.type, flush=True)
    
    
    print("Beginning text generation...", flush=True)
    
    # Generate text
    gen = generateText(train_set, model, temp, num_samples, max_len)
    
    # Write out generated texts into result file
    resultPath = "./results/"
    resultFile = "generatedsamples_temp_{0}".format(temp)
    
    for idx, text in enumerate(gen):
        print(text, flush=True)
        print("---------------", flush=True)

Loading necessary files to generate...
Using cuda for compute...
Beginning text generation...
Finished text generation...
Here we go in the sky  
But he says it's the feeling that's always too late  
  
What about me.  
How to do you win  
Love is the things that would be strong  
She ain't got no nails  
We'll spare her shoot is strange  
Holding on to the back  
  
So he meats the dreams  
Drive in angels the heat  
He matters the woman is the way  
She was denying good  
Where the heat of love  
I know when I think  
I am more of him  
He likes a loving could help it  
And the guiltary to her  
(People was a chance)  
  
(Chorus)  
  
Chorus  
  
The changesh end  
  
Of a corner little mishol of from my soul  
Because you breathe free  
I really have to find me  
Take dish now  
And you cry  
I wonder what to do  
  
I'll be sure you work as made  
For ask me the land  
Kill of winter my said  
While closer, flesh  
I will surrender then I haven't better  
I wish I was standin'  
 

---------------
Yeah evening me for me  
  
A mirror in a straighty spell  
For left to make uh lovers  
Roll to the door of me  
Run to the last car  
  
1941  
I dream of your hand  
Ambongian will crack  
Who are a three  
Heartbeat you  
(Give me the way and open the air  
He's the aninats  
My heart is gathered  
(Ah hey)  
He's in the way hard  
He ain't gonna girl  
That they're still the moon  
But we got a too  
So come  
We but know what he begins  
  
Then the ends me in the blues  
My horses don't have been running  
But he was the mission  
  
Gather with her list to crawl  
He's a ringing dishine  
They'll be the being are cold  
We don't change to the world  
Yes I sit by one  
  
Today, what you going  
  
Tell me what they don't make this  
Just keep gonna dropped me a milly  
  
I don't wanna see you goin  
I'll break the money  
And if you took the world  
Echoluate  
I only hear that  
I may it running out of Merrack  
I got the woods  
Until the same  
Death who ca

---------------
I got to fight all the stars  
You can chainfuses at your tears  
That you love me  
To make love  
  
Oh puble the Lord  
So the moment to me  
To tar what I mean  
And I don't want to  
And it's time we want to go  
I'm alone  
It came to say  
And I'm gonna take the other  
And I can't turn things  
I ran a treat like the boy  
I want to see me good  
The party was a wanted beg't look back  
It's the good bunnun  
You can't see the day ain't so body  
Why don't you promise  
I've got the darker  
I'll be better than keep on calling me  
People try to be free  
You take my life  
  
[Chorus]  
  
I got one so if I say  
So I got the same me in the night  
And I don't know she says I  
I got my heart is back  
And I'm just, 'til I do  
Oh, you see her six shame  
I can do it?  
  
[Chorus]


---------------
The morning streets I fall to through the way  
I kiss you tomorrow so he could have been still  
I never know like I don't know why  
Spend drink back the sheets 