In [1]:
import torch
import torch.nn.functional as F
import sys
import numpy as np
sys.path.append('../')
from shared.models.basic_lstm import BasicLSTM
from shared.process.pa4_dataloader import build_all_loaders
from shared.process.PA4Trainer import get_computing_device

In [45]:
computing_device = get_computing_device()
all_loaders, infos = build_all_loaders('../pa4Data/')

char2ind = infos['char_2_index']
ind2char = infos['index_2_char']

model = BasicLSTM(len(char2ind), 100, len(char2ind))
model.load_state_dict(torch.load('./session_train_pass_hidden_between_epochs/model_state_min_val_so_far.pt', map_location='cpu'))
model.eval()
model.to(computing_device)

BasicLSTM(
  (lstm): LSTM(93, 100, batch_first=True)
  (h2o): Linear(in_features=100, out_features=93, bias=True)
)

In [46]:
prime_str = "<start>"
prime_tensor = torch.zeros(len(prime_str), len(char2ind)).to(computing_device)
        
for i in range(len(prime_str)):
    char = prime_str[i]
    prime_tensor[i, char2ind[char]] = 1    

## Sample Music

In [76]:
# Sample from a category and starting letter
def sample(model, T=None, max_length = 2000):
    
    sample_music = ""
    
    with torch.no_grad():  # no need to track history in sampling
        model.reset_hidden(computing_device)
        
        # Prime with <start>, hidden state is now ready
        logits = model(prime_tensor.unsqueeze(dim=0))[-1]
        
        i = 0
        while i < max_length:
            res_ind = None
            if T is None:
                res_ind = np.argmax(logits).item()
            else:
                prob = np.array(F.softmax(logits/T, dim=0))
                res_ind = np.random.choice(len(char2ind), 1, p=prob)[0]
            final_char = ind2char[res_ind]            
            sample_music += final_char
            i+=1
            if i % 50 == 0:
                print(i)
                
            if sample_music[-5:] == "<end>" or sample_music[-5:] == "<start>":
                print("Found <end>, stop making music at i = {0}.".format(i))
                break
                
            next_char_tensor = torch.zeros(len(char2ind)).to(computing_device)
            next_char_tensor[res_ind] = 1
            next_char_tensor = next_char_tensor.view(1,1,-1)
            logits = model(next_char_tensor)[-1]

        return sample_music

In [77]:
m1 = sample(model, T=1, max_length=2_000)

30
60
90
120
150
180
210
240
270
Found <end>, stop making music at i = 286.


In [78]:
print(m1)


X: 7DTrime GourZ:id:: 5wiol tre surolind pos the Ilen
T:Fory ane tu Pllther Jfs an He. see cangce
H:Whes wa dertime tog
T:magy atl Alscim
R:jig
Z:id:hn-po-ka-19
M:2/4
L:1/8
K:D
DE EF|EE G2|B/c/B/G/ AF|BA AB/c/|cA BAc|Ad BA|GA c2|
Ac d>e|cA d/g/e/f/|ed/B/ dB|1 AG/G/ GG:|2 GE D2||
<end>


In [66]:
model_config = {
    'num_input': 10,
    'num_hidden': 100,
    'num_output': 19,
}
BasicLSTM(**model_config)
# model_config

BasicLSTM(
  (lstm): LSTM(10, 100, batch_first=True)
  (h2o): Linear(in_features=100, out_features=19, bias=True)
)

In [58]:
m2 = sample(model, T=0.1, max_length=2_000)

30
60
90
120
150
180
210
Found <end>, stop making music at i = 210.


In [59]:
print(m2)


X: TT
T:Dannan Holly:
Z:id:hn-polka-77
M:2/4
L:1/8
K:D
A>B AB|AB AB|cB AB/c/|dB AB/c/|dB AB/c/|dB AB/c/|dB AB|AB AB|1 BA GB/d/:|2 G2 GB||
|:BA BA|BA BA|BA BA|BA BA|BA BA|BA BA|BA AB/A/|1 GE GE:|2 GE GE||
<end>
