# Melody Generation using a Simple Feed Forward Net

In [None]:
import sys
sys.path.append("..") 

import matplotlib.pyplot as plt
import music21 as m21
import torch
import torch.nn.functional as F
from preprocess import load_songs_in_kern, NoteEncoder, TERM_SYMBOL

In [None]:
torch.manual_seed(0);

In [None]:
encoder = NoteEncoder()
scores = load_songs_in_kern('./../deutschl/erk')
enc_songs = encoder.encode_songs(scores)

symbols = sorted(
    list(set([item for sublist in enc_songs for item in sublist])))
stoi = {s:i+1 for i, s in enumerate(symbols)}
stoi[TERM_SYMBOL] = 0
itos = {i: s for s, i in stoi.items()}


In [None]:
scores[0].show('midi')

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
    
print(f'{device=}')

In [None]:
xs = []
ys = []
for m in enc_songs:
    chs = [TERM_SYMBOL] + m + [TERM_SYMBOL]
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        xs.append(ix1)
        ys.append(ix2)

xs = torch.tensor(xs, device=device)
ys = torch.tensor(ys, device=device)

#xs = xs[:1]
#ys = ys[:1]
# one-hot-encoding
xenc = F.one_hot(xs, num_classes=len(stoi)).float()

In [None]:
W = torch.randn((len(stoi), len(stoi)), requires_grad=True, device=device)
W.shape

In [None]:
# training aka gradient decent
epochs = 2_000
for k in range(epochs):
    # forward pass
    logits = xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(dim=1, keepdim=True)
    loss = -probs[torch.arange(len(ys), device=device), ys].log().mean()

    print(f'epoch {k}, loss: {loss.item()}')
    
    # backward pass
    W.grad = None # set gradients to zero
    loss.backward()
    y = torch.zeros(len(stoi), device=device)
    y[ys[0]] = 1
    #print(W.grad)
    #print(xenc.T @ (probs-y))
    #print(torch.allclose(W.grad, xenc.T @ (probs-y)))
    # update
    W.data += -10.0 * W.grad
    
# (s-y)*x

In [None]:
generated_encoded_song = []
char = '.'
ix = stoi[char]
while True:
    # sampling similar to forward pass
    xenc = F.one_hot(torch.tensor(
        [ix], device=device), num_classes=len(stoi)).float()
    logits = xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(dim=1, keepdim=True)
    ix = torch.multinomial(probs, num_samples=1,
                           replacement=True).item()
    char = itos[ix]
    if char == '.':
        break
    generated_encoded_song.append(char)
    # break
len(generated_encoded_song)


In [None]:
generated_song = encoder.decode_song(generated_encoded_song)
generated_song.show('midi')


In [None]:
W1 = torch.randn((len(stoi), len(stoi)//4), 
                 requires_grad=True, device=device)
W2 = torch.randn((len(stoi)//4, len(stoi)), 
                 requires_grad=True, device=device)

In [None]:
# training aka gradient decent
epochs = 2_000
for k in range(epochs):
    # forward pass
    x = xenc @ W1
    logits = x @ W2
    counts = logits.exp()
    probs = counts / counts.sum(dim=1, keepdim=True)
    loss = -probs[torch.arange(len(ys), device=device), ys].log().mean()

    print(f'epoch {k}, loss: {loss.item()}')
    
    # backward pass
    W1.grad = None # set gradients to zero
    W2.grad = None # set gradients to zero
    loss.backward()
    
    # update
    W1.data += -10.0 * W1.grad
    W2.data += -10.0 * W2.grad