In [None]:
import glob
import os
import numpy as np
import matplotlib.pyplot as plt
get_ipython().run_line_magic('matplotlib', 'inline')

from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset

# Formule output width of convolution network (same for height)
# W=(W−K+2P)/S+1 where W=width, K=kernel width, P=padding and S=stride

# Formule output width of transpose convolution network (same for height)
# W=(W−1)*S+K-2P where W=width, K=kernel width, P=padding and S=stride

# GAN for musical accompaniement

# 1- Dataset

In [None]:
# See Code in .py file

In [None]:
dataset = CustomDataset(folder=npy_path)

# 2- Model

## 1) Modules

In [None]:
class TemporalNetwork(nn.Module):
    """
    This module extends the given noise in the time dimension
    """

    def __init__(self, z_dim, n_bars):
        super(TemporalNetwork, self).__init__()
        self.layer_1 = nn.Sequential(
            nn.ConvTranspose2d(z_dim, 1024, (2, 1), stride=(1, 1)),
            nn.BatchNorm2d(1024, momentum=0.9),
            nn.ReLU()
        )
        self.layer_2 = nn.Sequential(
            nn.ConvTranspose2d(1024, z_dim, (n_bars - 1, 1), stride=(1, 1)),
            nn.BatchNorm2d(z_dim, momentum=0.9),
            nn.ReLU()
        )

        nn.init.kaiming_normal_(self.layer_1[0].weight)
        nn.init.kaiming_normal_(self.layer_2[0].weight)

    def forward(self, x):
        # Input if of size (batch, 1 bar, z_dim, 1)
        x = x.permute(0, 2, 1, 3)
        # Input if of size (batch, z_dim, 1, 1)
        x = self.layer_1(x)
        # Input if of size (batch, 1024, 2, 1)
        x = self.layer_2(x)
        # Input if of size (batch, z_dim, n_bars, 1)
        x.permute(0, 2, 1, 3)
        # Input if of size (batch, n_bars, z_dim, 1)

        return x


In [None]:
class BarEncoder(nn.Module):
    """
    This module encodes a whole bar over all tracks as a 1-dim embedding to create short-term memory for
    the Bar Generator
    """
    
    def __init__(self,output_dim):
        super(BarEncoder, self).__init__()
        
        self.layer_1 = nn.Sequential(
            nn.Conv2d(5, 16, (3,12), stride=(3,4), padding=(0,0)),
            nn.LeakyReLU(),
            nn.MaxPool2d((2,2))
        )
        self.layer_2 = nn.Sequential(
            nn.Conv2d(16, 16, (2,3), stride=(2,2), padding=(0,0)),
            nn.LeakyReLU(),
            nn.MaxPool2d((2,1))
        )

        nn.init.kaiming_normal_(self.layer_1[0].weight)
        nn.init.kaiming_normal_(self.layer_2[0].weight)

        self.linear = nn.Linear(in_features=448, out_features=output_dim)

    def forward(self, bar):
        # The bar is of shape (batch,5,1,96,128) = (batch,n_tracks,n_bars,n_steps_per_bat,n_pitches)
        bar = bar.squeeze(2) 
        # Now the shape is (batch,5,96,128)
        out = self.layer_1(bar)
        # Now the shape is (batch,16,16,15)
        out = self.layer_2(out)
        # Now the shape is (batch,16,4,7)
        out = nn.Flatten()(out)
        # Now the shape is (batch,448)
        out = self.linear(out)
        # Now the shape is (batch, output_shape)

        return out

