<a href="https://colab.research.google.com/github/asigalov61/SuperPiano/blob/master/Super_Piano_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Super Piano 3: Google Music Transformer
## Generating Music with Long-Term structure
### Based on 2019 ICLR paper by Cheng-Zhi Anna Huang, Google Brain

Huge thanks go out to the following people who contributed the code/repos used in this colab. Additional contributors are listed in the code as well.

1) Naikaru https://github.com/Naikaru/AI_MusicTransformer

2) Kevin-Yang https://github.com/jason9693/midi-neural-processor

3) jinyi12, Zac Koh, Akamight, Zhang https://github.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248

Thank you so much for your hard work and for sharing it with the world :)


###Setup Environment and Dependencies. Check GPU.

In [None]:
#@title Check if GPU (driver) is avaiiable (you do not want to run this on CPU, trust me)
!nvcc --version

In [None]:
#@title Clone/Install all dependencies
!git clone https://github.com/asigalov61/SuperPiano
!git clone https://github.com/asigalov61/midi-neural-processor
!pip install pretty-midi
!pip install tqdm

In [None]:
#@title Import all needed modules. Check GPU again (if there is no GPU, CPU will be enabled but it will be very slow)
import pickle
import os
import sys
import math
import random
import pickle
import numpy as np
import tqdm
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

## Setup the MIDI Pre-Processor/Encoder and Directories/Paths

In [None]:
#@title Model Main Hyperparameters
d_model = 512
nhead = 8
dim_feedforward = 1024
dropout = 0.2
num_layer = 6
batch_size = 8
sequence_length = 1024
warmup_steps = 4000
pad_token = 1   
# vocabulary_size = 388 # depends on the dataset

In [None]:
#@title Import Special MIDI Encoder
%cd /content/midi-neural-processor/

from processor import encode_midi, decode_midi

%cd /content

def encode_midi_files(midi_dir_path, save_dir_path, extension):
  #create directory for saving files
  os.makedirs(save_dir_path, exist_ok=True)
  #get all midi files from midi_directory
  for file in os.listdir(midi_dir_path):
    if file.endswith(tuple(extension)):
        print(os.path.join(midi_dir_path, file))
        print(file + ' is being processed', flush=True)
        try:
          encoded_file = encode_midi(midi_dir_path+file)
        except KeyboardInterrupt:
            print(' Stopped by keyboard')
            return
        except EOFError:
            print('EOF Error')
            return
        with open(save_dir_path+file+'.encoded', 'wb') as f:
            pickle.dump(encoded_file, f)


In [None]:
#@title Create all IO Directories and Paths
!mkdir /content/midi_files
!mkdir /content/encoded_midi
!mkdir /content/training_set
!mkdir /content/test_set
%cd /content/


midi_dir_path = '/content/midi_files/'
save_dir_path = '/content/encoded_midi/'
train_dir_path = '/content/training_set/'
test_dir_path = '/test_set/'
extension = ['.mid', '.midi']

#Download and Unzip MIDI DataSet

In [None]:
#@title (The Best Choice/Works best stand-alone) Super Piano 2 Original 2500 MIDIs of Piano Music
%cd /content/midi_files/
!wget 'https://github.com/asigalov61/SuperPiano/raw/master/Super_Piano_2_MIDI_DataSet_CC_BY_NC_SA.zip'
!unzip -j 'Super_Piano_2_MIDI_DataSet_CC_BY_NC_SA.zip'
%cd /content/

#Encode all MIDI files from the DataSet

In [None]:
#@title RUN ONCE or after you added/removed MIDIs
encode_midi_files(midi_dir_path, save_dir_path, extension) # only use once

In [None]:
#@title Split training data and test data in two different folders
# for more convenience we split training data and test data in two different folders

import shutil

def create_dataset(save_dir_path, split_ratio=0.9):
    dataset = [file for file in os.listdir(save_dir_path)]
    np.random.shuffle(dataset)

    train_set = dataset[:int(len(dataset) * split_ratio)]    
    test_set = dataset[int(len(dataset) * split_ratio):]
    
    shutil.rmtree(train_dir_path)
    shutil.rmtree(test_dir_path)

    os.makedirs(train_dir_path, exist_ok=True)
    os.makedirs(test_dir_path, exist_ok=True)

    for file in os.listdir(save_dir_path):
        if os.stat(save_dir_path+file).st_size != 0:
            if file in test_set:
                shutil.copyfile(save_dir_path+file, test_dir_path+file)
            else:
                shutil.copyfile(save_dir_path+file, train_dir_path+file)


In [None]:
create_dataset(save_dir_path, split_ratio=0.9) # only use once

In [None]:
def load_dataset(train_dir_path, test_dir_path):
    #load all encoded file
    train_set = [file for file in os.listdir(train_dir_path)]
    test_set = [file for file in os.listdir(test_dir_path)]

    return train_set, test_set

In [None]:
def generate_batch(dataset, dir_path, sequence_length=1024, batch_size=8):
    sequence_length += 1
    batch_midi = []
    while len(batch_midi) < batch_size:
        file = random.choice(dataset)
        #if the midi contains more sequence that the sequence length
        try:
            with open(dir_path+file, 'rb') as f:
                data = pickle.load(f)
        except:
            print(dir_path+file + " file not found.")
        if sequence_length <= len(data):
            begin_index = random.randrange(0, len(data) - sequence_length)
            data = data[begin_index:begin_index + sequence_length]
            batch_midi.append(data)
    batch_midi = torch.Tensor(batch_midi)
    inputs = batch_midi[:, :-1]
    labels = batch_midi[:, 1:]
    return inputs, labels

## 2 - Attention

