In [1]:
import torch
import torch.nn.functional as F
import sys
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm_notebook as tqdm

sys.path.append('../')
from shared.models.basic_lstm import BasicLSTM
from shared.models.multi_layer_lstm import MultiLayerLSTM
from shared.process.pa4_dataloader import build_all_loaders
from shared.process.PA4Trainer import get_computing_device

In [3]:
computing_device = get_computing_device()
all_loaders, infos = build_all_loaders('../pa4Data/')

char2ind = infos['char_2_index']
ind2char = infos['index_2_char']

model = BasicLSTM(len(char2ind), 100, len(char2ind))
# model = MultiLayerLSTM(len(char2ind), 120, num_layers=7, num_output=len(char2ind))
model.load_state_dict(torch.load('./lstm100_300epochs/model_state.pt', map_location='cpu'))
model.eval()
model.to(computing_device)

BasicLSTM(
  (lstm): LSTM(93, 100, batch_first=True)
  (h2o): Linear(in_features=100, out_features=93, bias=True)
)

In [4]:
prime_str = "<start>"
prime_tensor = torch.zeros(len(prime_str), len(char2ind)).to(computing_device)
        
for i in range(len(prime_str)):
    char = prime_str[i]
    prime_tensor[i, char2ind[char]] = 1    

## Sample Music

In [1]:
# Sample from a category and starting letter
def sample(model, T=None, max_length = 2000, stop_on_end_tag=False):
    
    sample_music = ""
    
    with torch.no_grad():  # no need to track history in sampling
        model.reset_hidden(computing_device)
        
        # Prime with <start>, hidden state is now ready
        logits = model(prime_tensor.unsqueeze(dim=0))[-1]
        
        i = 0
        while i < max_length:
            res_ind = None
            if T is None:
                res_ind = np.argmax(logits).item()
            else:
                prob = np.array(F.softmax(logits/T, dim=0))
                res_ind = np.random.choice(len(char2ind), 1, p=prob)[0]
            final_char = ind2char[res_ind]            
            sample_music += final_char
            i+=1
            if i % 50 == 0:
                print(i)
                
            if stop_on_end_tag and (sample_music[-5:] == "<end>" or sample_music[-7:] == "<start>"):
                print("Found <end>, stop making music at i = {0}.".format(i))
                break
                
            next_char_tensor = torch.zeros(len(char2ind)).to(computing_device)
            next_char_tensor[res_ind] = 1
            next_char_tensor = next_char_tensor.view(1,1,-1)
            logits = model(next_char_tensor)[-1]
        return sample_music

In [None]:
m1 = sample(model, T=1, max_length=2_000, stop_on_end_tag=False)

In [7]:
m1 = m1.replace("<end>", "").replace("<start>", "")
print(m1)


X: 1
T:Brellead Ornoin
O:Frandensonde
I:ab crads
Z:id:hn-harn-25
M:3/4
L:1/8
K:A
~E3 ADF|AF EF|B/B/A FDE||
FA FE|FED F2|BF GB/A/|GF GF|ED DB/e/|dc d2:|


X:56
T:Seat W\'acharkigh, The
T:Flarx ul Flond
Z:id:hn-polka-811
M:2/4
L:1/8
K:Am
DE/D/ DF|BA F2|BB BG/A/|dB/A/ dB/2F/|1 G2G D:|2 AG A||
B2 d>e d|g/f/d/c/ dd|ed d>d|eB GE|D>D GA|
AG GG/E/|AB/c/ BG|FE E>A|Bd e>d|A^G A2:|2


X:70
T:Mattic Fart\'O Tril's Mairolin
R:jig
D:Prady: Mdylye an Cuofly Cop\'i Bran: The M Sores'e Real,
O:Cranhe
T:Noll ath Elvig "Un Glurk or the Moog Spony-#21
D:Deve to Alac.
Z:id:hn-polka-33
M:C|
L:1/8
K:Em
|: B2d GFE | E3 EED | GED A2D | B2c Bc | d2c A2B | AGF G3A ||
V: 
|: Ad | DFD DED | DED D2F | D2D F2 :|
|: cFA efg | fdA BAG | ABc dcB | cA A2F | 
G3 | ~B3 A2d | g2f gfe | dcd FAF |1 A3- AA/ :|2 FED DEA | DFG A2 :|
|: G3 | ~g3 ggd | b2g f2c | gfd cAF | fdc dcd |
~c3 BAF |1 BAA BAG | d2c d2d | 
eAA BAF | def gfe | fef gfe | fAG A2f :|2 GAF G2e ||
|: faf ac | deB dcB | ded cAF |1 B2c d2 :|2 cAG A2 || 
Peve fa |

## Some samples