In [None]:
class MelodyEncoder(nn.Module):
    """
    This module encodes the conditionnal melody given as input as a 1-dim embedding to feed
    to the Bar Generator
    """

    def __init__(self, output_dim):
        super(MelodyEncoder, self).__init__()
        self.layer_1 = nn.Sequential(
            nn.Conv3d(1, 32, (1, 4, 12), stride=(1, 4, 4)),
            nn.ReLU(),
            nn.MaxPool3d((1, 2, 2)),
        )
        self.layer_2 = nn.Sequential(
            nn.Conv3d(32, 64, (1, 3, 3), stride=(1, 3, 2), padding=(0,0,1)),
            nn.ReLU(),
            nn.MaxPool3d((1, 2, 2)),
        )
        self.layer_3 = nn.Linear(512, output_dim)

        nn.init.kaiming_normal_(self.layer_1[0].weight)
        nn.init.kaiming_normal_(self.layer_2[0].weight)

    def forward(self, melody):
        # Dimension of melody is (batch, 1, N_bars, 96, 128) = (batch,n_tracks,n_bars,n_steps_per_bat,n_pitches)
        x = self.layer_1(melody)
        # Output is of dimension (batch, 32, N_bars, 12, 15)
        x = self.layer_2(x)
        # Output is of dimension (batch, 64, N_bars, 2, 4)
        x = x.permute(0, 2, 1, 3, 4)
        # Output is of dimension (batch, N_bars, 64, 2, 4)
        x = x.reshape(melody.shape[0], melody.shape[2], -1)
        # Output is of dimension (batch, N_bars, 512)
        x = self.layer_3(x)
        # Output is of dimension (batch, N_bars, output_dim)

        return x

In [None]:
class BarGenerator(nn.Module):
    """
    This module uses 1-dim input, extends it along both the time and the pitch axis and creates the
    next bar for the given track.
    """

    def __init__(self, input_dim):
        super(BarGenerator, self).__init__()
        self.layer_1 = nn.Sequential(
            nn.Linear(input_dim, 1024),
            #nn.BatchNorm1d(1024, momentum=0.9),
            nn.ReLU()
        )
        self.layer_2 = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, (2, 1), stride=(2, 1)),
            #nn.BatchNorm2d(512, momentum=0.9),
            nn.ReLU()
        )
        self.layer_3 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, (2, 1), stride=(2, 1)),
            #nn.BatchNorm2d(256, momentum=0.9),
            nn.ReLU()
        )
        self.layer_4 = nn.Sequential(
            nn.ConvTranspose2d(256, 256, (2, 1), stride=(2, 1)),
            #nn.BatchNorm2d(256, momentum=0.9),
            nn.ReLU()
        )
        self.layer_5 = nn.Sequential(
            nn.ConvTranspose2d(256, 256, (2, 1), stride=(2, 1)),
            #nn.BatchNorm2d(256, momentum=0.9),
            nn.ReLU()
        )
        self.layer_6 = nn.Sequential(
            nn.ConvTranspose2d(256, 256, (2, 1), stride=(2, 1)),
            #nn.BatchNorm2d(256, momentum=0.9),
            nn.ReLU()
        )
        self.layer_7 = nn.Sequential(
            nn.ConvTranspose2d(256, 256, (3, 1), stride=(3, 1)),
            #nn.BatchNorm2d(256, momentum=0.9),
            nn.ReLU()
        )
        self.layer_8 = nn.Sequential(
            nn.ConvTranspose2d(256, 256, (1, 4), stride=(1, 4)),
            #nn.BatchNorm2d(256, momentum=0.9),
            nn.ReLU()
        )
        self.layer_9 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, (1, 4), stride=(1, 4)),
            #nn.BatchNorm2d(128, momentum=0.9),
            nn.ReLU()
        )
        self.layer_10 = nn.Sequential(
            nn.ConvTranspose2d(128, 1, (1, 12), stride=(1, 8), padding=(0, 2)),
            nn.Tanh()
        )

        nn.init.kaiming_normal_(self.layer_2[0].weight)
        nn.init.kaiming_normal_(self.layer_3[0].weight)
        nn.init.kaiming_normal_(self.layer_4[0].weight)
        nn.init.kaiming_normal_(self.layer_5[0].weight)
        nn.init.kaiming_normal_(self.layer_6[0].weight)
        nn.init.kaiming_normal_(self.layer_7[0].weight)
        nn.init.kaiming_normal_(self.layer_8[0].weight)
        nn.init.kaiming_normal_(self.layer_9[0].weight)
        nn.init.kaiming_normal_(self.layer_10[0].weight)

    def forward(self, x):
        # Dimension of input in (batch, 1 bar, input_dim, 1)
        x = x.squeeze(3).squeeze(1)
        # Dimension of x is (batch, input_dim)
        x = self.layer_1(x)
        # Dimension of x is (batch, 1024)
        x = x.unsqueeze(-1).unsqueeze(-1)
        # Dimension of x is (batch, 1024, 1, 1)
        x = self.layer_2(x)
        # Dimension of x is (batch, 512, 2, 1)
        x = self.layer_3(x)
        # Dimension of x is (batch, 256, 4, 1)
        x = self.layer_4(x)
        # Dimension of x is (batch, 256, 8, 1)
        x = self.layer_5(x)
        # Dimension of x is (batch, 256, 16, 1)
        x = self.layer_6(x)
        # Dimension of x is (batch, 256, 32, 1)
        x = self.layer_7(x)
        # Dimension of x is (batch, 256, 96, 1)
        x = self.layer_8(x)
        # Dimension of x is (batch, 256, 96, 4)
        x = self.layer_9(x)
        # Dimension of x is (batch, 128, 96, 16)
        x = self.layer_10(x)
        # Dimension of output is (batch, 1, 96, 128)

        return x