In [None]:
#@title Attention Code Implementation
class MusicMultiheadAttention(torch.nn.MultiheadAttention):
    def __init__(self, embed_dim, nhead, dropout=0.1, bias=True, add_bias_kv=False, 
                 add_zero_attn=False, kdim=None, vdim=None):
        
        torch.nn.MultiheadAttention.__init__(self, embed_dim, nhead, dropout=0.1, 
                                             bias=True, add_bias_kv=False, 
                                             add_zero_attn=False, kdim=None, vdim=None)
        
        self.embed_dim = embed_dim
        self.weights_q = torch.nn.Linear(self.embed_dim, self.embed_dim)
        self.weights_k = torch.nn.Linear(self.embed_dim, self.embed_dim)
        self.weights_v = torch.nn.Linear(self.embed_dim, self.embed_dim)
        self.weights_o = torch.nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self, query, key, value, key_padding_mask=None,
                need_weights=True, attn_mask=None):
        Q, K, V = self.transform_input(query, key, value)
        # Reshaping the matrices 
        # Each L × D query, key, and value matrix is then split into H L × D 
        # h_D parts or attention heads, indexed by h, and with dimension D_h = D/H
        Q = self.matrix_to_heads(Q)
        K = self.matrix_to_heads(K)
        V = self.matrix_to_heads(V)
        
        # learning a separate relative position embedding Er of shape (H, L, Dh)
        Er = torch.randn([self.num_heads, query.size(1), self.head_dim], requires_grad=False).to(device)

        # we transpose the two last dimensions of Er to realize Q*Er^T 
        QEr = torch.matmul(Q, torch.transpose(Er,1,2))
        # QEr of shape (B, H, L, L)     
        # QEr = torch.einsum('bhld,ld->bhll', [Q, Er])

        # 1. Pad a dummy column vector of length L before the leftmost column.
        QEr = torch.nn.functional.pad(QEr, (1,0), mode="constant", value=0)

        # 2. Reshape the matrix to have shape (L+1, L). 
        QEr = torch.reshape(QEr, [QEr.size(0), QEr.size(1), QEr.size(3), QEr.size(2)])
        
        # 3. Slice that matrix to retain only the last l rows and all the columns, 
        # resulting in a (L, L) matrix again, but now absolute-by-absolute indexed, 
        # which is the S rel that we need.
        S_rel = QEr[:,:,1:,:]

        z_attention = self.attention(Q, K, V, S_rel, attn_mask)
        z_attention = self.weights_o(z_attention)
        # Masking can be added and Dropout ?

        return z_attention

    def attention(self, Q, K, V, S, mask):
        # Dh = self.head_dim // self.num_heads
        logits = torch.add(torch.matmul(Q, torch.transpose(K, 2, 3)), S) / math.sqrt((self.head_dim // self.num_heads))
        # print("logits : ", logits.size())
        # print("mask : ", mask.size())
        if mask is not None:
        #    mask = mask.unsqueeze(1) #shape of mask must be broadcastable with shape of underlying tensor
            logits = logits.masked_fill(mask == 0, -1e9) #masked_fill fills elements of scores with -1e9 where mask == 0
        #if mask is not None:
        #    logits += (mask.to(torch.int64) * -1e9).to(logits.dtype)        
            
        activation = F.softmax(logits, -1)
        attention = torch.matmul(activation, V)
        attention = torch.reshape(attention, (attention.size(0), -1, self.embed_dim))
        return attention
    
    def matrix_to_heads(self, qkv):
        '''
            Takes a query/key/value (qkv) matrix and reshapes it to  B * H * L * D_h heads 
            with dimension D_h = D/H
        '''
        batch_size_q = qkv.size(0)
        #qkv = torch.reshape(qkv, (batch_size_q, qkv.size(0), self.num_heads, self.head_dim))
        qkv = torch.reshape(qkv, (batch_size_q, self.num_heads, qkv.size(1), self.head_dim))
        return qkv

    def transform_input(self, query, key, value):
        '''
            Transforming the input vector, X, of LxD dimension 
            into 
                queries: Q = XW^Q 
                keys:    K = XW^K
            and values:  V = XW^V
            which are all DxD square matrices.
        '''
        return self.weights_q(query), self.weights_k(key), self.weights_v(value)
    

class MusicTransformerEncoderLayer(torch.nn.TransformerEncoderLayer):
    def __init__(self, d_model, nhead, dim_feedforward=2048,
                 dropout=0.1, activation="relu"):
        torch.nn.TransformerEncoderLayer.__init__(self, d_model, nhead)
        self.d_model = d_model
        # OverRide
        self.self_attn = MusicMultiheadAttention(d_model, nhead)

class MusicTransformerDecoderLayer(torch.nn.TransformerDecoderLayer):
    def __init__(self, d_model, nhead, dim_feedforward=2048,
                 dropout=0.1, activation="relu"):
        torch.nn.TransformerDecoderLayer.__init__(self, d_model,nhead)
        self.d_model = d_model
        # OverRide
        self.self_attn = MusicMultiheadAttention(d_model, nhead)

class MusicTransformerEncoder(torch.nn.TransformerEncoder):
    def __init__(self, encoder_layer, vocabulary_size=390, num_encoder_layers=6, normalization=None):
        super().__init__(encoder_layer, num_encoder_layers, normalization)
        self.d_model = encoder_layer.d_model
        self.vocabulary_size = vocabulary_size
        self.dropout = torch.nn.Dropout(encoder_layer.dropout.p)
        self.embedding = torch.nn.Embedding(num_embeddings=self.vocabulary_size, embedding_dim=self.d_model)
    
    def forward(self, src, mask=None, src_key_padding_mask=None):
        
        pos_encoding = DynamicPositionEmbedding(self.d_model, src.size(1))
        
        src = math.sqrt(self.d_model) * self.embedding(src.to(torch.long).to(device))
        src = pos_encoding(src)
        src = self.dropout(src)

        return super().forward(src, mask=mask, src_key_padding_mask=src_key_padding_mask)


class MusicTransformerDecoder(torch.nn.TransformerDecoder):
    def __init__(self, decoder_layer, vocabulary_size=390, num_decoder_layers=6, normalization=None):
        super().__init__(decoder_layer, num_decoder_layers, normalization)
        self.d_model = decoder_layer.d_model
        self.vocabulary_size = vocabulary_size
        self.dropout = torch.nn.Dropout(decoder_layer.dropout.p)
        self.embedding = torch.nn.Embedding(num_embeddings=self.vocabulary_size, embedding_dim=self.d_model)

    def forward(self, tgt, memory, tgt_mask=None, 
                memory_mask=None, tgt_key_padding_mask=None,
                memory_key_padding_mask=None):        
                
        pos_encoding = DynamicPositionEmbedding(self.d_model, tgt.size(1))

        tgt = pos_encoding(math.sqrt(self.d_model) * self.embedding(tgt.to(torch.long).to(device)))
        tgt = self.dropout(tgt)

        return super().forward(tgt, memory, tgt_mask=tgt_mask, 
                memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask)


class MusicTransformer(torch.nn.modules.Transformer):
    def __init__(self, d_model=512, nhead=8, vocabulary_size=388,
                 num_encoder_layers=6, num_decoder_layers=6, 
                 dim_feedforward=2048, dropout=0.1, activation="relu", 
                 custom_encoder=None, custom_decoder=None):
        
        super().__init__(d_model=d_model, nhead=nhead, 
                         num_encoder_layers=num_encoder_layers,
                         num_decoder_layers=num_decoder_layers, 
                         dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, 
                         custom_encoder=custom_encoder, custom_decoder=custom_decoder)
        
        self.vocabulary_size = vocabulary_size
        ###        
        self.fc = torch.nn.Linear(self.d_model, self.vocabulary_size)
        self._reset_parameters()

    def forward(self, src, tgt, src_mask=None, tgt_mask=None,
                memory_mask=None, src_key_padding_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None):

        #if src.size(1) != tgt.size(1):
        #    raise RuntimeError("the batch number of src and tgt must be equal")

        #if src.size(2) != self.d_model or tgt.size(2) != self.d_model:
        #    raise RuntimeError("the feature number of src and tgt must be equal to d_model")
            
        memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
                              tgt_key_padding_mask=tgt_key_padding_mask,
                              memory_key_padding_mask=memory_key_padding_mask)

        output = self.fc(output)        
        
        return output




class DynamicPositionEmbedding(torch.nn.Module):
    def __init__(self, embedding_dim, max_seq=1024):
        super().__init__()
        embed_sinusoid_list = np.array([[
            [
                math.sin(
                    pos * math.exp(-math.log(10000) * i/embedding_dim) *
                    math.exp(math.log(10000)/embedding_dim * (i % 2)) + 0.5 * math.pi * (i % 2)
                )
                for i in range(embedding_dim)
            ]
            for pos in range(max_seq)
        ]])
        self.positional_embedding = embed_sinusoid_list

    def forward(self, x):
        x = x + torch.from_numpy(self.positional_embedding[:, :x.size(1), :]).to(x.device, dtype=x.dtype)
        return x


class PositionalEncoder(torch.nn.Module):
    def __init__(self, d_model, max_seq_len=1024):
        super().__init__()
        self.d_model = d_model
        pe = torch.zeros(max_seq_len, d_model)
        for pos in range(max_seq_len):
            for i in range(0, d_model, 2):
                pe[pos, i] = \
                    math.sin(pos / (10000 ** ((2 * i) / d_model)))
                pe[pos, i + 1] = \
                    math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        with torch.no_grad():
            x = x * math.sqrt(self.d_model)
            seq_len = x.size(1)
            self.pe.to(device)
            print("self.pe : ", self.pe.device.type)
            pe = self.pe[:, :seq_len]
            print("pe : ", pe.device.type)
            x = x + pe
            return x   

In [None]:
#@title Dataset Prep
from torch.utils.data.dataset import Dataset

def tensorFromSequence(sequence):
    """
    Generate tensors from the sequence in numpy.
    """
    output = torch.tensor(sequence).long()

    return output


def PrepareData(npz_file, split='train', L=1024):
    """
    Function to prepare the data into pairs (input, target).
    Adds [PAD], [SOS] and [EOS] tokens into the data,
    where [PAD]=1, [SOS]=2, [EOS]=3.
    Limits the sequence to length of L.
    """
    print("Preparing data for",split,"split...")
    # Load in the data
    full_data = np.load(npz_file, fix_imports=True, encoding="latin1", allow_pickle=True)
    data = full_data[split]

    # Extract the vocab from file
    vocab = GenerateVocab(npz_file)
    # Generate new vocab to map to later
    new_vocab = np.arange(len(vocab))

    # Initialize the tokens
    pad_token = np.array([[1]])

    # Repeat for all samples in data
    pairs = []
    for samples in data:
        # Serialise the dataset so that the resulting sequence is
        # S_1 A_1 T_1 B_1, S_2 A_2 T_2 B_2, ...

        # Generate input
        input_seq = samples.flatten()

        # Cut off the samples so that it has length of 1024
        if(len(input_seq) >= L):
            # input_seq = input_seq[:L-1]
            input_seq = input_seq[:L]

        # Set the NaN values to 0 and reshape accordingly
        input_seq = np.nan_to_num(input_seq.reshape(1,input_seq.size))

        # Generate target
        output_seq = input_seq[:,1:]

        # For both sequences, pad to sequence length L
        pad_array = pad_token * np.ones((1,L-input_seq.shape[1]))
        input_seq = np.append(input_seq, pad_array,axis=1)
        pad_array = pad_token * np.ones((1,L-output_seq.shape[1]))
        output_seq = np.append(output_seq, pad_array,axis=1)

        # Map the pitch value to int values below vocab size
        for i, val in enumerate(vocab):
            input_seq[input_seq==val] = new_vocab[i]
            output_seq[output_seq==val] = new_vocab[i]

        # Make it into a pair
        pair = [input_seq, output_seq]

        # Combine all pairs into one big list of pairs
        pairs.append(pair)

    print("Generated data pairs.")
    return np.array(pairs)

def GenerateVocab(npz_file):
    """
    Generate vocabulary for the dataset including the custom tokens.
    """
    full_data = np.load(npz_file, fix_imports=True, encoding="latin1", allow_pickle=True)
    train_data = full_data['train']
    validation_data = full_data['valid']
    test_data = full_data['test']

    combined_data = np.concatenate((train_data, validation_data, test_data))

    vocab = np.nan
    for sequences in combined_data:
        vocab = np.append(vocab,np.unique(sequences))

    vocab = np.unique(vocab)
    vocab = vocab[~np.isnan(vocab)]
    vocab = np.append([0,1],vocab)
    return vocab 
 

def batched_learning(train,batch_size):
    for i in range(0, len(train), batch_size):
        train1 = train[i:i + batch_size]
        yield train1[:,0],train1[:,1]


##3 - Training

In [None]:
#@title Create and Init the Model
 
# vocabulary_size depends on the midi encoding
# ~> 388(+2) for encoded_midi / epiano compt 
# ~> 46(+2) for encoded_midi / epiano compt                 

# DEFINING THE MODEL

vocabulary_size = 390

normalization = torch.nn.LayerNorm(d_model)

custom_encoder_layer = MusicTransformerEncoderLayer(d_model=d_model, nhead=nhead, 
                                               dim_feedforward=dim_feedforward, 
                                               dropout=dropout, activation="relu")

custom_decoder_layer = MusicTransformerDecoderLayer(d_model=d_model, nhead=nhead, 
                                               dim_feedforward=dim_feedforward, 
                                               dropout=dropout, activation="relu")

custom_encoder = MusicTransformerEncoder(custom_encoder_layer, vocabulary_size, num_layer, normalization)
custom_decoder = MusicTransformerDecoder(custom_decoder_layer, vocabulary_size, num_layer, normalization)

model = MusicTransformer(d_model=d_model, nhead=nhead, 
                         vocabulary_size=vocabulary_size, 
                         num_encoder_layers=num_layer, 
                         num_decoder_layers=num_layer, 
                         dim_feedforward=dim_feedforward, 
                         dropout=dropout, activation="relu", 
                         custom_encoder=custom_encoder, 
                         custom_decoder=custom_decoder)
# Give model to the current device (hopefully cuda)
model.to(device)

# Optimizer
# Adam optimizer [20] with β 1 = 0.9, β 2 = 0.98 and  = 10 −9
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-09, weight_decay=1e-4)