In [30]:
m1 = """<start>
X:67
T:Ewwly PrevEundie, F Lage (1990-#17, #79, Thin's Foglen Spaistom
Z:id:hn-polka-77
M:2/4
L:1/8
K:G
DB/G/ D2|GE EF|GE DG/E/|B,D/D/B, |
G,E,3|G,2D ||
|:dF d>B|A2 A2:|
|:f2 d2|ef/g/ fe|a2 af|gf/e/ dB|AG F2|EF GA/B/|AB AB|GB F2:|
|:df e/d/d|fB e/f/g|ag g>g|af ba|ba ga|ba g>f|ed ed|dg ag|ae ed|cd e/f/a|gb ag/e/|fA dc|e2 eA/A/|1 de d2:|2 B2 dd||
<end>"""

In [17]:
m2 = """<start>
X:58
T:McHopTeer m'ur Sthe
T:Moree's Joonpipe
R:march
Z:id:hn-polka-23
M:2/4
L:1/8
K:D
dB/A/ BB|BA FG|CE FE|FA A>c|dB c/B/AB|cA E2:|
|:BA GB|ef e2|ea b/a/a|ge f/g/d/e/|fc de|f/e/d/e/ dB|1 A2 BG:|2 e2 BG||
|:cB AG/G/|FA B>B|AF A/B/c|de fd|e>f gf|ec B/A/G/A/|1 GE GF:|2 a2 bg/g/||
|:ef/f/ ga|gb ag|ga gf|ga ge|B/A/B/A/ B/A/g|1 E2 D2:|2 GA G3||
|:gg g>g|f/g/f/d/ e/d/B/c/|dB G//F/|DF AB/c/|dB G//G/G/G/|
D/C/C/D/ D>F|DA GE/F/E/F/|EG FG|A>B d2:|
|:A/A/B/A/ e/d/d/c/|dB AB/c/|dB G2|a/f/d/B/ A/G/A|1 G2 G2:|2 dG G>A||
<end>"""

In [18]:
def heatmap_at_k(cur_str, lstm_out, k, save_fig=False, path_to_save=None, N_PER_ROW = 20):

    def scale(x):
        return ((x - x.min())/(x.max() - x.min())) * 2 - 1
    lstm_out_k = lstm_out[:,k]

    n_activations, n_hiddens = lstm_out.shape
    n_activations = n_activations
    target_length = int(np.ceil(n_activations / N_PER_ROW) * N_PER_ROW)
    n_padding = target_length - n_activations

    num_rows = target_length//N_PER_ROW
    final_frame = np.array(list(scale(lstm_out_k.flatten())) + list(np.zeros(n_padding))).reshape((target_length//N_PER_ROW, N_PER_ROW))

    fig, ax = plt.subplots(figsize=(15,8))
    im = ax.imshow(final_frame, cmap='RdBu_r')

    # We want to show all ticks...
    ax.set_xticks(np.arange(N_PER_ROW))
    ax.set_yticks(np.arange(num_rows))

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    for i in range(num_rows):
        for j in range(N_PER_ROW):
            true_pos = i * N_PER_ROW + j
            if true_pos < len(cur_str):
                char = cur_str[true_pos]
                if char == "\n":
                    char = "\\n"
                text = ax.text(j, i, char,
                               ha="center", va="center", color="black", fontsize=15)

    ax.set_title("Activation map")

    cbar = ax.figure.colorbar(im, ax=ax)

    fig.tight_layout()
    if save_fig:
        plt.savefig(os.path.join(path_to_save, 'heatmap-{0}.png'.format(k)), bbox_inches='tight')
        plt.close()
    else:
        plt.show()

In [19]:
def create_and_save_all_heatmaps(sample, PATH_TO_SAVE = './figs', test_run=False, inds=[]):    
    # Get Outputs from LSTM
    cur_str = sample
    input = torch.zeros(len(cur_str), len(char2ind)).to(computing_device)

    for i in range(len(cur_str)-1):
        char = cur_str[i]
        input[i, char2ind[char]] = 1
    input.unsqueeze_(0)
    model.reset_hidden(computing_device)
    lstm_out = model.lstm(input, model.hidden)[0][0].detach().numpy()

    os.makedirs(PATH_TO_SAVE, exist_ok=True)
    N_K = lstm_out.shape[1]
    if test_run:
        for k in inds:
            heatmap_at_k(cur_str, lstm_out, k, save_fig=False, path_to_save=PATH_TO_SAVE)
    else:
        for k in tqdm(range(N_K)):
            heatmap_at_k(cur_str, lstm_out, k, save_fig=True, path_to_save=PATH_TO_SAVE)

In [16]:
m2 = """<start>
X:71
T:Urnd (er mant Macqunett an !
R:polka
Z:id:hn-polha-p217
M:6/8
L:1/8
K:A
AG D4|Bc A>B|cd FA|BA AB|AG FD|Bd f/e/|ce cG|FE FE:|
D:|
DF DF | FE FD FE | DD F,DA, |
A2A c2A :|d f2g ffd | A2E D3B | BAF GAB |1 c2G F2D :|2 ED FG ||
<end>
"""

In [None]:
create_and_save_all_heatmaps(m2, './figs_300', test_run=True, inds=[0, 1, 2, 44])

HBox(children=(IntProgress(value=0), HTML(value='')))

In [46]:
all_s = ""
for c in all_loaders['train'].dataset.text_chunks:
    all_s += c