## 2) Main parts

In [None]:
class Generator(nn.Module):
    """
    The Generator of the model
    """

    def __init__(self, batch_size, output_shape, z_dim, melody_output_dim, bar_encoder_output_dim):
        super(Generator, self).__init__()

        self.batch_size = batch_size
        self.n_tracks = output_shape[0]
        self.n_bars = output_shape[1]
        self.bar_encoder_output_dim = bar_encoder_output_dim

        self.melody_encoder = MelodyEncoder(melody_output_dim)
        self.chord_temporal_network = TemporalNetwork(z_dim, self.n_bars)
        self.track_temporal_networks = [TemporalNetwork(z_dim, self.n_bars) for _ in range(self.n_tracks - 1)]
        self.bar_encoder = BarEncoder(bar_encoder_output_dim)
        self.track_bar_generators = [BarGenerator(melody_output_dim + z_dim * 4 + bar_encoder_output_dim) for _ in
                                     range(self.n_tracks - 1)]

        # Register parameters:
        for i,tempnet in enumerate(self.track_temporal_networks):
            for j,param in enumerate(tempnet.parameters()):
                self.register_parameter("tempnet"+str(i)+'_'+str(j),param)
        for i,bargen in enumerate(self.track_bar_generators):
            for j,param in enumerate(bargen.parameters()):
                self.register_parameter('bargen'+str(i)+'_'+str(j),param)

    def forward(self, chord_noise, style_noise, tracks_noise, groove_noise, conditionnal_melody):
        """
          Chord noise: (batch, 1 bar, z_dim, 1)
          Style noise: (batch, 1 bar, z_dim, 1)
          Tracks noise: (batch, N-1 tracks, 1 bar, z_dim, 1)
          Groove noise: (batch, N-1 tracks, 1 bar, z_dim, 1)
          Conditional melody: (batch, 1 track, N bars, 96, 128)
        """
        batch_size = chord_noise.shape[0]

        ####################### Build inputs #######################
        # Encoded melody is (batch, N_bars, melody_output_dim, 1)
        encoded_melody = self.melody_encoder(conditionnal_melody).unsqueeze(-1)

        # Chord noise extended is (batch, N bars, z_dim, 1)
        chord_extended = self.chord_temporal_network(chord_noise)

        # Tracks noise extended is (N tracks, (batch, N bars, z_dim, 1))
        tracks_noise_extended = []
        for track in range(self.n_tracks - 1):
            track_noise_extended = self.track_temporal_networks[track](tracks_noise[:,track].unsqueeze(1))
            tracks_noise_extended.append(track_noise_extended)

        ####################### Concatenate #######################
        # Track bars will be dimension (N_bars, N_tracks-1, (batch, 1, melody_output_dim + z_dim * 4, 1))
        track_bars = []

        for bar in range(self.n_bars):
            tracks = []
            for track in range(self.n_tracks - 1):
                tracks.append(torch.cat((
                    style_noise.squeeze(1), #(batch, z_dim, 1)
                    groove_noise[:, track, :, :, :].squeeze(1), #(batch, z_dim, 1)
                    chord_extended[:, :, bar, :], #(batch, z_dim, 1)
                    tracks_noise_extended[track][:, :, bar, :] #(batch, z_dim, 1)
                ), dim=1).unsqueeze(1) #(batch ,1 bar, z_dim*4, 1)
                             )
                track_bars.append(tracks)

        ####################### Generate bars #######################
        previous_bar = torch.zeros(batch_size, 1, self.bar_encoder_output_dim, 1)
        # Store whole bars
        song = []
        for bar in range(self.n_bars):
            generated_tracks_bar = []
            for track in range(self.n_tracks-1):
                # Concat previous bar: (batch, 1 bar, melody_output_dim + z_dim * 4 + bar_encoder_output_dim, 1)
                track_bar = torch.cat((encoded_melody[:,bar].unsqueeze(1),track_bars[bar][track], previous_bar), dim=2)
                # Generated track bar: (batch, 1 bar, 96, 128)
                generated_track_bar = self.track_bar_generators[track](track_bar)
                # Append unsqueezed track bar: (batch, 1 track, 1 bar, 96, 128)
                generated_tracks_bar.append(generated_track_bar.unsqueeze(1))
            # Concat track: (batch, N-1 tracks, 1 bar, 96, 128)
            whole_bar = torch.cat(generated_tracks_bar, dim=1)
            # Add conditional melody: (batch, N tracks, 1 bar, 96, 128)
            whole_song_bar = torch.cat((conditionnal_melody[:,:,bar].unsqueeze(2),whole_bar), dim=1)
            # Unsqueeze encoded previous bar (batch, bar_encoder_output_dim) -> (batch, 1, bar_encoder_output_dim, 1)
            previous_bar = self.bar_encoder(whole_song_bar).unsqueeze(1).unsqueeze(-1)
            # Store whole bar
            song.append(whole_song_bar)

        # Concat bars: (batch, N tracks, 1 bar, 96, 128) -> (batch, N tracks, N bars, 96, 128)
        song = torch.cat(song, dim=2)

        return song