# Define a scheduler to vary the learning rate

class Scheduler:
    def __init__(self, optimizer, d_model=d_model, warmup_steps=4000):
        self.optimizer = optimizer
        self.d_model = d_model
        self.step_num = 1
        self.l_rate = 0
        self.warmup_steps = warmup_steps

    def step(self):        
        # increment step
        self.step_num += 1

        # compute new learning rate        
        self.l_rate = self.d_model**(-.5) * min(self.step_num**(-.5), self.step_num * self.warmup_steps**(-1.5))

        # update optimizer learning rate
        for p in optimizer.param_groups:
            p['lr'] = self.l_rate

        # update the weights in the network
        self.optimizer.step()


# See if it is possible to do it using lr_scheduler.LambdaLR lr_scheduler.StepLR
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)
scheduler = Scheduler(optimizer, d_model, warmup_steps)

In [None]:
#@title Mask Generation
### Mask Generation from https://github.com/COMP6248-Reproducability-Challenge/music-transformer-comp6248/blob/master/MaskGen.py

# Filename: MaskGen.py
# Date Created: 15-Mar-2019 2:42:12 pm
# Description: Functions used to generate masks w.r.t. given inputs.
from torch.autograd import Variable

def nopeak_mask(size):
    np_mask = np.triu(np.ones((1, size, size)), k=1).astype('uint8')
    np_mask =  Variable(torch.from_numpy(np_mask) == 0).to(device)
    return np_mask


