In [None]:
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
import torchvision.transforms as transforms
from torch import optim
import torch
from matplotlib import pyplot as plt

In [None]:
from sklearn.preprocessing import OneHotEncoder
import numpy as np

In [None]:
if torch.cuda.is_available():
    device=torch.device(type='cuda',index=0)
else:
    device=torch.device(type='cpu',index=0)

In [None]:
file_path = '/kaggle/input/piano-musics-abc-notation/piano-musics-abc-notation.txt'
with open(file_path, 'r') as file:
    abc_data = file.read()

In [None]:
tokens = sorted(set(abc_data))
num_tokens = len(tokens)
num_tokens += 1
token_to_index = {token: i for i, token in enumerate(tokens)}
index_to_token = {i: token for token, i in token_to_index.items()}

In [None]:
# abc_data = abc_data[0:10000]

In [None]:
SEQ_LENGTH = 100
STEP = SEQ_LENGTH

indices_sequence = [token_to_index[token] for token in abc_data]

input_target_pairs = []

for i in range(0, len(indices_sequence) - 2*SEQ_LENGTH, STEP):
    # Extract input and target sequences
    input_seq = indices_sequence[i:i + SEQ_LENGTH] + [86]  # Append EOS token to input sequence
    target_seq = [86] + indices_sequence[i + SEQ_LENGTH:i + 2 * SEQ_LENGTH]  # Prepend SOS token to target sequence
    
    # Add the input-target pair to the list
    input_target_pairs.append((input_seq, target_seq))

In [None]:
class MusicDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        input_seq, target_seq = self.pairs[idx]
        return torch.tensor(input_seq), torch.tensor(target_seq)


In [None]:
batch_size = 4
music_dataset = MusicDataset(input_target_pairs)
data_loader = DataLoader(music_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# for batch_idx, (input_batch, target_batch) in enumerate(data_loader):
#     print(input_batch.shape, target_batch.shape)

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_size, embed_size, hidden_size, dropout_p=0.1):
        super().__init__()
        self.e=nn.Embedding(input_size, embed_size)
        self.dropout=nn.Dropout(dropout_p)
        self.gru=nn.GRU(embed_size,hidden_size,num_layers = 2, batch_first=True)
    
    def forward(self,x):
        x=self.e(x)
        x=self.dropout(x)
        outputs, hidden=self.gru(x)
        return outputs, hidden

In [None]:
class Decoder(nn.Module):
    def __init__(self,output_size,embed_size,hidden_size):
        super().__init__()
        self.e=nn.Embedding(output_size,embed_size)
        self.relu=nn.ReLU()
        self.gru=nn.GRU(embed_size, hidden_size,num_layers = 2, batch_first=True)
        self.lin=nn.Linear(hidden_size,output_size)
        self.lsoftmax=nn.LogSoftmax(dim=-1)
    
    def forward(self,x,prev_hidden):
        x=self.e(x)
        x=self.relu(x)
        output,hidden=self.gru(x,prev_hidden)
        y=self.lin(output)
        y=self.lsoftmax(y)
        return y, hidden

In [None]:
def train_one_epoch():
    encoder.train()
    decoder.train()
    track_loss=0
    
    for i, (seq, next_seq) in enumerate(data_loader):
        
        seq=seq.to(device)
        next_seq=next_seq.to(device)
        
        encoder_outputs, encoder_hidden=encoder(seq)
        decoder_hidden=encoder_hidden
        yhats, decoder_hidden = decoder(next_seq[:,0:-1],decoder_hidden)
                    
        gt=next_seq[:, 1:]
        
        yhats_reshaped=yhats.view(-1,yhats.shape[-1])
        
        gt=gt.reshape(-1)
        
        loss=loss_fn(yhats_reshaped,gt)
        track_loss+=loss.item()
        
        opte.zero_grad()
        optd.zero_grad()
        
        loss.backward()
        
        opte.step()
        optd.step()
        
    return track_loss/len(data_loader)    

In [None]:
test_abc = abc_data[0:1000]

In [None]:
len(test_abc)

In [None]:
SEQ_LENGTH = 100
STEP = SEQ_LENGTH

indices_sequence = [token_to_index[token] for token in test_abc]

test_input_target_pairs = []

for i in range(0, len(indices_sequence) - SEQ_LENGTH, STEP):
    input_seq = indices_sequence[i:i + SEQ_LENGTH] + [86]  # Append EOS token to input sequence
    target_seq = [86] + indices_sequence[i + SEQ_LENGTH:i + 2 * SEQ_LENGTH]  # Prepend SOS token to target sequence
    
    test_input_target_pairs.append((input_seq, target_seq))

In [None]:
test_batch_size = 1
test_music_dataset = MusicDataset(test_input_target_pairs)
test_data_loader = DataLoader(test_music_dataset, batch_size=test_batch_size, shuffle=True)

In [None]:
# for batch_idx, (input_batch, target_batch) in enumerate(test_data_loader):
#     print(input_batch.shape, target_batch.shape)

In [None]:
#eval loop (written assuming batch_size=1)
def eval_one_epoch(e,n_epochs):
    
    encoder.eval()
    decoder.eval()
    track_loss=0
    
    with torch.no_grad():
        
        for i, (seq,next_seq) in enumerate(test_data_loader):
            
            seq=seq.to(device)
            next_seq=next_seq.to(device)
            
            encoder_outputs, encoder_hidden=encoder(seq)
            decoder_hidden=encoder_hidden
            input_ids=next_seq[:,0]
            yhats=[]
            
            if e+1==n_epochs:
                print("\n\n\n#Seq: ", i)
                pred_music=""

            for j in range(0, 100):
                
                probs, decoder_hidden = decoder(input_ids.unsqueeze(1),decoder_hidden)
                yhats.append(probs)
                _,input_ids=torch.topk(probs,1,dim=-1)
                input_ids=input_ids.squeeze(1,2)
                
                if e+1==n_epochs:
                    token = index_to_token[input_ids.item()]
                    pred_music += token
                
            if e+1==n_epochs:
                input_seq_list = seq.tolist()
                # Convert indices to tokens
                input_tokens = [[index_to_token[index] for index in seq[:-1]] for seq in input_seq_list]
                input_tokes = ''.join(input_tokens[0])
                
                print("\n\nInput Seq: ", input_tokes)
                print("\n\nPredicted Music:",pred_music)
                print("\n\nMusic: ", input_tokes + pred_music)

In [None]:
embed_size=300
hidden_size=512

encoder=Encoder(num_tokens,embed_size,hidden_size).to(device)
decoder=Decoder(num_tokens,embed_size,hidden_size).to(device)

loss_fn=nn.NLLLoss(ignore_index=0).to(device)
lr=0.01
opte=optim.Adam(params=encoder.parameters(), lr=lr, weight_decay=0.001)
optd=optim.Adam(params=decoder.parameters(), lr=lr, weight_decay=0.001)

n_epochs=20

for e in range(n_epochs):
    print("Epoch=",e+1, end=": ")
    print("Train Loss=", round(train_one_epoch(),4))
    eval_one_epoch(e,n_epochs)