In [None]:
class Discriminator(nn.Module):
    """
    The Discriminator of the model
    """
    
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()

        n_bars = input_dim[1]

        self.conv1 = nn.Sequential(
            nn.Conv3d(in_channels=5,
                      out_channels=128,
                      kernel_size=(2,1,1),
                      stride=(1,1,1),
                      padding=0,
                      ),
            nn.LeakyReLU()
            )
        nn.init.kaiming_normal_(self.conv1[0].weight)
        

        self.conv2 = nn.Sequential(
            nn.Conv3d(in_channels=128,
                      out_channels=128,
                      kernel_size=(n_bars-1,1,1),
                      stride=(1,1,1),
                      padding=0,
                      ),
            nn.LeakyReLU()
            )
        nn.init.kaiming_normal_(self.conv2[0].weight)

        self.conv3 = nn.Sequential(
            nn.Conv3d(in_channels=128,
                      out_channels=128,
                      kernel_size=(1,1,12),
                      stride=(1,1,12),
                      padding=(0,0,2),
                      ),
            nn.LeakyReLU()
            )
        nn.init.kaiming_normal_(self.conv3[0].weight)

        self.conv4 = nn.Sequential(
            nn.Conv3d(in_channels=128,
                      out_channels=128,
                      kernel_size=(1,1,7),
                      stride=(1,1,4),
                      padding=0,
                      ),
            nn.LeakyReLU()
            )
        nn.init.kaiming_normal_(self.conv4[0].weight)

        self.conv5 = nn.Sequential(
            nn.Conv3d(in_channels=128,
                      out_channels=128,
                      kernel_size=(1,2,1),
                      stride=(1,2,1),
                      padding=(0,0,0),
                      ),
            nn.LeakyReLU()
            )
        nn.init.kaiming_normal_(self.conv5[0].weight)

        self.conv6 = nn.Sequential(
            nn.Conv3d(in_channels=128,
                      out_channels=128,
                      kernel_size=(1,2,1),
                      stride=(1,2,1),
                      padding=(0,0,0),
                      ),
            nn.LeakyReLU()
            )
        nn.init.kaiming_normal_(self.conv6[0].weight)

        self.conv7 = nn.Sequential(
            nn.Conv3d(in_channels=128,
                      out_channels=256,
                      kernel_size=(1,4,1),
                      stride=(1,2,1),
                      padding=(0,0,0),
                      ),
            nn.LeakyReLU()
            )
        nn.init.kaiming_normal_(self.conv7[0].weight)

        self.conv8 = nn.Sequential(
            nn.Conv3d(in_channels=256,
                      out_channels=256,
                      kernel_size=(1,3,1),
                      stride=(1,2,1),
                      padding=(0,0,0),
                      ),
            nn.LeakyReLU()
            )
        nn.init.kaiming_normal_(self.conv8[0].weight)

        self.linear1 = nn.Sequential(
              nn.Linear(in_features=2560,
                        out_features=1024,
                        ),
              nn.LeakyReLU(),
              )
        nn.init.kaiming_normal_(self.linear1[0].weight)

        self.linear2 = nn.Sequential(
            nn.Linear(in_features=1024,
                      out_features=1,
                      ),
            )
        nn.init.kaiming_normal_(self.linear2[0].weight)
        

    def forward(self,song):
        # Dimension of song is (batch, 5, N_bars, 96, 128)
        out = self.conv1(song)
        # Dimension of output is (batch, 128, N_bars-1, 96, 128)
        out = self.conv2(out)
        # Dimension of output is (batch, 128, 1, 96, 128)
        out = self.conv3(out)
        # Dimension of output is (batch, 128, 1, 96, 11)
        out = self.conv4(out)
        # Dimension of output is (batch, 128, 1, 96, 2)
        out = self.conv5(out)
        # Dimension of output is (batch, 128, 1, 48, 2)
        out = self.conv6(out)
        # Dimension of output is (batch, 128, 1, 24, 2)
        out = self.conv7(out)
        # Dimension of output is (batch, 256, 1, 11, 2)
        out = self.conv8(out)
        # Dimension of output is (batch, 256, 1, 5, 2)
        out = nn.Flatten()(out)
        # Dimension of output is (batch, 2560)
        out = self.linear1(out)
        # Dimension of output is (batch, 1024)
        out = self.linear2(out)
        # Dimension of output is (batch, 1)
        
        return out