def create_masks(src, trg, pad_token):
    src_mask = (src != pad_token).unsqueeze(-2).to(device)

    if trg is not None:
        trg_mask = (trg != pad_token).unsqueeze(-2).to(device)
        size = trg.size(1) # get seq_len for matrix
        np_mask = nopeak_mask(size)
        trg_mask = trg_mask & np_mask
    else:
        trg_mask = None
    return src_mask, trg_mask


def count_nonpad_tokens(target):
    nonpads = (target != 1).squeeze()
    ntokens = torch.sum(nonpads)
    return ntokens        

### Training with data from Midi_Encoded 

In [None]:
train_data, test_data = load_dataset(train_dir_path, test_dir_path)

ckpt_dir = '/content/checkpoints/'
if not os.path.exists(ckpt_dir):
  os.makedirs(ckpt_dir)
  
best_loss = 10.
model_name = 'midi_encoded_6-1'
ckpt_path = '/content/checkpoints/train-nlayer_'+model_name+'.pt'
if os.path.exists(ckpt_path):
    ckpt = torch.load(ckpt_path)
    try:
      model.load_state_dict(ckpt['my_model'])
      optimizer.load_state_dict(ckpt['optimizer'])
      best_acc = ckpt['best_loss']
    except RuntimeError as e:
        print('wrong checkpoint')
    else:    
      print('checkpoint is loaded !')
      print('current best loss : %.2f' % best_loss)

train_writer = SummaryWriter()
test_writer = SummaryWriter()

In [None]:
#@title Tensorboard Graphs and Stats
# Load the TensorBoard notebook extension
%reload_ext tensorboard
import tensorflow as tf
import datetime, os
%tensorboard --logdir /content/runs

In [18]:
#@title Start Training
# Training
max_epochs = 100
n_iter = 0

# each 50 iterations we are going to compare the losses
total_train_loss = []
total_valid_loss = []

