In [None]:
import sys
sys.path.insert(0, '../')

import torch

from data.audioLoader import AudioLoader
from data.trainDataset import TrainDataset
from ganSystem import GANSystem
import logging

# logging.getLogger().setLevel(logging.DEBUG)  # set root logger to debug

"""Just so logging works..."""
formatter = logging.Formatter('%(name)s:%(levelname)s:%(message)s')
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter(formatter)
logging.getLogger().addHandler(console_handler)
"""Just so logging works..."""

__author__ = 'Andres'

signal_split = [480, 64, 480]
md = 32

params_stft_discriminator = dict()
params_stft_discriminator['stride'] = [2, 2, 2, 2, 2]
params_stft_discriminator['nfilter'] = [md, 2 * md, 4 * md, 8 * md, 16 * md]
params_stft_discriminator['shape'] = [[5, 5], [5, 5], [5, 5], [5, 5], [5, 5]]
params_stft_discriminator['data_size'] = 2

params_mel_discriminator = dict()
params_mel_discriminator['stride'] = [2, 2, 2, 2, 2]
params_mel_discriminator['nfilter'] = [md//4, 2 * md//4, 4 * md//4, 8 * md//4, 16 * md//4]
params_mel_discriminator['shape'] = [[5, 5], [5, 5], [5, 5], [5, 5], [5, 5]]
params_mel_discriminator['data_size'] = 2

params_generator = dict()
params_generator['stride'] = [2, 2, 2, 2, 2]
params_generator['nfilter'] = [8 * md, 4 * md, 2 * md, md, 1]
params_generator['shape'] = [[4, 4], [4, 4], [8, 8], [8, 8], [8, 8]]
params_generator['padding'] = [[1, 1], [1, 1], [3, 3], [3, 3], [3, 3]]
params_generator['residual_blocks'] = 2

params_generator['full'] = 256 * md
params_generator['summary'] = True
params_generator['data_size'] = 2
params_generator['in_conv_shape'] = [16, 2]
params_generator['borders'] = dict()
params_generator['borders']['nfilter'] = [md, 2 * md, md, md / 2]
params_generator['borders']['shape'] = [[5, 5], [5, 5], [5, 5], [5, 5]]
params_generator['borders']['stride'] = [2, 2, 2, 2]
params_generator['borders']['data_size'] = 2
params_generator['borders']['border_scale'] = 1
# This does not work because of flipping, border 2 need to be flipped tf.reverse(l, axis=[1]), ask Nathanael
params_generator['borders']['width_full'] = None

# Optimization parameters inspired from 'Self-Attention Generative Adversarial Networks'
# - Spectral normalization GEN DISC
# - Batch norm GEN
# - TTUR ('GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium')
# - ADAM  beta1=0 beta2=0.9, disc lr 0.0004, gen lr 0.0001
# - Hinge loss
# Parameters are similar to the ones in those papers...
# - 'PROGRESSIVE GROWING OF GANS FOR IMPROVED QUALITY, STABILITY, AND VARIATION'
# - 'LARGE SCALE GAN TRAINING FOR HIGH FIDELITY NATURAL IMAGE SYNTHESIS'
# - 'CGANS WITH PROJECTION DISCRIMINATOR'

params_optimization = dict()
params_optimization['batch_size'] = 64
params_stft_discriminator['batch_size'] = 64
params_mel_discriminator['batch_size'] = 64

params_optimization['n_critic'] = 1
params_optimization['generator'] = dict()
params_optimization['generator']['optimizer'] = 'adam'
params_optimization['generator']['kwargs'] = [0.5, 0.9]
params_optimization['generator']['learning_rate'] = 1e-4
params_optimization['discriminator'] = dict()
params_optimization['discriminator']['optimizer'] = 'adam'
params_optimization['discriminator']['kwargs'] = [0.5, 0.9]
params_optimization['discriminator']['learning_rate'] = 1e-4

# all parameters
params = dict()
params['net'] = dict()  # All the parameters for the model
params['net']['generator'] = params_generator
params['net']['stft_discriminator'] = params_stft_discriminator
params['net']['mel_discriminator'] = params_mel_discriminator
params['net']['prior_distribution'] = 'gaussian'
params['net']['shape'] = [1, 512, 1024]  # Shape of the image
params['net']['inpainting'] = dict()
params['net']['inpainting']['split'] = signal_split
params['net']['gamma_gp'] = 10  # Gradient penalty
# params['net']['fs'] = 16000//downscale
params['net']['loss_type'] = 'wasserstein'

params['optimization'] = params_optimization
params['summary_every'] = 250  # Tensorboard summaries every ** iterations
params['print_every'] = 50  # Console summaries every ** iterations
params['save_every'] = 1000  # Save the model every ** iterations
# params['summary_dir'] = os.path.join(global_path, name +'_summary/')
# params['save_dir'] = os.path.join(global_path, name + '_checkpoints/')

args = dict()
args['generator'] = params_generator
args['stft_discriminator_count'] = 2
args['mel_discriminator_count'] = 3
args['stft_discriminator'] = params_stft_discriminator
args['mel_discriminator'] = params_mel_discriminator
args['borderEncoder'] = params_generator['borders']
args['stft_discriminator_in_shape'] = [1, 512, 64]
args['mel_discriminator_in_shape'] = [1, 80, 64]
args['mel_discriminator_start_powscale'] = 2
args['generator_input'] = 1440
args['optimizer'] = params_optimization
args['split'] = signal_split
args['log_interval'] = 100
args['spectrogram_shape'] = params['net']['shape']
args['gamma_gp'] = params['net']['gamma_gp']
args['tensorboard_interval'] = 500
args['save_path'] = '../saved_results/'
args['experiment_name'] = 'real_data'
args['save_interval'] = 10000

args['fft_length'] = 1024
args['fft_hop_size'] = 256
args['sampling_rate'] = 22050

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

examples_per_file = 32
audioLoader = AudioLoader(args['sampling_rate'], args['fft_length'], args['fft_hop_size'], 50)

dataFolder = "../../Datasets/maestro-v2.0.0/"

ganSystem = GANSystem(args)


In [None]:
start_at_step = 400000
start_at_epoch = 1

ganSystem.loadModel(start_at_step, start_at_epoch)

In [None]:
import numpy as np
from data.baseDataset import BaseDataset

__author__ = 'Andres'


class ValidDataset(BaseDataset):
    def _sliceAudio(self, audio):
        return audio

    def _saveNewFile(self, name, audio, spectrogram):
        self._loaded_files[name] = [0, spectrogram, audio]
        self._index += 1

    def __getitem__(self, unused_index):
        filename = self._selectFile()
        spectrogram, audio = self._loaded_files[filename][1], self._loaded_files[filename][2]

        starts = np.random.randint(0, spectrogram.shape[1] - self._window_size, self._examples_per_file)

        spectrograms = np.zeros([self._examples_per_file, self._audio_loader.windowLength()//2+1, self._window_size], dtype=np.float64)
        audio_length = self._window_size*self._audio_loader.hopSize()
        audios = np.zeros([self._examples_per_file, audio_length])

        for index, start in enumerate(starts):
            spectrograms[index] = spectrogram[:, start:start + self._window_size]
            audio_start = np.min([start*self._audio_loader.hopSize(), audio.shape[0]-audio_length])
            audios[index] = audio[audio_start:audio_start+audio_length]
        self._usedFilename(filename)

        return spectrograms[:, :-1], audios

In [None]:
from data.audioLoader import AudioLoader
# from data.validDataset import ValidDataset

examples_per_file = 1
loader_batch_size = 64

dataFolder = "../../../../Datasets/maestro-v2.0.0/"

validDataset = ValidDataset(dataFolder, window_size=1024, audio_loader=audioLoader, examples_per_file=examples_per_file, loaded_files_buffer=20, file_usages=1)

valid_loader = torch.utils.data.DataLoader(validDataset,
    batch_size=loader_batch_size//examples_per_file, shuffle=False,
                                           num_workers=0, drop_last=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for data in valid_loader:
    print('loaded')
    spectrograms=data[0]
    audios = data[1]
    
    audios = audios.view(loader_batch_size, -1)
    
    spectrograms = spectrograms.to(device).float()
    spectrograms = spectrograms.view(loader_batch_size, *args['spectrogram_shape'])
    spectrograms = torch.repeat_interleave(spectrograms, int(args['optimizer']['batch_size']/loader_batch_size), 0)
    
    left_borders = spectrograms[:, :, :, :args['split'][0]]
    right_borders = spectrograms[:, :, :, args['split'][0] + args['split'][1]:]
    print('generate')
    generated_spectrograms = ganSystem.generateGap([left_borders, right_borders])

    fake_spectrograms = torch.cat((left_borders, generated_spectrograms, right_borders), 3)

    break

In [None]:
encoded = ganSystem.border_encoders[0](ganSystem.time_average(ganSystem.mel_spectrogram(left_borders), 4))
print(encoded.shape)
import matplotlib.pyplot as plt
plt.hist(encoded.detach().cpu().numpy().flatten(), 30)

In [None]:
# plt.hist(ganSystem.time_average(ganSystem.mel_spectrogram(left_borders), 4).detach().cpu().numpy().flatten(), 30)
plt.imshow(ganSystem.time_average(ganSystem.mel_spectrogram(left_borders), 4).detach().cpu().numpy()[0,0])

In [None]:
import numpy as np

fig,a =  plt.subplots(2, 4, figsize=(6, 2), dpi=300, sharey=False, sharex=False)
input_spectrogram = ganSystem.time_average(ganSystem.mel_spectrogram(left_borders), 4).detach().cpu().numpy()

for inner in range(4):
    im_bottom = a[0, inner].imshow(np.flip(input_spectrogram[inner+4, 0], 0),  cmap='inferno', vmax=0.41)
    im_top = a[1, inner].imshow(np.flip(np.mean(encoded[inner+4].detach().cpu().numpy(), axis=0), 0),  cmap='inferno')

    a[0, inner].tick_params(
        axis='both',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=False,      # ticks along the bottom edge are off,
        left=False,
        top=False,         # ticks along the top edge are off
        labelleft=False,
        labelbottom=False) # labels along the bottom edge are off
    a[1, inner].tick_params(
        axis='both',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=False,      # ticks along the bottom edge are off
        left=False,
        top=False,         # ticks along the top edge are off
        labelleft=False,
        labelbottom=False) # labels along the bottom edge are off


fig.tight_layout(pad=0.0, w_pad=1.0, h_pad=1.0)
plt.savefig('left_border_mean_encoded_2.png')

In [None]:
import matplotlib.gridspec as gridspec
import numpy as np

input_spectrogram = ganSystem.time_average(ganSystem.mel_spectrogram(left_borders), 4).detach().cpu().numpy()

fig = plt.figure(figsize=(12, 8), dpi=300, constrained_layout=True)

gs = gridspec.GridSpec(4, 4, figure=fig, wspace=0.05, hspace=0.05)

ax00 = fig.add_subplot(gs[:2, :2])
ax00.imshow(np.flip(input_spectrogram[2, 0, :78], 0),  cmap='inferno') #Try to match the ratio of the encoded representation
ax00.tick_params(
        axis='both',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=False,      # ticks along the bottom edge are off,
        left=False,
        top=False,         # ticks along the top edge are off
        labelleft=False,
        labelbottom=False) # labels along the bottom edge are off

for i in range(4):
    for j in range(4):
        if i<2 and j<2:
            continue
        
        ax01 = fig.add_subplot(gs[i, j])
        ax01.imshow(np.flip(encoded[2, 4*i+j].detach().cpu().numpy(), axis=0),  cmap='inferno')
        ax01.tick_params(
        axis='both',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=False,      # ticks along the bottom edge are off,
        left=False,
        top=False,         # ticks along the top edge are off
        labelleft=False,
        labelbottom=False) # labels along the bottom edge are off

                
plt.savefig('left_border_encoded_v5.png')