## 3) Putting it all together

In [None]:
class HarmonyGAN:
    """
    A GAN Model for music accompaniement, based on MuseGAN. It uses the Generator and Discriminator classes
    defined above
    """
  
    def __init__(self,
                 grad_weight = 10,
                 z_dim = 32,
                 batch_size = 16,
                 n_bars = 3,
                 n_steps_per_bar = 96,
                 melody_embed_dim = 16,
                 bar_embed_dim = 16,
                ):
    
        self.name = 'HarmonyGAN'
    
        self.z_dim = z_dim
    
        self.n_tracks = 5
        self.n_bars = n_bars
        self.n_steps_per_bar = n_steps_per_bar
        self.n_pitches = 128
        
        self.input_dim = (self.n_tracks,n_bars,n_steps_per_bar,self.n_pitches)

        self.grad_weight = grad_weight
        self.batch_size = batch_size

        # Keep losses in memory during training:
        self.d_losses = []
        self.g_losses = []
        
        # Initialize number of epochs trained:
        self.epoch = 0
        
        # Build Model:
        self.D = Discriminator(self.input_dim).float()
        self.G = Generator(self.batch_size,
                           self.input_dim, 
                           self.z_dim, 
                           melody_embed_dim,
                           bar_embed_dim).float()

        
        

    def _reset_gradient(self):
        self.G_optimizer.zero_grad()
        self.D_optimizer.zero_grad()
        
    
    def train(self,dataset,epochs=500,save_every_n_epochs=50,d_loops=1,clamp_weights=0.01,lr_G = 0.00005,lr_D = 0.00005):
        """
        Train the model on the given dataset using Adam optimizer and WGAN algorithm on three G losses
        """

        # Build Optimizers:
        self.G_optimizer = torch.optim.RMSprop(self.G.parameters(), lr=lr_G)
        self.D_optimizer = torch.optim.RMSprop(self.D.parameters(), lr=lr_D)

        # Iterator for epochs. Epochs is added to the number of already-trained epochs.
        tqdm_epochs = tqdm(range(self.epoch,self.epoch+epochs), desc='Training ', unit='epoch',initial=self.epoch,total=self.epoch+epochs)



        for epoch in tqdm_epochs:
            
            self.epoch+=1

            # Randomly shuffle batches:
            np.random.shuffle(dataset.items)
    
            tqdm_dataloader = tqdm(range(len(dataset)), desc='D_loss  ... - G_loss  ...', unit='batches', leave = False)
            losses_d_internal = []
            losses_g_internal = []
    
            for i in tqdm_dataloader:
                true_songs = dataset[i]

                

                # Random permutations within the batch:
                idx = torch.randperm(true_songs.shape[0])
                true_songs = true_songs[idx]
                true_songs = true_songs[:self.batch_size]

                batch_size = true_songs.shape[0]


            
                # true_song is of shape (batch,n_bars,n_steps_per_bar,n_pitches,n_tracks) and uses int8
                true_songs = true_songs.permute(0,4,1,2,3).float()
                # Change the range of the notes to [-1,1]:
                true_songs = 2*(true_songs-0.5)
            
                # true_song is of shape (batch,n_tracks,n_bars,n_steps_per_bar,n_pitches)
                true_melodies = true_songs[:,0,:,:,:].unsqueeze(1)
                # true_melody is of shape (batch,1,n_bars,n_steps_per_bar,n_pitches)

                ####################################################################
                #################### Train Discriminator ###########################
                ####################################################################
                for loop in range(d_loops):
                    # D loops are used to insure that the D losses are not nullified by the slower learning of the generator
                    # This is due to the fact that G has significantly more parameters than D, as is common in GANs


                    # Generate fake song batch:
                    chord_noise = torch.randn(batch_size,1,self.z_dim,1)
                    style_noise = torch.randn(batch_size,1,self.z_dim,1)
                    tracks_noise = torch.randn(batch_size,self.n_tracks-1,self.z_dim,1)
                    groove_noise = torch.randn(batch_size,self.n_tracks-1,1,self.z_dim,1)
                    fake_songs = self.G(chord_noise,style_noise,tracks_noise,groove_noise,true_melodies) #Use real melodies to generate the output

                    # Tensors containing labels of either true or fake songs and neutral labels for the random average of images:
                    positive_labels = torch.ones(batch_size,1)
                    negative_labels = -torch.ones(batch_size,1)
                    neutral_labels = torch.zeros(batch_size,1)

                    ################## For True Songs ############### 
                    # Forward pass:
                    true_scores = self.D(true_songs)
                    # Compute discriminator Loss:
                    d_loss_real = torch.mean(true_scores)
                    

                    ################# For Fake Songs ################
                    # Forward pass:
                    fake_scores = self.D(fake_songs)
                    # Compute discriminator Loss:
                    d_loss_fake = torch.mean(fake_scores)


                    ################### Total Loss ##################
                    d_loss = -d_loss_real + d_loss_fake
                    # Reset gradients of both G and D optimizers:
                    self._reset_gradient()
                    # Backpropagation on D:
                    d_loss.backward()
                    # Optimize the weights of D:
                    self.D_optimizer.step()


                    # Add loss to list:
                    losses_d_internal.append(d_loss.data.item())


                    # Clamp weights of D:
                    for param in self.D.parameters():
                        param.data.clamp_(-clamp_weights, clamp_weights)




                ####################################################################
                ######################## Train Generator ###########################
                ####################################################################

                # Generate fake song batch:
                chord_noise = torch.randn(batch_size,1,self.z_dim,1)
                style_noise = torch.randn(batch_size,1,self.z_dim,1)
                tracks_noise = torch.randn(batch_size,self.n_tracks-1,self.z_dim,1)
                groove_noise = torch.randn(batch_size,self.n_tracks-1,1,self.z_dim,1)
                fake_songs = self.G(chord_noise,style_noise,tracks_noise,groove_noise,true_melodies) #Use real melodies to generate the output


                # Get discriminator's predictions:
                outputs = self.D(fake_songs)

                # Get generator Loss to have D predict the output as a real image:
                g_loss = -torch.mean(outputs)
                losses_g_internal.append(g_loss.data.item())

                # Reset gradient of both G and D:
                self._reset_gradient()

                # Backpropagation on both G and D:
                g_loss.backward()

                # Optimize only the weights of G:
                self.G_optimizer.step()

                # Show losses during training:
                tqdm_dataloader.set_description(desc='D_loss  '+str(round(np.mean(losses_d_internal),3))+' - G_loss  '+str(round(np.mean(losses_g_internal),3)), refresh=True)

            # Average losses across epoch:
            self.d_losses.append(np.mean(losses_d_internal))
            self.g_losses.append(np.mean(losses_g_internal))

            # Save every n epochs:
            if self.epoch%save_every_n_epochs==0:
                print('Model and losses saved.')
                self.save_model(version='intermediary')
    

    def _binarize(self,generated,thresh):
        """
        Takes an array of probabilities as an input and returns the binarized version
        """
        return np.where(generated>thresh, 1.0, 0.0)

    
    def accompaniement(self,conditionnal_track,thresh=0.2):
        """ 
        Creates an accompaniement for the conditionnal track with random noises
        Returns an array, binarized.
        
        """
        chord_noise = torch.randn(1,1,self.z_dim,1)
        style_noise = torch.randn(1,1,self.z_dim,1)
        tracks_noise = torch.randn(1,self.n_tracks-1,self.z_dim,1)
        groove_noise = torch.randn(1,self.n_tracks-1,1,self.z_dim,1)
        generated = self.G(chord_noise,style_noise,tracks_noise,groove_noise,conditionnal_track).data.numpy()
        return self._binarize(generated,thresh)
    
    
    def show_losses(self,directory='',show=True):
        """
        Displays all losses and saves them in the running directory
        """
        
        fig, ax = plt.subplots(1,1,figsize=(10,10))
        #plt.plot([x[0] for x in self.d_losses], label='Critic loss on real scores', alpha=0.7)
        plt.plot(self.d_losses, label='Critic loss', alpha=1)
        #plt.plot([x[1] for x in self.d_losses], label='Critic loss on generated scores', alpha=0.7)
        #plt.plot([x[2] for x in self.d_losses], label='Critic Partial loss (avg scores)', alpha=0.7)
        plt.plot(self.g_losses, label='Generator loss', alpha=1)
        
        
        ax.legend(loc='upper right')

        plt.xlabel('Epochs', fontsize=18)
        plt.ylabel('loss', fontsize=16)

        plt.savefig(directory+'losses.png')

        if show:
            plt.show()
    

    def save_model(self,version='final'):
        """ 
        Method to save the model either at the end of the run or at a certain epoch
        """
        
        if not os.path.isdir('Models'):
            os.mkdir('Models')
            os.mkdir('Models/final/')
            os.mkdir('Models/intermediary/')
        
        
        if version=='final':
            directory = 'Models/final/'
            subdirs = glob.glob(directory+'*/')
            last_run_number = -1
            for d in subdirs:
                run_number = int(d.split('_')[-1][:-1])
                if run_number>last_run_number:
                    last_run_number=run_number
            directory=directory+'Model_'+str(last_run_number+1)
            os.mkdir(directory)
            
        elif version=='intermediary':
            directory = 'Models/intermediary/'
            subdirs = glob.glob(directory+'*/')
    
            
            directory=directory+'Epoch_'+str(self.epoch)
            os.mkdir(directory)
        
        # Save models:
        torch.save(self.G,directory+'/Generator.ckpt')
        torch.save(self.D,directory+'/Discriminator.ckpt')
        torch.save(self.G.bar_encoder,directory+'/BarEncoder.ckpt')
        torch.save(self.G.melody_encoder,directory+'/MelodyEncoder.ckpt')
        torch.save(self.G.chord_temporal_network,directory+'/ChordsTemporalNetwork.ckpt')
        for i in range(self.n_tracks-1):
            torch.save(self.G.track_temporal_networks[i],directory+'/TracksTemporalNetwork_'+str(i+1)+'.ckpt')
            torch.save(self.G.track_bar_generators[i],directory+'/BarGenerator_'+str(i+1)+'.ckpt')
            
        # Save losses:
        self.show_losses(directory=directory+'/',show=False)
    


