<a href="https://colab.research.google.com/github/DRIP-AI-RESEARCH-JUNIOR/MUSIC_GENEARATION/blob/master/MusicGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class Generator(nn.Module):
    def __init__(self,num_features,p=0.1):
        super(Generator, self).__init__()
        self.fc_encoder = nn.Linear(2*num_features,num_features)
        self.lstmCell_encoder = nn.LSTMCell(input_size=num_features, hidden_size=num_features)
        self.dropout = nn.Dropout(p=p)
        self.fc_decoder = nn.Linear(num_features,num_features)
        self.lstmCell_decoder = nn.LSTMCell(input_size=num_features, hidden_size=num_features)

    def weight_init(self):
        torch.nn.init.zeros_(tensor)

    def forward(self,x,he,ce,hd,cd):
        output = []
        seq_len = x.size()[1]
        batch = x.size()[0]
        num_features = x.size()[2]
        input = x.permute(1,0,2)
        start = nn.init.uniform_(torch.empty(batch,num_features))
        for x_step in input:
            input_concat = torch.cat((x_step,start),dim=-1)
            linear_out_encoder = F.relu(self.fc_encoder(input_concat))
            he,ce = self.lstmCell_encoder(linear_out_encoder,(he,ce))
            he = self.dropout(he)
            hd,cd = self.lstmCell_decoder(he,(hd,cd))
            start = F.sigmoid(self.fc_decoder(hd))
            output.append(start)
        output = torch.stack(output)
        output = output.permute(1,0,2)
        return output


In [None]:
x = torch.randn(2,3,88)
he = torch.randn(2,88)
ce = torch.randn(2,88)
hd = torch.randn(2,88)
cd = torch.randn(2,88)

In [None]:
model = Generator(88)
out = model(x,he,ce,hd,cd)



In [None]:
print(out.shape)

torch.Size([2, 3, 88])


In [None]:
class Discriminator(nn.Module):
    def __init__(self,num_feature,p=0.1):
        super(Discriminator,self).__init__()
        self.dropout = nn.Dropout(p=p)
        self.lstm = nn.LSTM(num_feature,int(num_feature/2),num_layers=2,batch_first=True,bidirectional=True)
        self.fc = nn.Linear(num_feature,1)

    def forward(self,x,h,c):
        drop_D = self.dropout(x)
        out,(h,c) = self.lstm(x,(h,c))
        out = F.sigmoid(self.fc(out))
        return out

In [None]:
x = torch.randn(2,3,88)
h = torch.randn(4,2,44)
c = torch.randn(4,2,44)
model = Discriminator(88)
out = model(x,h,c)
print(out.shape)

torch.Size([2, 3, 1])




In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!cp -r /content/drive/My\ Drive/Nottingham /content
!cp -r /content/drive/My\ Drive/midi /content

In [1]:
%matplotlib inline
import os
import sys
import random
import math
sys.path.append('midi')
 
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as data

In [None]:
from midi_utils import midiread, midiwrite
from matplotlib import pyplot as plt
import skimage.io as io
from IPython.display import FileLink

In [None]:
import numpy as np
import torch
import torch.utils.data as data
 
 
def midi_filename_to_piano_roll(midi_filename):
    
    midi_data = midiread(midi_filename, dt=0.3)
    
    piano_roll = midi_data.piano_roll.transpose()
    
    # Pressed notes are replaced by 1
    piano_roll[piano_roll > 0] = 1
    
    return piano_roll
 
 
def pad_piano_roll(piano_roll, max_length=132333, pad_value=0):
        
    original_piano_roll_length = piano_roll.shape[1]
    
    padded_piano_roll = np.zeros((88, max_length))
    padded_piano_roll[:] = pad_value
    
    padded_piano_roll[:, -original_piano_roll_length:] = piano_roll
 
    return padded_piano_roll
 
 
