In [2]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pickle
import pretty_midi
import librosa
import librosa.display
import gc
from sklearn.preprocessing import StandardScaler
import warnings
from collections import Counter
from torch.utils.data import Dataset
import torch



import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset


from sklearn.preprocessing import StandardScaler

from Preprocessing import *
#from ExtractGenre import *
from CNN_ExtractGenre import *
import Util as Util

import DatasetLoader as DL



In [5]:
InputPath = os.path.realpath('YAMF/genres_original')

GenreMapping = {'metal': 0, 'disco': 1, 'classical': 2, 'hiphop': 3, 'jazz': 4,
          'country': 5, 'pop': 6, 'blues': 7, 'reggae': 8, 'rock': 9}

In [135]:
PolData = DL.PolyphonicDataset(Genre = 'jazz')
PolTrainData = DataLoader(PolData, batch_size=30, shuffle=True, num_workers=0)

In [22]:
class Generator(nn.Module):

    def __init__(self, NoiseSize, HowManyInstrument = 4):

        super().__init__()


        self.FFNN = nn.Sequential(
            nn.Linear(NoiseSize, 256*32*4, bias=False),
            nn.BatchNorm1d(256*32*4),
            nn.LeakyReLU()
        )

        self.UpscalingConv = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=5, bias=False, padding=2),
            nn.BatchNorm2d((128)),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, bias=False, padding=2, output_padding=1),
            nn.BatchNorm2d((64)),
            nn.LeakyReLU(),        
            nn.ConvTranspose2d(64, HowManyInstrument, kernel_size=5, stride=2, bias=False, padding=2, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.FFNN(x)
        y = y.reshape((-1, 256, 32, 4))
        y = self.UpscalingConv(y)
        return y

In [23]:
class Discriminator(nn.Module):
    def __init__(self, num_instruments=4):
        super().__init__()
        
        self.model = nn.Sequential(
            # Input: [batch, num_instruments, 128, 16]
            nn.Conv2d(num_instruments, 64, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Flatten(),
            nn.Linear(256 * 16 * 2, 1), 
            nn.Sigmoid()  
        )

    def forward(self, x):
        # x shape: [batch, num_instruments, 128, 16]
        return self.model(x)

In [24]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        # nn.init.normal_(m.weight.data, 0.0, 0.02)
        nn.init.xavier_uniform_(m.weight.data)
    if classname.find('Linear') != -1:
        #nn.init.normal_(m.weight.data, 0.0, 0.02)
        nn.init.xavier_uniform_(m.weight.data)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.2)
        nn.init.constant_(m.bias.data, 0)

In [27]:
G = Generator(256)
G.apply(weights_init)

noise = torch.normal(0, 1, [10, 256])
print(noise.shape)
generated_image = G(noise).detach()
print(generated_image.shape)

discriminator = Discriminator()
discriminator.apply(weights_init)


decision = discriminator(generated_image)
print(decision)


torch.Size([10, 256])
torch.Size([10, 4, 128, 16])
tensor([[0.5305],
        [0.5107],
        [0.5192],
        [0.5313],
        [0.5185],
        [0.5240],
        [0.5039],
        [0.5257],
        [0.5016],
        [0.5226]], grad_fn=<SigmoidBackward0>)


In [28]:
generated_image

tensor([[[[0.5254, 0.4803, 0.4056,  ..., 0.4645, 0.5434, 0.4590],
          [0.4707, 0.5925, 0.3780,  ..., 0.5558, 0.4530, 0.4564],
          [0.4722, 0.4119, 0.2904,  ..., 0.5106, 0.5997, 0.4150],
          ...,
          [0.4753, 0.5831, 0.4585,  ..., 0.5232, 0.5599, 0.4855],
          [0.4824, 0.4774, 0.4131,  ..., 0.4993, 0.3893, 0.4459],
          [0.4697, 0.5798, 0.5100,  ..., 0.5365, 0.5050, 0.5098]],

         [[0.4954, 0.5083, 0.4873,  ..., 0.5898, 0.4235, 0.5305],
          [0.6029, 0.5140, 0.7182,  ..., 0.5204, 0.5628, 0.5112],
          [0.4565, 0.3938, 0.2940,  ..., 0.4683, 0.5304, 0.5056],
          ...,
          [0.5499, 0.4467, 0.5039,  ..., 0.4643, 0.5803, 0.4787],
          [0.5545, 0.5831, 0.4355,  ..., 0.5374, 0.3663, 0.5574],
          [0.5077, 0.4849, 0.5587,  ..., 0.4697, 0.5449, 0.5343]],

         [[0.6329, 0.4644, 0.6164,  ..., 0.4884, 0.5899, 0.4294],
          [0.4618, 0.5876, 0.5451,  ..., 0.5423, 0.4373, 0.5484],
          [0.5194, 0.2954, 0.4422,  ..., 0