In [27]:
import torch
import torch.nn.functional as F
import sys
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm_notebook as tqdm

sys.path.append('../')
from shared.models.basic_lstm import BasicLSTM
from shared.process.pa4_dataloader import build_all_loaders
from shared.process.PA4Trainer import get_computing_device

In [29]:
computing_device = get_computing_device()
all_loaders, infos = build_all_loaders('../pa4Data/')

char2ind = infos['char_2_index']
ind2char = infos['index_2_char']

model = BasicLSTM(len(char2ind), 300, len(char2ind))
model.load_state_dict(torch.load('./lstm300adam0.001/model_state.pt', map_location='cpu'))
model.eval()
model.to(computing_device)

BasicLSTM(
  (lstm): LSTM(93, 300, batch_first=True)
  (h2o): Linear(in_features=300, out_features=93, bias=True)
)

In [30]:
prime_str = "<start>"
prime_tensor = torch.zeros(len(prime_str), len(char2ind)).to(computing_device)
        
for i in range(len(prime_str)):
    char = prime_str[i]
    prime_tensor[i, char2ind[char]] = 1    

## Sample Music

In [31]:
# Sample from a category and starting letter
def sample(model, T=None, max_length = 2000, stop_on_end_tag=False):
    
    sample_music = ""
    
    with torch.no_grad():  # no need to track history in sampling
        model.reset_hidden(computing_device)
        
        # Prime with <start>, hidden state is now ready
        logits = model(prime_tensor.unsqueeze(dim=0))[-1]
        
        i = 0
        while i < max_length:
            res_ind = None
            if T is None:
                res_ind = np.argmax(logits).item()
            else:
                prob = np.array(F.softmax(logits/T, dim=0))
                res_ind = np.random.choice(len(char2ind), 1, p=prob)[0]
            final_char = ind2char[res_ind]            
            sample_music += final_char
            i+=1
            if i % 50 == 0:
                print(i)
                
            if stop_on_end_tag and (sample_music[-5:] == "<end>" or sample_music[-7:] == "<start>"):
                print("Found <end>, stop making music at i = {0}.".format(i))
                break
                
            next_char_tensor = torch.zeros(len(char2ind)).to(computing_device)
            next_char_tensor[res_ind] = 1
            next_char_tensor = next_char_tensor.view(1,1,-1)
            logits = model(next_char_tensor)[-1]

        return sample_music#.split("<end>")

In [42]:
m1 = sample(model, T=0.8, max_length=2_000, stop_on_end_tag=True)

50
100
150
200
Found <end>, stop making music at i = 240.


In [43]:
m1 = m1.replace("<end>", "")
print(m1)


X:77
T:Meis Erlister Dol achaine Crossey, The
T:Callyce Rither, The
R:polka
H:See also #77, #16
D:Replay Danole, Tun Fons Doler
Z:id:hn-polka-75
M:2/4
L:1/8
K:D
E>A BA|FA AB/c/|dA AB|^c/d/ eg gf|eg fg|fd Bd|Bd dB|AB cd|ef a>g|fA A2:|



In [44]:
import pyabc
def is_valid_abc_syntax(sample):
    try:
        _ = pyabc.Tune(sample)
        # Valid
        return True
    except Exception as e:
        print("Invalid abc syntax!", e)
        return False
is_valid_abc_syntax(m1)

True

In [41]:
for m in m1.split("<end>"):
    print(m)


X:71
T:Bolly Berder Crosser Polka
R:polka
Z:id:hn-polka-77
Z:id:hn-polka
7:Mary Bergin: Feadoga Stain 21
Z:id:hn-polka-92
M:2/4
L:1/8
K:A
A>B cE|cB Bc|ed dB|AB c/B/A/B/|AF DF|AF AB|1 d2 B2:|2 d2 de||
|:fbrea af|ef/e/ cd|ef g2|dB cB|AB cd|ef e2|dB AB|1 BA GA:|2 A2 D2:|2 d2 d2||



In [78]:
m1 = """<start>
X:1
T:Be so nom Bangis of tom the Corne to Malse #158
Z:id:hn-polka-89
M:2/4
L:1/8
K:A
A>A B/A/G/A/|BA BA/B/|dB dB/B/|AB cB/A/|GA Bd/e/|dB B/A/G/A/|1 AG GB/c/:|2 BG GA||
P:variation 29
D:D Danne
Z:id:hn-polka-86
M:2/4
L:1/8
K:D
de dB|AB cB|AB BA|BA BA|BA Bd|ef/e/ dB|1 BA AB/c/:|2 BA Bd|e2 ef|ed e>f|ed cB|A2 AB/c/|dB AB/c/|ed cB|cA AB/c/|dB AB/c/|
dB Bd/B/|dB BA/B/|dB AB/c/|Bd Bc/d/|eA cB/A/|BA GB/d/|ef ec/e/|fd dB|AB AG|
AB AB/c/|dB AB|Bc AB/c/|dB AB/A/|BA BG/B/|AB cB/A/|BA Bd/e/|fd ef|fe fd|ef ed|e2 ed/B/|A2 AG||
<end>"""

In [79]:
cur_str = m1
input = torch.zeros(len(cur_str), len(char2ind)).to(computing_device)
        
for i in range(len(cur_str)-1):
    char = cur_str[i]
    input[i, char2ind[char]] = 1
input.unsqueeze_(0)
lstm_out = model.lstm(input, model.hidden)[0][0].detach().numpy()

In [80]:
lstm_out.shape

(533, 100)

In [81]:
def heatmap_at_k(k, save_fig=False, path_to_save=None):
    N_PER_ROW = 20

    def scale(x):
        return ((x - x.min())/(x.max() - x.min())) * 2 - 1
    lstm_out_k = lstm_out[:,k]

    n_activations, n_hiddens = lstm_out.shape
    n_activations = n_activations
    target_length = int(np.ceil(n_activations / N_PER_ROW) * N_PER_ROW)
    n_padding = target_length - n_activations

    num_rows = target_length//N_PER_ROW
    final_frame = scale(np.array(list(lstm_out_k.flatten()) + list(np.zeros(n_padding))).reshape((target_length//N_PER_ROW, N_PER_ROW)))

    fig, ax = plt.subplots(figsize=(20,10))
    im = ax.imshow(final_frame, cmap='RdBu_r')

    # We want to show all ticks...
    ax.set_xticks(np.arange(N_PER_ROW))
    ax.set_yticks(np.arange(num_rows))

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    for i in range(num_rows):
        for j in range(N_PER_ROW):
            true_pos = i * N_PER_ROW + j
            if true_pos < len(cur_str):
                char = cur_str[true_pos]
                if char == "\n":
                    char = "\\n"
                text = ax.text(j, i, char,
                               ha="center", va="center", color="black", fontsize=12)

    ax.set_title("Activation map")

    cbar = ax.figure.colorbar(im, ax=ax)

    fig.tight_layout()
    if save_fig:
        plt.savefig(os.path.join(path_to_save, 'heatmap-{0}.png'.format(k)), bbox_inches='tight')
        plt.close()
    else:
        plt.show()

In [82]:
PATH_TO_SAVE = './figs'
os.makedirs(PATH_TO_SAVE, exist_ok=True)
N_K = lstm_out.shape[1]
for i in tqdm(range(N_K)):
    heatmap_at_k(i, save_fig=True, path_to_save=PATH_TO_SAVE)
#     print("Done", i)

HBox(children=(IntProgress(value=0), HTML(value='')))