# 3 - Use

In [None]:
batch_size = 32

In [None]:
harmony = HarmonyGAN(n_bars=8,batch_size=batch_size)

In [None]:
harmony.show_losses()

# 4 - Examples

In [None]:
final_np=np.load('/content/gdrive/My Drive/Perso/H2/small_dataset.npy').astype(np.float32)
# Order: Drums, Piano, Guitar, Bass, Strings
final_np = final_np[:,:,:,:,[1,2,4,3,0]]
# Order: Piano, Guitar, Strings, Bass, Drums
ds = TensorDataset(torch.from_numpy(final_np))

In [None]:
[reference_song] = ds[15]
reference_song = reference_song.unsqueeze(0).permute(0,4,1,2,3)
melody = reference_song[:,0,:,:,:].unsqueeze(1)
print(melody.is_cuda)

In [None]:
accompaniement = harmony.accompaniement(melody,thresh=0.3)

In [None]:
accompaniement.shape

In [None]:
def tensor_song_to_array(t_song):
    if type(t_song)==torch.Tensor:
        t_song = t_song.data.numpy()
    _,nb_tracks,nb_bars,steps_per_bar,pitches = t_song.shape
    song = t_song.reshape((nb_tracks,nb_bars*steps_per_bar,pitches))
    return song