for e in range(max_epochs):
    model.train()
    train_loss = []
    for b_train in tqdm.tqdm(range(len(train_data) // batch_size)):
        
        n_iter += 1
        # train phase
        # feed data into the network and get outputs.
        # feed data into the network and get outputs.
        inputs, target = generate_batch(train_data, train_dir_path, sequence_length=sequence_length, batch_size=batch_size)
              
        # Train on GPU
        inputs.to(device)
        target.to(device)
        ys = target.contiguous().view(-1).to(torch.long).to(device)
        
        # Create mask for both input and target sequences
        input_mask, target_mask = create_masks(torch.reshape(inputs, (batch_size, 1, -1)), torch.reshape(target, (batch_size, 1, -1)), pad_token)        
        
        # feed data into the network and get outputs.
        preds_idx = model(inputs, target, input_mask, target_mask)
        
        # Flush out gradients computed at the previous step before computing gradients at the current step. 
        #       Otherwise, gradients would accumulate.
        optimizer.zero_grad()

        # calculate loss
        loss = F.cross_entropy(preds_idx.contiguous().view(preds_idx.size(-1), -1).transpose(0,1), ys, ignore_index = pad_token, size_average = False) / (count_nonpad_tokens(ys))

        # accumulates the gradient and backprogate loss.
        loss.backward()

        # performs a parameter update based on the current gradient
        scheduler.step()    
        
        print('\n====================================================')
        print('Epoch/Batch: {}/{}'.format(e, b_train))
        print('Train >>>> Loss: {:6.6}'.format(loss))

        train_loss.append(loss.item())

    #print('\n**************************************************')
    #print("\n*** Test *** ")
    
    # Validation step
    model.eval()
    valid_loss = []
    with torch.no_grad():
        for b_test in range(len(test_data) // batch_size):        
            inputs, target = generate_batch(test_data, test_dir_path, sequence_length=sequence_length, batch_size=batch_size)
                  
            # Train on GPU
            inputs.to(device)
            target.to(device)
            ys = target.contiguous().view(-1).to(torch.long).to(device)
            
            # Create mask for both input and target sequences
            input_mask, target_mask = create_masks(torch.reshape(inputs, (batch_size, 1,-1)), torch.reshape(target, (batch_size, 1, -1)), pad_token)        

            # Feed Forward
            preds_validate = model(inputs, target, input_mask, target_mask)
            loss = F.cross_entropy(preds_validate.contiguous().view(preds_validate.size(-1), -1).transpose(0,1), ys, \
                                    ignore_index = pad_token, size_average = False) / (count_nonpad_tokens(ys))
            valid_loss.append(loss.item())

    avg_train_loss = np.mean(train_loss)
    avg_valid_loss = np.mean(valid_loss)

    total_train_loss.append(avg_train_loss)
    total_valid_loss.append(avg_valid_loss)

    print("[Average Train Loss]: {:6.6}".format(avg_train_loss))
    print("[Average Testing Loss]: {:6.6}".format(avg_valid_loss))

    # save checkpoint whenever there is improvement in performance
    if avg_valid_loss < best_loss:
        best_loss = avg_valid_loss
        # Note: optimizer also has states ! don't forget to save them as well.
        ckpt = {'my_model':model.state_dict(),
                'optimizer':optimizer.state_dict(),
                'best_loss':best_loss}
        torch.save(ckpt, ckpt_path)
        print('checkpoint is saved !')

    train_writer.add_scalar('loss/train', avg_train_loss, global_step=n_iter)
    test_writer.add_scalar('loss/valid', avg_valid_loss, global_step=n_iter)
    #print('\n**************************************************')

train_writer = SummaryWriter()
test_writer = SummaryWriter()

 59%|█████▊    | 170/290 [05:11<03:39,  1.83s/it]


Epoch/Batch: 7/169
Train >>>> Loss: 5.87859


 59%|█████▉    | 171/290 [05:13<03:37,  1.82s/it]


Epoch/Batch: 7/170
Train >>>> Loss: 5.85414


 59%|█████▉    | 172/290 [05:15<03:36,  1.83s/it]


Epoch/Batch: 7/171
Train >>>> Loss: 5.91029


 60%|█████▉    | 173/290 [05:17<03:33,  1.83s/it]


Epoch/Batch: 7/172
Train >>>> Loss: 5.86714


 60%|██████    | 174/290 [05:19<03:32,  1.83s/it]


Epoch/Batch: 7/173
Train >>>> Loss: 5.89046


 60%|██████    | 175/290 [05:21<03:30,  1.83s/it]


Epoch/Batch: 7/174
Train >>>> Loss: 5.84257


 61%|██████    | 176/290 [05:22<03:28,  1.83s/it]


Epoch/Batch: 7/175
Train >>>> Loss: 5.87305


 61%|██████    | 177/290 [05:24<03:26,  1.83s/it]


Epoch/Batch: 7/176
Train >>>> Loss: 5.89277


 61%|██████▏   | 178/290 [05:26<03:24,  1.83s/it]


Epoch/Batch: 7/177
Train >>>> Loss: 5.87278


 62%|██████▏   | 179/290 [05:28<03:22,  1.82s/it]


Epoch/Batch: 7/178
Train >>>> Loss: 5.8581


 62%|██████▏   | 180/290 [05:30<03:20,  1.83s/it]


Epoch/Batch: 7/179
Train >>>> Loss: 5.88501


 62%|██████▏   | 181/290 [05:32<03:18,  1.82s/it]


Epoch/Batch: 7/180
Train >>>> Loss: 5.86344


 63%|██████▎   | 182/290 [05:33<03:16,  1.82s/it]


Epoch/Batch: 7/181
Train >>>> Loss: 5.85615


 63%|██████▎   | 183/290 [05:35<03:15,  1.83s/it]


Epoch/Batch: 7/182
Train >>>> Loss: 5.85881


 63%|██████▎   | 184/290 [05:37<03:13,  1.82s/it]


Epoch/Batch: 7/183
Train >>>> Loss: 5.86897


 64%|██████▍   | 185/290 [05:39<03:11,  1.82s/it]


Epoch/Batch: 7/184
Train >>>> Loss: 5.87244


 64%|██████▍   | 186/290 [05:41<03:09,  1.82s/it]


Epoch/Batch: 7/185
Train >>>> Loss: 5.85547


 64%|██████▍   | 187/290 [05:42<03:07,  1.82s/it]


Epoch/Batch: 7/186
Train >>>> Loss: 5.88189


 65%|██████▍   | 188/290 [05:44<03:05,  1.82s/it]


Epoch/Batch: 7/187
Train >>>> Loss: 5.85982


 65%|██████▌   | 189/290 [05:46<03:04,  1.82s/it]


Epoch/Batch: 7/188
Train >>>> Loss: 5.86864


 66%|██████▌   | 190/290 [05:48<03:02,  1.82s/it]


Epoch/Batch: 7/189
Train >>>> Loss: 5.90305


 66%|██████▌   | 191/290 [05:50<03:00,  1.83s/it]


Epoch/Batch: 7/190
Train >>>> Loss: 5.88543


 66%|██████▌   | 192/290 [05:52<02:58,  1.83s/it]


Epoch/Batch: 7/191
Train >>>> Loss: 5.85767


 67%|██████▋   | 193/290 [05:53<02:57,  1.82s/it]


Epoch/Batch: 7/192
Train >>>> Loss: 5.90446


 67%|██████▋   | 194/290 [05:55<02:54,  1.82s/it]


Epoch/Batch: 7/193
Train >>>> Loss:  5.881


 67%|██████▋   | 195/290 [05:57<02:52,  1.82s/it]


Epoch/Batch: 7/194
Train >>>> Loss: 5.88915


 68%|██████▊   | 196/290 [05:59<02:51,  1.82s/it]


Epoch/Batch: 7/195
Train >>>> Loss: 5.90721


 68%|██████▊   | 197/290 [06:01<02:49,  1.82s/it]


Epoch/Batch: 7/196
Train >>>> Loss: 5.86305


 68%|██████▊   | 198/290 [06:02<02:47,  1.82s/it]


Epoch/Batch: 7/197
Train >>>> Loss: 5.89463


 69%|██████▊   | 199/290 [06:04<02:45,  1.82s/it]


Epoch/Batch: 7/198
Train >>>> Loss: 5.88459


 69%|██████▉   | 200/290 [06:06<02:44,  1.83s/it]


Epoch/Batch: 7/199
Train >>>> Loss: 5.89994


 69%|██████▉   | 201/290 [06:08<02:43,  1.83s/it]


Epoch/Batch: 7/200
Train >>>> Loss: 5.86162


 70%|██████▉   | 202/290 [06:10<02:41,  1.83s/it]


Epoch/Batch: 7/201
Train >>>> Loss: 5.87455


 70%|███████   | 203/290 [06:12<02:39,  1.83s/it]


Epoch/Batch: 7/202
Train >>>> Loss: 5.85298


 70%|███████   | 204/290 [06:13<02:37,  1.83s/it]


Epoch/Batch: 7/203
Train >>>> Loss: 5.90132


 71%|███████   | 205/290 [06:15<02:35,  1.83s/it]


Epoch/Batch: 7/204
Train >>>> Loss: 5.84119


 71%|███████   | 206/290 [06:17<02:33,  1.83s/it]


Epoch/Batch: 7/205
Train >>>> Loss: 5.86502


 71%|███████▏  | 207/290 [06:19<02:31,  1.83s/it]


Epoch/Batch: 7/206
Train >>>> Loss: 5.87978


 72%|███████▏  | 208/290 [06:21<02:29,  1.82s/it]


Epoch/Batch: 7/207
Train >>>> Loss: 5.8859


 72%|███████▏  | 209/290 [06:23<02:27,  1.82s/it]


Epoch/Batch: 7/208
Train >>>> Loss: 5.84231


 72%|███████▏  | 210/290 [06:24<02:26,  1.83s/it]


Epoch/Batch: 7/209
Train >>>> Loss: 5.83269


 73%|███████▎  | 211/290 [06:26<02:24,  1.83s/it]


Epoch/Batch: 7/210
Train >>>> Loss: 5.88068


 73%|███████▎  | 212/290 [06:28<02:22,  1.83s/it]


Epoch/Batch: 7/211
Train >>>> Loss: 5.84315


 73%|███████▎  | 213/290 [06:30<02:20,  1.83s/it]


Epoch/Batch: 7/212
Train >>>> Loss: 5.87322


 74%|███████▍  | 214/290 [06:32<02:19,  1.83s/it]


Epoch/Batch: 7/213
Train >>>> Loss: 5.89127


 74%|███████▍  | 215/290 [06:34<02:16,  1.83s/it]


Epoch/Batch: 7/214
Train >>>> Loss: 5.89715


 74%|███████▍  | 216/290 [06:35<02:15,  1.83s/it]


Epoch/Batch: 7/215
Train >>>> Loss: 5.87049


 75%|███████▍  | 217/290 [06:37<02:13,  1.83s/it]


Epoch/Batch: 7/216
Train >>>> Loss: 5.91095


 75%|███████▌  | 218/290 [06:39<02:11,  1.83s/it]


Epoch/Batch: 7/217
Train >>>> Loss: 5.88499


 76%|███████▌  | 219/290 [06:41<02:09,  1.83s/it]


Epoch/Batch: 7/218
Train >>>> Loss: 5.86966


 76%|███████▌  | 220/290 [06:43<02:07,  1.82s/it]


Epoch/Batch: 7/219
Train >>>> Loss: 5.88307


 76%|███████▌  | 221/290 [06:45<02:05,  1.83s/it]


Epoch/Batch: 7/220
Train >>>> Loss: 5.87204


 77%|███████▋  | 222/290 [06:46<02:04,  1.83s/it]


Epoch/Batch: 7/221
Train >>>> Loss: 5.8742


 77%|███████▋  | 223/290 [06:48<02:02,  1.83s/it]


Epoch/Batch: 7/222
Train >>>> Loss:  5.883


 77%|███████▋  | 224/290 [06:50<02:00,  1.83s/it]


Epoch/Batch: 7/223
Train >>>> Loss: 5.82975


 78%|███████▊  | 225/290 [06:52<01:58,  1.83s/it]


Epoch/Batch: 7/224
Train >>>> Loss: 5.83616


 78%|███████▊  | 226/290 [06:54<01:57,  1.83s/it]


Epoch/Batch: 7/225
Train >>>> Loss: 5.84642


 78%|███████▊  | 227/290 [06:56<01:55,  1.83s/it]


Epoch/Batch: 7/226
Train >>>> Loss: 5.91602


 79%|███████▊  | 228/290 [06:57<01:53,  1.83s/it]


Epoch/Batch: 7/227
Train >>>> Loss: 5.89037


 79%|███████▉  | 229/290 [06:59<01:51,  1.83s/it]


Epoch/Batch: 7/228
Train >>>> Loss: 5.90602


 79%|███████▉  | 230/290 [07:01<01:49,  1.83s/it]


Epoch/Batch: 7/229
Train >>>> Loss: 5.92208


 80%|███████▉  | 231/290 [07:03<01:47,  1.82s/it]


Epoch/Batch: 7/230
Train >>>> Loss: 5.86361


 80%|████████  | 232/290 [07:05<01:45,  1.82s/it]


Epoch/Batch: 7/231
Train >>>> Loss: 5.88511


 80%|████████  | 233/290 [07:06<01:44,  1.83s/it]


Epoch/Batch: 7/232
Train >>>> Loss: 5.87148


 81%|████████  | 234/290 [07:08<01:42,  1.83s/it]


Epoch/Batch: 7/233
Train >>>> Loss: 5.88079


 81%|████████  | 235/290 [07:10<01:40,  1.83s/it]


Epoch/Batch: 7/234
Train >>>> Loss: 5.90962


 81%|████████▏ | 236/290 [07:12<01:38,  1.83s/it]


Epoch/Batch: 7/235
Train >>>> Loss: 5.87023


 82%|████████▏ | 237/290 [07:14<01:36,  1.82s/it]


Epoch/Batch: 7/236
Train >>>> Loss: 5.89173


 82%|████████▏ | 238/290 [07:16<01:35,  1.83s/it]


Epoch/Batch: 7/237
Train >>>> Loss:  5.903


 82%|████████▏ | 239/290 [07:17<01:33,  1.83s/it]


Epoch/Batch: 7/238
Train >>>> Loss: 5.86212


 83%|████████▎ | 240/290 [07:19<01:31,  1.83s/it]


Epoch/Batch: 7/239
Train >>>> Loss: 5.8647


 83%|████████▎ | 241/290 [07:21<01:29,  1.83s/it]


Epoch/Batch: 7/240
Train >>>> Loss: 5.85361


 83%|████████▎ | 242/290 [07:23<01:27,  1.83s/it]


Epoch/Batch: 7/241
Train >>>> Loss: 5.86521


 84%|████████▍ | 243/290 [07:25<01:25,  1.83s/it]


Epoch/Batch: 7/242
Train >>>> Loss: 5.84689


 84%|████████▍ | 244/290 [07:27<01:23,  1.83s/it]


Epoch/Batch: 7/243
Train >>>> Loss: 5.90783


 84%|████████▍ | 245/290 [07:28<01:22,  1.82s/it]


Epoch/Batch: 7/244
Train >>>> Loss: 5.83394


 85%|████████▍ | 246/290 [07:30<01:20,  1.82s/it]


Epoch/Batch: 7/245
Train >>>> Loss: 5.84492


 85%|████████▌ | 247/290 [07:32<01:18,  1.82s/it]


Epoch/Batch: 7/246
Train >>>> Loss: 5.88521


 86%|████████▌ | 248/290 [07:34<01:16,  1.83s/it]


Epoch/Batch: 7/247
Train >>>> Loss: 5.86782


 86%|████████▌ | 249/290 [07:36<01:14,  1.82s/it]


Epoch/Batch: 7/248
Train >>>> Loss: 5.86979


 86%|████████▌ | 250/290 [07:38<01:13,  1.83s/it]


Epoch/Batch: 7/249
Train >>>> Loss: 5.8721


 87%|████████▋ | 251/290 [07:39<01:11,  1.82s/it]


Epoch/Batch: 7/250
Train >>>> Loss: 5.89378


 87%|████████▋ | 252/290 [07:41<01:09,  1.83s/it]


Epoch/Batch: 7/251
Train >>>> Loss: 5.88254


 87%|████████▋ | 253/290 [07:43<01:07,  1.83s/it]


Epoch/Batch: 7/252
Train >>>> Loss: 5.87999


 88%|████████▊ | 254/290 [07:45<01:05,  1.82s/it]


Epoch/Batch: 7/253
Train >>>> Loss: 5.85106


 88%|████████▊ | 255/290 [07:47<01:03,  1.82s/it]


Epoch/Batch: 7/254
Train >>>> Loss: 5.8827


 88%|████████▊ | 256/290 [07:49<01:02,  1.83s/it]


Epoch/Batch: 7/255
Train >>>> Loss: 5.87418


 89%|████████▊ | 257/290 [07:50<01:00,  1.83s/it]


Epoch/Batch: 7/256
Train >>>> Loss: 5.84674


 89%|████████▉ | 258/290 [07:52<00:58,  1.83s/it]


Epoch/Batch: 7/257
Train >>>> Loss: 5.86594


 89%|████████▉ | 259/290 [07:54<00:56,  1.83s/it]


Epoch/Batch: 7/258
Train >>>> Loss: 5.9179


 90%|████████▉ | 260/290 [07:56<00:54,  1.83s/it]


Epoch/Batch: 7/259
Train >>>> Loss: 5.87041


 90%|█████████ | 261/290 [07:58<00:52,  1.83s/it]


Epoch/Batch: 7/260
Train >>>> Loss: 5.87535


 90%|█████████ | 262/290 [07:59<00:51,  1.82s/it]


Epoch/Batch: 7/261
Train >>>> Loss: 5.87399


 91%|█████████ | 263/290 [08:01<00:49,  1.83s/it]


Epoch/Batch: 7/262
Train >>>> Loss: 5.85861


 91%|█████████ | 264/290 [08:03<00:47,  1.83s/it]


Epoch/Batch: 7/263
Train >>>> Loss: 5.8535


 91%|█████████▏| 265/290 [08:05<00:45,  1.83s/it]


Epoch/Batch: 7/264
Train >>>> Loss: 5.84807


 92%|█████████▏| 266/290 [08:07<00:43,  1.83s/it]


Epoch/Batch: 7/265
Train >>>> Loss: 5.91478


 92%|█████████▏| 267/290 [08:09<00:42,  1.83s/it]


Epoch/Batch: 7/266
Train >>>> Loss: 5.89147


 92%|█████████▏| 268/290 [08:10<00:40,  1.83s/it]


Epoch/Batch: 7/267
Train >>>> Loss: 5.86361


 93%|█████████▎| 269/290 [08:12<00:38,  1.83s/it]


Epoch/Batch: 7/268
Train >>>> Loss: 5.87768


 93%|█████████▎| 270/290 [08:14<00:36,  1.84s/it]


Epoch/Batch: 7/269
Train >>>> Loss: 5.86707


 93%|█████████▎| 271/290 [08:16<00:35,  1.84s/it]


Epoch/Batch: 7/270
Train >>>> Loss: 5.8573


 94%|█████████▍| 272/290 [08:18<00:33,  1.84s/it]


Epoch/Batch: 7/271
Train >>>> Loss: 5.90703


 94%|█████████▍| 273/290 [08:20<00:31,  1.84s/it]


Epoch/Batch: 7/272
Train >>>> Loss: 5.86831


 94%|█████████▍| 274/290 [08:21<00:29,  1.84s/it]


Epoch/Batch: 7/273
Train >>>> Loss: 5.86054


 95%|█████████▍| 275/290 [08:23<00:27,  1.83s/it]


Epoch/Batch: 7/274
Train >>>> Loss: 5.87766


 95%|█████████▌| 276/290 [08:25<00:25,  1.83s/it]


Epoch/Batch: 7/275
Train >>>> Loss: 5.87726


 96%|█████████▌| 277/290 [08:27<00:23,  1.84s/it]


Epoch/Batch: 7/276
Train >>>> Loss: 5.85965


 96%|█████████▌| 278/290 [08:29<00:22,  1.85s/it]


Epoch/Batch: 7/277
Train >>>> Loss: 5.83895


 96%|█████████▌| 279/290 [08:31<00:20,  1.84s/it]


Epoch/Batch: 7/278
Train >>>> Loss: 5.88173


 97%|█████████▋| 280/290 [08:33<00:18,  1.84s/it]


Epoch/Batch: 7/279
Train >>>> Loss: 5.89214


 97%|█████████▋| 281/290 [08:34<00:16,  1.84s/it]


Epoch/Batch: 7/280
Train >>>> Loss: 5.88817


 97%|█████████▋| 282/290 [08:36<00:14,  1.83s/it]


Epoch/Batch: 7/281
Train >>>> Loss: 5.87259


 98%|█████████▊| 283/290 [08:38<00:12,  1.83s/it]


Epoch/Batch: 7/282
Train >>>> Loss: 5.89176


 98%|█████████▊| 284/290 [08:40<00:10,  1.83s/it]


Epoch/Batch: 7/283
Train >>>> Loss: 5.87077


 98%|█████████▊| 285/290 [08:42<00:09,  1.83s/it]


Epoch/Batch: 7/284
Train >>>> Loss: 5.89316


 99%|█████████▊| 286/290 [08:43<00:07,  1.83s/it]


Epoch/Batch: 7/285
Train >>>> Loss: 5.90742


 99%|█████████▉| 287/290 [08:45<00:05,  1.82s/it]


Epoch/Batch: 7/286
Train >>>> Loss: 5.8981


 99%|█████████▉| 288/290 [08:47<00:03,  1.83s/it]


Epoch/Batch: 7/287
Train >>>> Loss: 5.92574


100%|█████████▉| 289/290 [08:49<00:01,  1.83s/it]


Epoch/Batch: 7/288
Train >>>> Loss: 5.85131


100%|██████████| 290/290 [08:51<00:00,  1.83s/it]


Epoch/Batch: 7/289
Train >>>> Loss: 5.91338



  0%|          | 0/290 [00:00<?, ?it/s]

[Average Train Loss]: 5.87723
[Average Testing Loss]: 5.88378


  0%|          | 1/290 [00:01<08:59,  1.87s/it]


Epoch/Batch: 8/0
Train >>>> Loss: 5.90702


  1%|          | 2/290 [00:03<08:55,  1.86s/it]


Epoch/Batch: 8/1
Train >>>> Loss: 5.86553


  1%|          | 3/290 [00:05<08:51,  1.85s/it]


Epoch/Batch: 8/2
Train >>>> Loss: 5.90569


  1%|▏         | 4/290 [00:07<08:47,  1.84s/it]


Epoch/Batch: 8/3
Train >>>> Loss: 5.91822


  2%|▏         | 5/290 [00:09<08:43,  1.84s/it]


Epoch/Batch: 8/4
Train >>>> Loss: 5.90035


  2%|▏         | 6/290 [00:11<08:40,  1.83s/it]


Epoch/Batch: 8/5
Train >>>> Loss: 5.87782


  2%|▏         | 7/290 [00:12<08:39,  1.84s/it]


Epoch/Batch: 8/6
Train >>>> Loss: 5.85824


  3%|▎         | 8/290 [00:14<08:36,  1.83s/it]


Epoch/Batch: 8/7
Train >>>> Loss: 5.9074


  3%|▎         | 9/290 [00:16<08:34,  1.83s/it]


Epoch/Batch: 8/8
Train >>>> Loss: 5.89129


  3%|▎         | 10/290 [00:18<08:31,  1.83s/it]


Epoch/Batch: 8/9
Train >>>> Loss: 5.85539


  4%|▍         | 11/290 [00:20<08:31,  1.83s/it]


Epoch/Batch: 8/10
Train >>>> Loss: 5.87485


  4%|▍         | 12/290 [00:21<08:28,  1.83s/it]


Epoch/Batch: 8/11
Train >>>> Loss: 5.85324


  4%|▍         | 13/290 [00:23<08:25,  1.83s/it]


Epoch/Batch: 8/12
Train >>>> Loss: 5.84745


  5%|▍         | 14/290 [00:25<08:23,  1.82s/it]


Epoch/Batch: 8/13
Train >>>> Loss: 5.86305


  5%|▌         | 15/290 [00:27<08:20,  1.82s/it]


Epoch/Batch: 8/14
Train >>>> Loss: 5.8566


  6%|▌         | 16/290 [00:29<08:18,  1.82s/it]


Epoch/Batch: 8/15
Train >>>> Loss: 5.91288


  6%|▌         | 17/290 [00:31<08:17,  1.82s/it]


Epoch/Batch: 8/16
Train >>>> Loss: 5.87495


  6%|▌         | 18/290 [00:32<08:17,  1.83s/it]


Epoch/Batch: 8/17
Train >>>> Loss: 5.90463


  7%|▋         | 19/290 [00:34<08:16,  1.83s/it]


Epoch/Batch: 8/18
Train >>>> Loss: 5.87565


  7%|▋         | 20/290 [00:36<08:14,  1.83s/it]


Epoch/Batch: 8/19
Train >>>> Loss: 5.90156


  7%|▋         | 21/290 [00:38<08:11,  1.83s/it]


Epoch/Batch: 8/20
Train >>>> Loss: 5.87591


  8%|▊         | 22/290 [00:40<08:10,  1.83s/it]


Epoch/Batch: 8/21
Train >>>> Loss: 5.89949


  8%|▊         | 23/290 [00:42<08:07,  1.83s/it]


Epoch/Batch: 8/22
Train >>>> Loss: 5.89283


  8%|▊         | 24/290 [00:43<08:05,  1.82s/it]


Epoch/Batch: 8/23
Train >>>> Loss: 5.84822


  9%|▊         | 25/290 [00:45<08:04,  1.83s/it]


Epoch/Batch: 8/24
Train >>>> Loss: 5.88076


  9%|▉         | 26/290 [00:47<08:02,  1.83s/it]


Epoch/Batch: 8/25
Train >>>> Loss: 5.82566


  9%|▉         | 27/290 [00:49<08:01,  1.83s/it]


Epoch/Batch: 8/26
Train >>>> Loss: 5.87919


 10%|▉         | 28/290 [00:51<08:00,  1.83s/it]


Epoch/Batch: 8/27
Train >>>> Loss: 5.90984


KeyboardInterrupt: ignored

In [None]:
import matplotlib.pyplot as plt

plt.plot(total_train_loss, label='train loss')
plt.plot(total_valid_loss, label='validation loss')
plt.legend()
plt.xlabel("epoch")
plt.ylabel("loss")
img_path = os.path.join('/gdrive/My Drive/my_data/library/checkpoints/', "encoded_midi_6-1.png")
plt.savefig(img_path)

### Training with data from JS Bach Chorales 

In [None]:
src_data = '/gdrive/My Drive/my_data/library/Jsb16thSeparated.npz'

# Generate the vocabulary from the data
vocab = GenerateVocab(src_data)
vocabulary_size = len(vocab)
pad_token = 1

# since we change de vocab len the model should be create after this
# we just show this as an exemple of training of JS Bach Chorales dataset 
# thanks to propressing functions from another repository.

# Setup the dataset for training split and validation split
train_data = PrepareData(src_data ,'train', int(sequence_length))
valid_data = PrepareData(src_data ,'valid', int(sequence_length))

ckpt_dir = os.path.join(gdrive_root, '/my_data/library/checkpoints/')
if not os.path.exists(ckpt_dir):
  os.makedirs(ckpt_dir)
  
best_loss = 10.
model_name = 'midi_encoded_6-1'
ckpt_path = '/gdrive/My Drive/my_data/library/checkpoints/train-nlayer_'+model_name+'.pt'
if os.path.exists(ckpt_path):
    ckpt = torch.load(ckpt_path)
    try:
      model.load_state_dict(ckpt['my_model'])
      optimizer.load_state_dict(ckpt['optimizer'])
      best_acc = ckpt['best_loss']
    except RuntimeError as e:
        print('wrong checkpoint')
    else:    
      print('checkpoint is loaded !')
      print('current best loss : %.2f' % best_loss)

In [None]:
# Training
max_epochs = 100
n_iter = 0

# each 50 iterations we are going to compare the losses
total_train_loss = []
total_valid_loss = []

for e in range(max_epochs):
    model.train()
    random.shuffle(train_data)
    train_loss = []
    for b_train, batch in enumerate(batched_learning(train_data, batch_size=batch_size)):        
        
        n_iter += 1
        # train phase
        # feed data into the network and get outputs.
        inputs, target = batch
        
        #print(inputs.shape, target.shape)

        # Train on GPU
        inputs = (tensorFromSequence(inputs)).to(device)
        target = (tensorFromSequence(target)).to(device)

        # Create mask for both input and target sequences
        input_mask, target_mask = create_masks(inputs, target, pad_token)

        inputs = inputs[:,0,:]
        target = target[:,0,:] 

        ys = target.contiguous().view(-1)              

        # ys = labels.contiguous().view(-1).to(torch.long).to(device)
        
        # feed data into the network and get outputs.
        preds_idx = model(inputs, target, input_mask, target_mask)
        
        # Flush out gradients computed at the previous step before computing gradients at the current step. 
        #       Otherwise, gradients would accumulate.
        optimizer.zero_grad()

        # calculate loss
        loss = F.cross_entropy(preds_idx.contiguous().view(preds_idx.size(-1), -1).transpose(0,1), ys, ignore_index = pad_token, size_average = False) / (count_nonpad_tokens(ys))

        # accumulates the gradient and backprogate loss.
        loss.backward()

        # performs a parameter update based on the current gradient
        scheduler.step()    
        
        print('\n====================================================')
        print('Epoch/Batch: {}/{}'.format(e, b_train))
        print('Train >>>> Loss: {:6.6}'.format(loss))

        train_loss.append(loss.item())

    print('\n**************************************************')
    print("\n*** Test *** ")
    
    # Validation step
    model.eval()
    valid_loss = []
    with torch.no_grad():
        pair = valid_data
        inputs = tensorFromSequence(pair[0]).to(device)
        target = tensorFromSequence(pair[1]).to(device)
        
        # Create mask for both input and target sequences
        input_mask, target_mask = create_masks(inputs, target, pad_token)

        inputs = inputs[:,0,:]
        target = target[:,0,:] 
        ys = target.contiguous().view(-1)

        preds_validate = model(inputs, target, input_mask, target_mask)
        loss = F.cross_entropy(preds_validate.contiguous().view(preds_validate.size(-1), -1).transpose(0,1), ys, \
                                ignore_index = pad_token, size_average = False) / (count_nonpad_tokens(ys))
        valid_loss.append(loss.item())

    avg_train_loss = np.mean(train_loss)
    avg_valid_loss = np.mean(valid_loss)

    total_train_loss.append(avg_train_loss)
    total_valid_loss.append(avg_valid_loss)

    print("[Average Train Loss]: {:6.6}".format(avg_train_loss))
    print("[Average Testing Loss]: {:6.6}".format(avg_valid_loss))

    # save checkpoint whenever there is improvement in performance
    if avg_valid_loss < best_loss:
        best_loss = avg_valid_loss
        # Note: optimizer also has states ! don't forget to save them as well.
        ckpt = {'my_model':model.state_dict(),
                'optimizer':optimizer.state_dict(),
                'best_loss':best_loss}
        torch.save(ckpt, ckpt_path)
        print('checkpoint is saved !')

    train_writer.add_scalar('loss/train', avg_train_loss, global_step=n_iter)
    test_writer.add_scalar('loss/valid', avg_valid_loss, global_step=n_iter)
    print('\n**************************************************')
    # torch.cuda.empty_cache()
    # torch.save(model.state_dict(), '/gdrive/My Drive/my_data/library/checkpoints/train-{}.pt'.format(e))
    ckpt = {'my_model':model.state_dict(),
            'optimizer':optimizer.state_dict(),
            'best_loss':best_loss}
    torch.save(model.state_dict(), ckpt_backup_path)
    torch.save(ckpt, ckpt_fullbackup_path)

train_writer = SummaryWriter()
test_writer = SummaryWriter()

# Prediction

In [21]:
inputs = torch.randint(0, 1000, (1,10)).to(device)
gen_length = 30
generated_midi = torch.Tensor()
for seq in range(gen_length):
    if seq % 20 == 0: 
        print(seq)
    #print(inputs)
    logits = F.softmax(model(inputs[:, :-1], inputs[:, 1:]), -1).to(device)

    logits = logits[0, :, :]
    
    one_hot = torch.distributions.OneHotCategorical(probs=logits[:,-1])
    res = one_hot.sample().argmax(-1).unsqueeze(-1).to(device)

    #res = torch.transpose(res, 0, 1)

    inputs = torch.cat((inputs[0], res), dim=-1)
    inputs = torch.reshape(inputs, (1, -1))

0


RuntimeError: ignored

In [22]:
from matplotlib import pyplot as pt

x = range(0, len(inputs[0]))
y = [ints.item() for ints in inputs[0]]

p = pt.scatter(x, y, label='generated pitches')

pt.xlabel('index')
pt.ylabel('midi value (pitch)')
pt.show()

RuntimeError: ignored

In [None]:
midi_boy = decode_midi(inputs[0])
torch.save(midi_boy, gdrive_root + '/my_data/library/test.midi')