class NotesGenerationDataset(data.Dataset):
    
    def __init__(self, midi_folder_path, longest_sequence_length=1491):
        
        self.midi_folder_path = midi_folder_path
        
        midi_filenames = os.listdir(midi_folder_path)
        
        self.longest_sequence_length = longest_sequence_length
        
        midi_full_filenames = map(lambda filename: os.path.join(midi_folder_path, filename),midi_filenames)
        
        self.midi_full_filenames = list(midi_full_filenames)
        
        if longest_sequence_length is None:
            
            self.update_the_max_length()
    
    
    def update_the_max_length(self):
        
        sequences_lengths = map(lambda filename: midi_filename_to_piano_roll(filename).shape[1],self.midi_full_filenames)
        
        max_length = max(sequences_lengths)
        
        self.longest_sequence_length = max_length
                
    
    def __len__(self):
        
        return len(self.midi_full_filenames)
    
    def __getitem__(self, index):
        
        midi_full_filename = self.midi_full_filenames[index]
        
        piano_roll = midi_filename_to_piano_roll(midi_full_filename)
        # print("piano_roll",piano_roll.shape)
        
        # Shifting by one time step
        sequence_length = piano_roll.shape[1] - 1
        
        # Shifting by one time step
        input_sequence = piano_roll[:, :-1]
        # print("input_sequence",input_sequence.shape)
        ground_truth_sequence = piano_roll[:, 1:]
        # print("ground_truth",ground_truth_sequence.shape)
                
        # padding sequence so that all of them have the same length
        input_sequence_padded = pad_piano_roll(input_sequence, max_length=self.longest_sequence_length)
        # print("input_sequence_padded",input_sequence_padded.shape)
        
        ground_truth_sequence_padded = pad_piano_roll(ground_truth_sequence,max_length=self.longest_sequence_length,pad_value=-100)
        # print("ground_sequence_padded",ground_truth_sequence_padded.shape)
                
        input_sequence_padded = input_sequence_padded.transpose()
        ground_truth_sequence_padded = ground_truth_sequence_padded.transpose()
        
        return (torch.FloatTensor(input_sequence_padded),torch.LongTensor(ground_truth_sequence_padded),torch.LongTensor([sequence_length]) )
 
    
def post_process_sequence_batch(batch_tuple):
    
    input_sequences, output_sequences, lengths = batch_tuple
    
    splitted_input_sequence_batch = input_sequences.split(split_size=1)
    splitted_output_sequence_batch = output_sequences.split(split_size=1)
    splitted_lengths_batch = lengths.split(split_size=1)
 
    training_data_tuples = zip(splitted_input_sequence_batch,
                               splitted_output_sequence_batch,
                               splitted_lengths_batch)
 
    training_data_tuples_sorted = sorted(training_data_tuples,
                                         key=lambda p: int(p[2]),
                                         reverse=True)
 
    splitted_input_sequence_batch, splitted_output_sequence_batch, splitted_lengths_batch = zip(*training_data_tuples_sorted)
 
    input_sequence_batch_sorted = torch.cat(splitted_input_sequence_batch)
    output_sequence_batch_sorted = torch.cat(splitted_output_sequence_batch)
    lengths_batch_sorted = torch.cat(splitted_lengths_batch)
    
    input_sequence_batch_sorted = input_sequence_batch_sorted[:, -lengths_batch_sorted[0, 0]:, :]
    output_sequence_batch_sorted = output_sequence_batch_sorted[:, -lengths_batch_sorted[0, 0]:, :]
    
    input_sequence_batch_transposed = input_sequence_batch_sorted.transpose(0, 1)
    
    lengths_batch_sorted_list = list(lengths_batch_sorted)
    lengths_batch_sorted_list = map(lambda x: int(x), lengths_batch_sorted_list)
    
    return input_sequence_batch_transposed, output_sequence_batch_sorted, list(lengths_batch_sorted_list)

In [None]:
trainset = NotesGenerationDataset('Nottingham/train/', longest_sequence_length=None)
 
trainset_loader = data.DataLoader(trainset, batch_size=8,shuffle=True, drop_last=True)

In [None]:
valset = NotesGenerationDataset('Nottingham/valid/', longest_sequence_length=None)
 
valset_loader = data.DataLoader(valset, batch_size=8,shuffle=True, drop_last=True)