In [None]:
accompaniement = tensor_song_to_array(accompaniement)
reference_song = tensor_song_to_array(reference_song)

In [None]:
import subprocess
import sys
def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])  
install('pypianoroll')
import pypianoroll

In [None]:
def array_to_pypianoroll(array,tempo=60):
    # Order: Piano, Guitar, Strings, Bass, Drums
    programs = [1, # Accoustic Piano
                29, # Electric muted guitar
                49, # Orchestral Strings
                34, # Electric Bass Finger
                118, # DrumSet
               ]
    is_drum = [False,False,False,False,True]
    tracks = []
    for track in range(array.shape[0]):
        tracks.append(pypianoroll.Track(pianoroll=array[track,:,:],
                                        program=programs[track], 
                                        is_drum=is_drum[track]))
    return pypianoroll.Multitrack(tracks=tracks,tempo=tempo,beat_resolution=96//4)


In [None]:
accompaniement = array_to_pypianoroll(accompaniement)
reference_song = array_to_pypianoroll(reference_song)

In [None]:
fig,ax=pypianoroll.plot_multitrack(reference_song,track_label='program')
fig.set_size_inches(10,10)
plt.savefig('pianoroll_reference.png')

In [None]:
fig,ax=pypianoroll.plot_multitrack(accompaniement,track_label='program')
fig.set_size_inches(10,10)
plt.savefig('pianoroll_generated.png')

In [None]:
reference_song.write('reference.mid')
accompaniement.write('generated.mid')