In [None]:
class Generator(nn.Module):

    def __init__(self, num_fea, hidden_dim=256, drop=0.6, device='cuda'):
        super(Generator, self).__init__()

        self.hidden_dim = hiddeen_dim
        self.device = device
        self.num_fea = num_fea

        self.en_fc = nn.Linear(2*num_fea, hidden_dim)
        self.en_lstm = nn.LSTMCell(input_size=hidden_dim, hidden_size=hidden_dim)

        self.bottleneck_drop = nn.Dropout(p=drop)

        self.de_lstm = nn.LSTMCell(input_size=hidden_dim, hidden_size=hidden_dim)
        self.de_fc = nn.Linear(hidden_dim, num_fea)

    def forward(self, x, states):
        batch_size, seq_len, num_fea = x

        # seq_len*(batch_size*num_fea)
        x = torch.split(x, 1, dim=1)
        x = [x_.squeeze(dim=1) for x_ in x]

        sos = torch.empty([batch_size, num_fea]).uniform_().to(self.device)

        en_state, de_state = states
        out_fea = []

        for x_ in x:
            input = torch.cat((z_, sos), dim=-1)
            en_out = F.relu(self.en_fc(input))
            hE, cE = self.en_lstm(en_out, en_state)
            
            hE = self.bottleneck_drop(hE)

            hD, cD = self.de_lstm(hE, de_state)
            sos = self.de_fc(hD)

            out_fea.append(sos)

            en_state = (hE, cE)
            de_state = (hD, cD)

        out_fea = torch.stack(out_fea, dim=1) # s,b,n -> b,s,n
        states = (en_state, de_state)
        return out_fea, states
    
    def init_hidden(self, batch_size):

        w = next(self.parameters()).data

        h = ((w.new(batch_size, self.hidden_dim).zero_().to(self.device),
              w.new(batch_size, self.hidden_dim).zero_().to(self.device)),
             (w.new(batch_size, self.hidden_dim).zero_().to(self.device),
              w.new(batch_size, self.hidden_dim).zero_().to(self.device)))
        return h

In [None]:
class Discriminator(nn.Module):

    def __init__(self, num_fea, hidden_dim=256, drop=0.6, device='cuda'):
        super(Discriminator, self).__init__()

        self.device = device
        self.num_layers = 2
        self.hidden_dim = hidden_dim

        self.drop = nn.Dropout(p=drop)
        self.lstm = nn.LSTM(input_size=num_fea, hidden_size=hidden_dim, num_layers=self.num_layers,
                            batch_first=True, dropout=drop, bidirectional=True)
        self.fc = nn.Linear(2*hidden_dim, 1)

    def forward(self, x, state):
        lstm_out, state = self.lstm(self.drop(x), state) # b, s, 2*h
        out = torch.sigmoid(self.fc(lstm_out)) # b, s, 1

        out = torch.mean(out, dim=tuple(range(1, len(out.shape))))

        return out, lstm_out, state

    def init_hidden(self, batch_size):
        w = next(self.parameters()).data

        h = (w.new(self.num_layers*2, batch_size, self.hidden_dim).zero_().to(self.device),
             w.new(self.num_layers*2, batch_size, self.hidden_dim).zero_().to(self.device))
        
        return h

In [None]:
EPS = 1e-40
class GenLoss(nn.Module):

    def __init__(self):
        super(GenLoss, self).__init__()

    def forward(seld, fake_logits):
        return torch.mean(-torch.log(torch.clamp(fake_logits, EPS, 1.0)))

In [None]:
class DisLoss(nn.Module):
    
    def __init__(self, smooth=False):
        super(DisLoss, self).__init__()

        self.smooth = smooth

    def forward(self, real_logits, fake_logits):

        # loss = -(ylog(p) + (1-y)log(1-p))

        d_loss_real = -torch.log(torch.clamp(real_logits, EPS, 1.0))

        if self.smooth:
            d_loss_fake = torch.log(torch.clamp((1-real_logits), EPS, 1.0))
            d_loss_real = 0.9*d_loss_real + 0.1*d_loss_fake
        
        d_loss_fake = torch.log(torch.clamp((1-fake_logits), EPS, 1.0))

        return torch.mean(d_loss_real + d_loss_fake)

In [None]:
def train(train_loader, net_g, net_d, optimizer_g, optimizer_d):

    net_g.train()
    net_d.train()

    d_total_loss = 0
    g_total_loss = 0
    n_correct = 0

    for i, batch in enumerate(train_loader):
        state_g = net_g.init_hidden(batch.shape[0])
        state_d = net_d.init_hidden(batch.shape[0])

        # Net-G
        optimizer_g.zero_grad()
        x = torch.empty([batch.shape[0], MAX_SEQ, num_fea]).uniform_()
        
        g_fea, _ = net_g(x, state_g)
        d_logit_fake,_ = net_d(g_fea, state_d)
        loss_g = GenLoss()(d_logit_fake)

        loss_g.backward()
        optimizer_g.step()

        # Net-D
        optimizer_d.zero_grad()

        d_logits_real,_,_ = net_d(batch.to(device), state_d)

        d_logits_fake,_,_ = net_d(g_fea.detach(), state_d)

        loss_d = DisLoss(l_s)(d_logits_real, d_logits_fake)

        loss_d.backward()
        optimizer_d.step()

        g_total_loss += loss_g.item()
        d_total_loss += loss_d.item()
        n_correct += (d_logits_real>0.5).sum().item() + (d_logits_fake<0.5).sum().item()

    return all model and losses and accuracy