# Libraries

In [1]:
# %%capture
# !pip install nnAudio
# !pip install openunmix
# !pip install torchlibrosa

In [1]:
1+1

2

In [146]:
import torch
from torch import nn
import copy
import torch.nn.init as init
from torchaudio.transforms import MelScale, AmplitudeToDB
from nnAudio import Spectrogram
from openunmix.transforms import make_filterbanks
import torchlibrosa as tl
from functools import reduce
import math
import torch.nn.functional as F
from torch.autograd import Function
import os
import torch.optim as optim
import re
import json
import matplotlib.pyplot as plt
import torchaudio
import numpy as np
import datetime
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import librosa
import time
import pyloudnorm as pyln
import pandas as pd
import tqdm
import random

# Model Tools

## Config

In [62]:
class settings:
    PAIR_DIR = './DJtransGAN-dg-pipeline/data/track/pair'
    MIX_DIR = './DJtransGAN-dg-pipeline/data/mix'
    STORE_DIR = './results'


    # STFT Parameters
    WINDOW = 'hann'
    N_FFT = 2048  # 256 lose time resolution
    HOP_LENGTH = 512   # 128
    N_MELS = 128

    # Band Parameters
    BAND_FREQS = [20, 300, 5000, 20000]


    # Others
    SR = 44100 # sampling rate
    EPSILON = 1e-12 # avoid divide zeros
    RANDOM_SEED = 0     # random seed
    N_TIME = 60
    CUE_BAR = 8

## STFT

In [63]:
class NNaudioSTFT(nn.Module):
    def __init__(self,
                 n_fft=settings.N_FFT,
                 hop_length=settings.HOP_LENGTH,
                 sr=settings.SR,
                 power=1,
                 center=True,
                 length=None):
        super(NNaudioSTFT, self).__init__()
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.power = power
        self.center = center
        self.length = length
        self.stft = Spectrogram.STFT(n_fft=n_fft, hop_length=512,
                                     center=center, sr=sr, iSTFT=True)


    def inverse(self, mags, phases, length=None):

        if self.power == 2:
            mags = torch.sqrt(mags)

        matrixs = torch.stack([mags*torch.cos(phases),
                               mags*torch.sin(phases)], dim=-1)
        stft = self.stft.to(mags.device)
        length = length if length is not None else self.length
        if length is not None:
            waves = stft.inverse(matrixs, onesided=True, length=length)
        else:
            waves = stft.inverse(matrixs, onesided=True)
        return waves.unsqueeze(1)

    def forward(self, waves, phase=True):
        stft = self.stft.to(waves.device)
        specs = stft(waves)
        reals = specs[..., 0]
        imags = specs[..., 1]
        mags = torch.sqrt(reals**2 + imags**2)
        if phase:
            phases = torch.atan2(imags, reals)
            return mags, phases
        return mags,

In [64]:
class AsteroidSTFT(nn.Module):
    def __init__(self,
                 n_fft=settings.N_FFT,
                 hop_length=settings.HOP_LENGTH,
                 sr=settings.SR,
                 center=True,
                 power=1,
                 length=None):

        super(AsteroidSTFT, self).__init__()
        self.power = power
        fbs = make_filterbanks(n_fft=n_fft,
                               n_hop=hop_length,
                               center=center,
                               sample_rate=sr,
                               method='asteroid')
        self.length = length
        self.encoder = fbs[0]
        self.decoder = fbs[1]

    def inverse(self, mags, phases, length=None):
        if self.power == 2:
            mags = torch.sqrt(mags)
        reals = mags * torch.cos(phases)
        imags = mags * torch.sin(phases)
        matrixs = torch.stack((reals, imags), axis=-1)
        length = length if length is not None else self.length
        return self.decoder(matrixs, length=length)


    def forward(self, waves, phase=True):
        matrixs = self.encoder(waves)
        reals = matrixs[..., 0]
        imags = matrixs[..., 1]
        mags = reals**2 + imags**2
        if self.power == 1:
            mags = torch.sqrt(mags)

        if phase:
            phases = torch.atan2(imags, reals)
            return mags, phases
        return mags,

In [65]:
class TorchlibrosaSTFT(nn.Module):
    def __init__(self,
                 n_fft=settings.N_FFT,
                 hop_length=settings.HOP_LENGTH,
                 sr=settings.SR,
                 power=1,
                 length=None):

        super(TorchlibrosaSTFT, self).__init__()
        self.power = power
        self.length = length
        self.encoder = tl.STFT(n_fft=n_fft, hop_length=hop_length)
        self.decoder = tl.ISTFT(n_fft=n_fft, hop_length=hop_length)

    def swapaxes(self, data):
        return torch.swapaxes(data, 2, 3)

    def inverse(self, mags, phases, length=None):
        if self.power == 2:
            mags = torch.sqrt(mags)
        reals = self.swapaxes(mags * torch.cos(phases))
        imags = self.swapaxes(mags * torch.sin(phases))
        length = length if length is not None else self.length
        return self.decoder(reals, imags, length).unsqueeze(1)

    def forward(self, waves, phase=True):
        reals, imags  = self.encoder(waves.squeeze(1))
        reals = self.swapaxes(reals)
        imags = self.swapaxes(imags)
        mags = reals**2 + imags**2
        if self.power == 1:
            mags = torch.sqrt(mags)

        if phase:
            phases = torch.atan2(imags, reals)
            return mags, phases
        return mags,

## Help funcs

In [66]:
def init_weights(m):
    '''
    Usage:
        model = Model()
        model.apply(weight_init)
    '''
    if isinstance(m, nn.Conv1d):
        init.normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv3d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose1d):
        init.normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose3d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.BatchNorm1d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm3d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        init.xavier_normal_(m.weight.data)
        init.normal_(m.bias.data)
    elif isinstance(m, nn.LSTM):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.LSTMCell):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.GRU):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.GRUCell):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)

In [67]:
def get_drop_func(name, prob):
    return {
        'drop1d': nn.Dropout(p=prob),
        'drop2d': nn.Dropout2d(p=prob),
        }.get(name)

In [68]:
def get_activate_func(name):
    return {
        'tanh': nn.Tanh() ,
        'relu': nn.ReLU() ,
        'prelu': nn.PReLU(),
        'lrelu': nn.LeakyReLU(),
        'sigmoid': nn.Sigmoid(),
        }.get(name)

In [69]:
def get_norm_func(name, out_dim, **kargs):
    n_group = kargs.get('group')
    return {
        'batch1d': nn.BatchNorm1d(out_dim),
        'batch2d': nn.BatchNorm2d(out_dim),
        'layer': nn.LayerNorm(out_dim),
        'group': nn.GroupNorm(n_group if n_group else 1, out_dim)
    }.get(name)

In [70]:
def get_class_name(class_):
    return class_.__class__.__name__

In [71]:
def estimate_channel(n_down):
    step  = 2**(n_down)
    return int(settings.N_MELS // step)

In [72]:
def estimate_cnn_out_dim(n_down, last_dim):
    return int(last_dim/estimate_channel(n_down))

In [73]:
def get_mel_func():
    return MelScale(n_mels=settings.N_MELS,
                    n_stft=settings.N_FFT // 2 + 1,
                    sample_rate=settings.SR)

In [74]:
def get_stft_func(**kargs): # stype: asteroid, nnaudio, torchlibrosa

    stft_type = kargs.get('stft_type')
    if stft_type:
        kargs.pop('stft_type')
    else:
        stft_type = 'torchlibrosa'

    return {
        'nnaudio' : NNaudioSTFT(**kargs),
        # 'asteroid': AsteroidSTFT(**kargs),
        'torchlibrosa': TorchlibrosaSTFT(**kargs),
    }[stft_type]

In [75]:
def get_amp2db_func(power=1):
    return AmplitudeToDB(stype=['amplitude', 'power'][int(power)-1])

In [76]:
def estimate_frame(in_dim, n_down, in_type):
    step = 2**n_down
    frame_dict = {'time': round(in_dim*settings.SR/settings.HOP_LENGTH),
                  'sample': round(in_dim/settings.HOP_LENGTH),
                  'frame': in_dim
                 }
    frame = frame_dict.get(in_type)

    if frame is None:
        print(f'do not found {in_type} input type')

    return int(math.ceil(frame/step))

In [77]:
def estimate_mlps_in_dim(in_dim, n_down, in_type='time'):
    out_dims = (estimate_channel(n_down),
                estimate_frame(in_dim, n_down, in_type))
    out_dim  = reduce(lambda x, y: x*y, out_dims)
    return (out_dim, out_dims)

In [78]:
def estimate_postprocessor_in_dim(encoder_args):
    in_dim = encoder_args.get('last_dim')
    if in_dim is None:
        in_dim = estimate_mlps_in_dim(encoder_args.get('n_time', settings.N_TIME),
                                     len(encoder_args.get('out_dims')))[0]
    return in_dim

In [79]:
def estimate_postprocessor_out_dims(processor_args, unzipper_args):
    n_band   = unzipper_args['n_band']
    n_fader  = unzipper_args['n_fader']
    n_param  = unzipper_args['n_param']
    last_dim = (n_band-1) * 2 + n_band * n_fader * n_param
    out_dims = [last_dim]
    if processor_args.get('out_dims'):
        out_dims = processor_args.get('out_dims') + out_dims
    return out_dims

In [80]:
def unsqueeze_dim(data, num):
    for i in range(num):
        data = data.unsqueeze(-1)
    return data

In [81]:
def linear_transform(v, minv, maxv):
    return (maxv - minv) * v + minv

In [82]:
def linear_transform_inv(v, minv, maxv, maxy=6):
    return maxy / (maxv - minv) / v

In [83]:
class custom_relu6(Function):
    @staticmethod
    def forward(self, inp):
        return F.relu6(inp)

    @staticmethod
    def backward(self, grad_out):
        return grad_out

In [84]:
def prelu6(x, start, sacle, mode='custom'):
    if mode == 'custom':
        return custom_relu6.apply(custom_relu6.apply(x - start) * sacle)
    return F.relu6(F.relu6(x - start) * sacle)

In [85]:
def choose_fade_shape(fade, fade_shape='linear'):
    transforms = {
        'linear': fade,
        'exponential': torch.pow(2, (fade - 1)) * fade,
        'logarithmic': (torch.log10(.1 + fade) + 1) / 1.0414,
        'quarter_sine': torch.sin(fade * math.pi / 2),
        'half_sine': torch.sin(fade * math.pi - math.pi / 2) / 2 + 0.5
    }
    if fade_shape not in transforms.keys():
        return
    return transforms[fade_shape]

In [86]:
def fade_transform(curve, fade_type, fade_shape):
    return {
        'fi': choose_fade_shape(curve, fade_shape),
        'fo': choose_fade_shape(1. - curve, fade_shape)
    }[fade_type]

In [87]:
def purify_device(data, idxs=None):
    if idxs is None:
        idxs = range(len(data))
    if isinstance(data, list):
        return [d.detach().cpu().clone() if i in idxs else d for i, d in enumerate(data)]
    if isinstance(data, tuple):
        return tuple([d.detach().cpu().clone() if i in idxs else d for i, d in enumerate(data)])
    if isinstance(data, torch.Tensor):
        return data.detach().cpu().clone()

## Conv2d

In [88]:
class Conv2d(nn.Module):
    def __init__(self,
                 in_dim,
                 out_dim,
                 k_size = 3, # kernel_size
                 stride = 1,
                 drop_prob = None,
                 activate = None,
                 norm_type = None):

        super(Conv2d, self).__init__()
        self.conv = nn.Conv2d(in_dim, out_dim, k_size, stride=stride, padding=k_size//2)
        self.dropout = get_drop_func('drop2d', drop_prob) if drop_prob else None
        self.activate = get_activate_func(activate) if activate  else None
        self.norm = get_norm_func(norm_type, out_dim) if norm_type else None
        self.apply(init_weights)

    def forward(self, x):
        x = self.conv(x)

        if self.dropout:
            x = self.dropout(x)
        if self.norm:
            if get_class_name(self.norm) == 'LayerNorm':
                norm = nn.LayerNorm(x.size()[1:]).to(x.device)
                x = norm(x)
            else:
                x = self.norm(x)

        if self.activate:
            x = self.activate(x)

        return x

## Blocks

In [89]:
class Conv2dBlock(nn.Module):
    def __init__(self,
                 in_dim,
                 out_dim,
                 k_size = 3,
                 stride = 2,
                 activate = 'lrelu',
                 norm_type = 'batch2d'):

        super(Conv2dBlock, self).__init__()

        strides = [stride, 1]
        in_dims = [in_dim, out_dim]
        self.convs = nn.ModuleList([Conv2d(in_dim, out_dim, k_size=k_size, stride=stride,
                                           activate=activate, norm_type=norm_type) for (in_dim, stride) in zip(in_dims, strides)])

    def forward(self, x):
        for conv in self.convs:
            x = conv(x)
        return x

In [90]:
class PoolingBlock(nn.Module):
    def __init__(self,
                 out_dim = 1024, # tuple -> 2D, int -> 1D
                 activate = None,
                 pool_type = 'avg'):
        super(PoolingBlock, self).__init__()

        if isinstance(out_dim, tuple):
            self.data_type = '2d'
            self.pool = {'avg': nn.AdaptiveAvgPool2d, 'max': nn.AdaptiveMaxPool2d}[pool_type](out_dim)
        else:
            self.data_type = '1d'
            self.pool = {'avg': nn.AdaptiveAvgPool1d, 'max': nn.AdaptiveMaxPool1d}[pool_type](out_dim)

        self.activate = get_activate_func(activate)

    def foward(self, x):
        if self.data_type == '1d':
            y = torch.flatten(x, start_dim=1).unsqueeze(1)
            y = self.pool(y).squeeze(1)
        else:
            y = self.pool(x)

        if self.activate:
            y = self.activate(y)
        return y

## CNN

In [91]:
class CNN(nn.Module):
    def __init__(self,
                 in_dim = 1,
                 out_dims = [4, 8, 16],
                 kernel_size = 3,
                 last_dim = None, # last dim is provied for poolblock to compress it to specific size
                 activate = 'lrelu',
                 norm_type = 'batch2d'):
        super(CNN, self).__init__()

        self.n_down = len(out_dims)
        self.n_layer = len(out_dims) + 1
        in_dims = [in_dim] + out_dims[:-1]
        conv_blocks = [Conv2dBlock(in_dim, out_dim, k_size=kernel_size, activate=activate,
                                   norm_type = norm_type) for (in_dim, out_dim) in zip(in_dims, out_dims)]

        conv_blocks += [Conv2d(out_dims[-1], 1, k_size=1, activate=activate, norm_type=norm_type)]
        self.conv_blocks = nn.ModuleList(conv_blocks)

        if last_dim:
            out_dim = estimate_cnn_out_dim(self.ndown, last_dim)
            self.pool = PoolingBlock((None, out_dim), pool_type='avg')
        else:
            self.pool = None

        self.apply(init_weights)


    def forward(self, x):
        for conv_block in self.conv_blocks:
            x = conv_block(x)

        if self.pool:
            x = self.pool(x)
        return x

## ResConv2dBlock

In [92]:
class ResConv2dBlock(nn.Module):
    def __init__(self,
                 in_dim,
                 out_dim,
                 k_size=3,
                 stride=2,
                 activate='relu',
                 norm_type='batch2d'):

        super(ResConv2dBlock, self).__init__()
        strides=[stride, 1, stride]
        in_dims=[in_dim, out_dim, in_dim]
        self.convs=nn.ModuleList([Conv2d(in_dim,
                                         out_dim,
                                         k_size=k_size,
                                         stride=stride,
                                         activate=activate,
                                         norm_type=norm_type) for (in_dim, stride) in zip(in_dims, strides)])
        self.diff = (strides[-1] != 1) or (in_dim != out_dim)
        self.activate = get_activate_func(activate)

    def forward(self, x):
        y = x.clone()
        for conv in self.convs[:2]:
            y = conv(y)

        if self.diff:
            x = self.convs[-1](x)

        y = x + y
        y = self.activate(y)
        return y

## ResCNN

In [93]:
class ResCNN(nn.Module):
    def __init__(self,
                 in_dim=1,
                 out_dims=[4, 8, 16],
                 last_dim=None, # last dim is provied for poolblock to compress it to specific size
                 k_size=3,
                 activate='lrelu',
                 norm_type='batch2d'):

        super(ResCNN, self).__init__()

        self.n_down = len(out_dims)
        self.n_layer = len(out_dims) + 1
        in_dims = [in_dim] + out_dims[:-1]
        conv_blocks = [ResConv2dBlock(in_dim, out_dim, k_size=k_size,
                                      activate=activate, norm_type=norm_type) for (in_dim, out_dim) in zip(in_dims, out_dims)]
        conv_blocks += [Conv2d(out_dims[-1], 1, k_size=1, activate=activate, norm_type=norm_type)]
        self.conv_blocks = nn.ModuleList(conv_blocks)

        if last_dim:
            out_dim = estimate_cnn_out_dim(last_dim=last_dim, n_down=self.ndown)
            self.pool = PoolingBlock((None, out_dim), pool_type='avg')
        else:
            self.pool = None

        self.apply(init_weights)

    def forward(self, x):
        for conv_block in self.conv_blocks:
            x = conv_block(x)

        if self.pool:
            x = self.pool(x)
        return x

## Encoder

In [94]:
class Encoder(nn.Module):
    def __init__(self, encoder_args):
        super(Encoder, self).__init__()

        encoder_args = copy.deepcopy(encoder_args)

        if 'times' in encoder_args.keys():
            encoder_args.pop('times')

        cnn_type = encoder_args.pop('cnn_type')
        net = {'cnn': CNN, 'res_cnn': ResCNN}[cnn_type]

        self.net = net(**encoder_args)
        self.apply(init_weights)

        self.n_layer = self.net.n_layer
        self.n_down = self.net.n_down


    def forward(self, x):
        return self.net(x)

## Frontend

In [95]:
class Frontend(nn.Module):
    def __init__(self,
                 stft_type = 'torchlibrosa',
                 power=1,
                 length=None):
        super(Frontend, self).__init__()

        self.power = power
        self.stft_type = stft_type
        self.mel_scale = get_mel_func()
        self.stft = get_stft_func(stft_type=stft_type,
                                  power=power,
                                  length=length)
        self.amp2db = get_amp2db_func(power=power)

    def mel_sacle(self, mag):
        mel_mag = self.mel_scale.to(mag.device)(mag)
        mel_mag = self.amp2db.to(mel_mag.device)(mel_mag)
        return mel_mag

    def inverse(self, mags, phases, length=None):
        return self.stft.to(mags.device).inverse(mags, phases, length)

    def forward(self, x, phase=False):
        out_tuple = ()
        if phase:
            mag, phase = self.stft.to(x.device)(x, phase=phase)
            out_tuple += (mag, phase)
        else:
            if len(x.size()) < 4:
                mag,  = self.stft.to(x.device)(x, phase=phase)
            else:
                mag   = x
            out_tuple = (mag,)
        mel_mag   = self.mel_sacle(mag)
        out_tuple = (mel_mag,) + out_tuple
        return out_tuple

## Unzipper

In [96]:
class Unzipper(nn.Module):
    def __init__(self, n_band=4, n_fader=1, n_param=2):

        super(Unzipper, self).__init__()
        self.shape = {'band':  ((n_band - 1), 2),
                      'fader': (n_band, n_fader, n_param)}
        self.n_params = {k: reduce(lambda x, y: x*y, v)  for (k, v) in self.shape.items()}
        self.s_param = reduce(lambda x, y: x+y, self.n_params.values()) # total num of parameters

    def forward(self, params):
        assert params.size(-1) == self.s_param
        unzip_params = {}

        if self.n_params['band'] == 0:
            unzip_params['fader'] = params
        elif self.n_params['fader'] == 0:
            unzip_params['band']  = params
        else:
            (band_params, fader_params) = torch.split(params, tuple(self.n_params.values()), dim=-1)
            unzip_params['band']  = band_params
            unzip_params['fader'] = fader_params

        return {k: unzip_params[k].view((-1,) + self.shape[k]) for k in unzip_params}

## Context

In [97]:
class Context(nn.Module):
    def __init__(self,
                 activate  = 'lrelu',
                 norm_type = 'batch2d',
                ):
        super(Context, self).__init__()
        self.conv = Conv2d(2,
                           1,
                           k_size=1,
                           activate=activate,
                           norm_type=norm_type)
        self.apply(init_weights)

    def forward(self, x1, x2):
        if len(x1.size()) == 3:
            y = torch.stack((x1, x2), axis=1)
        else:
            y = torch.cat((x1, x2), axis=1)
        y = self.conv(y)
        return y

## MLPs

In [98]:
class MLP(nn.Module):
    def __init__(self,
                 in_dim,
                 out_dim,
                 drop_prob = None,
                 activate = None,
                 norm_type = None):
        super(MLP, self).__init__()
        self.linear = nn.Linear(in_dim, out_dim)
        self.dropout = get_drop_func('drop1d', drop_prob) if drop_prob else None
        self.norm = get_norm_func(norm_type, out_dim) if norm_type else None
        self.activate = get_activate_func(activate) if activate else None
        self.apply(init_weights)

    def forward(self, x):
        x = self.linear(x)

        if self.dropout:
            x = self.dropout(x)
        if self.norm:
            if get_class_name(self.norm) == 'LayerNorm':
                norm = nn.LayerNorm(x.size()[1:]).to(x.device)
                x = norm(x)
            else:
                x = self.norm(x)

        if self.activate:
            x = self.activate(x)
        return x

In [99]:
class MLPs(nn.Module):
    def __init__(self,
                 in_dim = 1,
                 out_dims = [512, 256],
                 activate = 'lrelu',
                 norm_type = 'batch1d'):

        super(MLPs, self).__init__()

        self.n_layer = len(out_dims)+2
        in_dims = [in_dim] + out_dims[:-1]
        mlps = [MLP(in_dim, out_dim,
                    activate=activate,
                    norm_type = norm_type) for (in_dim, out_dim) in zip(in_dims[:-1],
                                                                        out_dims[:-1])]
        mlps += [MLP(in_dims[-1], out_dims[-1])]
        self.mlps = nn.ModuleList(mlps)
        self.apply(init_weights)

    def forward(self, x):
        if len(x.size()) > 2:
            x = torch.flatten(x, start_dim=1)

        for i, mlp in enumerate(self.mlps):
            x = mlp(x)
        return x

## PostProcessor

In [100]:
class PostProcessor(nn.Module):
    def __init__(self,
                 processor_args,
                 context_args=None):

        super(PostProcessor, self).__init__()

        processor_args = copy.deepcopy(processor_args)
        self.pool_type = processor_args.pop('pool_type')
        self.context = Context(**context_args) if context_args else None

        if 'last_activate' in processor_args.keys():
            last_activate = processor_args.pop('last_activate')
        else:
            last_activate = None

        if 'loss_type' in processor_args.keys():
            processor_args.pop('loss_type')

        if self.pool_type:
            self.predictor = MLPs(**processor_args)
        else:
            if 'norm' in processor_args.keys():
                processor_args.pop('norm')
            self.predictor = PoolingBlock(**processor_args)

        self.last_activate = get_activate_func(last_activate)

    def forward(self, x, z=None):

        if self.context and z is not None:
            x = self.context(x, z)

        y = self.predictor(x)
        if self.last_activate:
            y = self.last_activate(y)

        return y

## Faders

In [101]:
class Relu6Faders(nn.Module):
    def __init__(self, sum_type, fade_type, fade_shape):
        super(Relu6Faders, self).__init__()
        self.sum_type = sum_type
        self.fade_type = fade_type
        self.fade_shape = fade_shape

    def calibrate(self, curves):
        return curves - curves[..., :1]

    def normalize(self, curves):
        return curves / curves[..., -1:]

    def expand_to(self, source, target):
        return source.unsqueeze(-1).expand_as(target)

    def unzip(self, params, minv, maxv):
        if self.sum_type == 'mean':
            if isinstance(minv, torch.Tensor):
                minv = unsqueeze_dim(minv, len(params[..., 0].size()) - len(minv.size()))
            if isinstance(maxv, torch.Tensor):
                maxv = unsqueeze_dim(maxv, len(params[..., 0].size()) - len(maxv.size()))

            start = linear_transform(params[..., 0], minv, maxv)
            slope = linear_transform_inv(params[..., 1], start, maxv)
            if params.size(-1) > 2:
                weight = params[..., 2]
            else:
                return [start, slope]
        elif self.sum_type == 'sum':
            new_params = [[], []]
            if isinstance(minv, torch.Tensor):
                minv = unsqueeze_dim(minv, len(params[..., 0, 0].size()) - len(minv.size()))
            if isinstance(maxv, torch.Tensor):
                maxv = unsqueeze_dim(maxv, len(params[..., 0, 0].size()) - len(maxv.size()))

            for index in range(params.size(-2)):
                start = linear_transform(params[..., index, 0], minv , maxv)
                slope = linear_transform_inv(params[..., index, 1], start, maxv)
                new_params[0].append(start)
                new_params[1].append(slope)
                if params.size(-1) > 2:
                    weight = params[..., 2]
                    new_params[2].append(weight)
                minv = start + (maxv / slope)
            return [torch.stack(new_param, axis=-1) for new_param in new_params]

    def get_raw_curve(self, waves, params):
        x = torch.linspace(0, 6, waves.size(-1), device=waves.device)
        if len(params.size()) > 3:
            n_batch, n_wave, n_fader, n_param = params.size()
            return x.expand(n_batch, n_wave, n_fader, -1)
        else:
            n_batch, n_fader, n_param = params.size()
            return x.expand(n_batch, n_fader, -1)

    def render_curve(self, x, start, slope):
        curves = prelu6(x, start, slope) / 6
        curves = fade_transform(curves, self.fade_type, self.fade_shape)
        return curves

    def sum_curves(self, curves, waves):
        if self.sum_type == 'mean':
            sum_curve = torch.mean(curves, dim=-2)
            sum_curve = sum_curve.unsqueeze(-2).expand_as(waves)
            return sum_curve
        elif self.sum_type == 'sum':
            sum_curve = torch.sum(curves, axis=-2)
            sum_curve = self.normalize(self.calibrate(sum_curve))
            sum_curve = sum_curve.unsqueeze(-2).expand_as(waves)
            return sum_curve

    def forward(self, waves, params, minv=0, maxv=1):
        '''
        Args:
            wave (Tensor)  : (n_batch, n_channel, n_samples)
            params (Tensor): (n_batch, n_fader  , n_param)

            or

            wave (Tensor)  : (n_batch, n_wave, n_fader, n_samples)
            params (Tensor): (n_batch, n_wave, n_fader, n_samples)
        '''
        params = params[:waves.size(0), ...]
        raw_curves = self.get_raw_curve(waves, params)
        unzip_params = self.unzip(params, minv * 6, maxv * 6)
        params = [self.expand_to(param, raw_curves) for param in unzip_params]
        processed_curves = self.render_curve(raw_curves, params[0], params[1])
        processed_curves = processed_curves * params[2] if len(params) > 2 else processed_curves
        processed_curves = self.sum_curves(processed_curves, waves)
        processed_waves  = processed_curves * waves
        return processed_waves, purify_device(processed_curves)

## Masks

In [102]:
BAND_TO_FADE = {
    'low' : 'fo',
    'high': 'fi'
}

In [103]:
class VMask(nn.Module):
    def __init__(self, sum_type, band_type, fade_shape, split_type):

        '''
        Args:
            sum_type  :  (mean, sum)
            band_type :  (low, high)
            fade_shape:  (linear, exponential, logarithmic, quarter_sine, half_sine)
            split_type:  (bias, equal)

        '''

        super(VMask, self).__init__()
        self.nyqf = settings.SR / 2
        self.band_freqs = settings.BAND_FREQS
        self.fader = Relu6Faders(sum_type, BAND_TO_FADE[band_type], fade_shape)
        self.band_type = band_type
        self.split_type = split_type

    def freq2ratio(self, freq):
        return freq / self.nyqf

    def get_bounded_ratio(self, index, n_band):
        if self.split_type == 'bias':
            min_freq = self.band_freqs[index]
            max_freq = self.band_freqs[index + 1]
            return self.freq2ratio(min_freq), self.freq2ratio(max_freq)
        elif self.split_type == 'equal':
            unit = self.nyqf / n_band
            min_freq = unit * index
            max_freq = unit * (index + 1)
            return self.freq2ratio(min_freq), self.freq2ratio(max_freq)

    def get_bounded_ratios(self, params, total_band, start_index):

        if total_band is None:
            n_band  = params.size(1) + 1
        else:
            n_band  = total_band
        bounded_ratios = [self.get_bounded_ratio(index + start_index, n_band) for index in range(params.size(1))]
        return bounded_ratios

    def curve_fillout(self, curves):
        n_batch, _, n_channel, n_bins = curves.size()
        zeros = torch.zeros(n_batch, 1, n_channel, n_bins, device=curves.device)
        ones  = torch.ones(n_batch , 1, n_channel, n_bins, device=curves.device)

        return {
            'low' : torch.cat((zeros, curves, ones), axis=1),
            'high': torch.cat((zeros, torch.flip(curves, dims=[1]), ones), axis=1)
        }[self.band_type]

    def band_exciter(self, curves):
        curves_fill = self.curve_fillout(curves)
        curves_diff = torch.diff(curves_fill, axis=1)
        return curves_diff

    def create_fake_waves(self, mags):
        return torch.ones_like(mags[..., 0], device=mags.device)

    def create_mask(self, curves, n_frames):
        expand_dims = list(curves.size()) + [n_frames]
        return curves.unsqueeze(-1).expand(expand_dims)

    def masking(self, mags, curves):
        masks = self.create_mask(curves, mags.size(-1))
        mags  = mags.unsqueeze(1).expand_as(masks)
        return mags * masks, masks



    def forward(self, mags, params, total_band=None, start_index=0):
        '''
        Args:
            mags (Tensor)  : (n_batch, n_channel , n_bins, n_frames)
            params (Tensor): (n_batch, n_band - 1, 2)

        Return:
        '''
        n_batch, n_channel , n_bins, n_frames = mags.size()
        fake_waves = self.create_fake_waves(mags)
        bounded_ratios = self.get_bounded_ratios(params, total_band, start_index)

        curves = torch.stack([self.fader(fake_waves,
                                         params[:, index:index+1, :],
                                         bounded_ratios[index][0],
                                         bounded_ratios[index][1]) [0] for index in range(params.size(1))], axis=1)
        processed_curves = self.band_exciter(curves)
        processed_mags, processed_masks = self.masking(mags, processed_curves)
        return processed_mags, purify_device(processed_masks), purify_device(processed_curves), purify_device(curves)


In [104]:
class HMask(nn.Module):
    def __init__(self, sum_type, fade_type, fade_shape):

        '''
        Args:
            sum_type  :  (mean, sum)
            fade_type :  (fi, fo)
            fade_shape:  (linear, exponential, logarithmic, quarter_sine, half_sine)
        '''

        super(HMask, self).__init__()
        self.nyqf = settings.SR / 2
        self.fader = Relu6Faders(sum_type, fade_type, fade_shape)
        self.fade_type = fade_type

    def create_fake_waves(self, mags):
        return torch.ones_like(mags[..., 0, :], device=mags.device)

    def create_mask(self, curves, n_bins):
        expand_dims = list(curves.size())
        expand_dims.insert(-1, n_bins)
        return curves.unsqueeze(-2).expand(expand_dims)

    def masking(self, mags, curves):
        masks = self.create_mask(curves, mags.size(-2))
        return mags * masks, masks

    def forward(self, mags, params, cue_region=None):
        '''
        Args:
            mags (Tensor)  : (n_batch, n_channel, n_bins, n_frames)
            params (Tensor): (n_batch, n_fader, n_params)

        Return:
        '''
        n_batch = params.size(0)
        n_band = params.size(1)
        cue_region = torch.tensor([0, 1]).expand(n_batch, -1) if cue_region is None else cue_region
        fake_waves = self.create_fake_waves(mags)
        processed_curves = self.fader(fake_waves, params, cue_region[..., 0], cue_region[..., 1])[0]
        processed_mags, processed_masks = self.masking(mags, processed_curves)
        return processed_mags, purify_device(processed_masks), purify_device(processed_curves)

## Mixer

In [105]:
class Mixer(nn.Module):
    def __init__(self, band_args, fader_args):
        super(Mixer, self).__init__()
        self.vmask = self.build_vmask(band_args)
        self.hmask = self.build_hmask(fader_args)
        self.out_type = band_args['out_type']

    def build_vmask(self, band_args):
        return VMask(band_args['sum_type'],
                     band_args['band_type'],
                     band_args['fade_shape'],
                     band_args['split_type'])

    def build_hmask(self, fader_args):
        return HMask(fader_args['sum_type'],
                     fader_args['fade_type'],
                     fader_args['fade_shape'])

    def _get_cue_region(self, cue_region):
        return torch.tensor([0, 1]) if cue_region is None else cue_region

    def forward(self,
                mags,
                band_params=None,
                fader_params=None,
                cue_region=None):
        out_dict = {}
        if band_params is not None:
            processed_mags, band_masks, _, band_curves = self.vmask(mags, band_params)
            out_dict['band'] = band_curves
            processed_masks  = band_masks
        else:
            processed_mags = mags.unsqueeze(1)
            processed_masks = torch.ones_like(processed_mags).cpu()

        if fader_params is not None:
            cue_region  = self._get_cue_region(cue_region).to(mags.device)
            processed_mags, fader_masks, fader_curves = self.hmask(processed_mags,
                                                                   fader_params,
                                                                   cue_region=cue_region)
            out_dict['fader'] = fader_curves
            processed_masks *= fader_masks

        out_dict['mask'] = processed_masks

        if processed_mags.size(1) > 1:
            if self.out_type == 'mix':
                processed_mags = torch.sum(processed_mags, axis=1)
        else:
            processed_mags = processed_mags.squeeze(1)
        return processed_mags, out_dict

# Model

## Generator

In [106]:
class Generator(nn.Module):
    def __init__(self,
                 unzipper_args,
                 encoder_args,
                 context_args,
                 processor_args,
                 mixer_args):

        super(Generator, self).__init__()

        self.data_types = ['prev', 'next']
        self.hyper = self.get_hyperparameter(unzipper_args, encoder_args,
                                             context_args, processor_args,
                                             mixer_args)
        self.frontend = Frontend()
        self.unzipper = Unzipper(**unzipper_args)

        processor_args['in_dim'] = estimate_postprocessor_in_dim(encoder_args)
        processor_args['out_dims'] = estimate_postprocessor_out_dims(processor_args, unzipper_args)
        processor_args['last_activate'] = 'sigmoid'


        self.encoder = Encoder(encoder_args)
        self.post_processors = nn.ModuleList([PostProcessor(processor_args, context_args) for i in range(2)])
        self.mixers = [Mixer(mixer_args.get('band_args'), {**mixer_args.get('fader_args'), 'fade_type': fade_type}) for fade_type in ['fo', 'fi']]


    def get_hyperparameter(self,
                           unzipper_args,
                           encoder_args,
                           context_args,
                           processor_args,
                           mixer_args):
        return {
            'unzipper_args': copy.deepcopy(unzipper_args),
            'encoder_args': copy.deepcopy(encoder_args),
            'context_args': copy.deepcopy(context_args),
            'processor_args': copy.deepcopy(processor_args),
            'mixer_args': copy.deepcopy(mixer_args)
        }


    def encode(self, in_waves, phase=False):
        out_tuple = ()
        encodeds = [self.frontend(in_wave, phase=phase) for in_wave in in_waves]
        mel_mags = [encoded[0] for encoded in encodeds]
        mags = [encoded[1] for encoded in encodeds]
        vecs = [self.encoder(mel_mag) for mel_mag in mel_mags]
        out_tuple = (vecs, mags)

        if phase:
            phases = [encoded[2] for encoded in encodeds]
            out_tuple += (phases,)
        return out_tuple

    def render(self, in_mags, in_vecs, cue_region=None):
        render_params = [self.unzipper(processor(*in_vecs)) for processor in self.post_processors]
        render_outs = [mixer(in_mag,
                             band_params = render_param.get('band'),
                             fader_params = render_param.get('fader'),
                             cue_region = cue_region) for (in_mag, render_param, mixer) in zip(in_mags, render_params, self.mixers)]
        return render_outs

    def mix(self, in_datas):
        return torch.sum(torch.stack(in_datas, axis=1), axis=1)

    # Apply Inverse STFT to get audio back
    def infer(self,
              *in_waves,
              cue_region=None,
              mix=True):

        with torch.no_grad():

            length = in_waves[0].size(-1)
            in_vecs, in_mags, in_phases = self.encode(in_waves, phase=True)

            render_outs =  self.render(in_mags, in_vecs, cue_region=cue_region)
            out_results = {k: render_out[1] for (k, render_out) in zip(self.data_types, render_outs)}

            out_mags = [render_out[0] for render_out in render_outs]
            out_mags = [torch.sum(out_mag, axis=1) if len(out_mag.size()) == 5 else out_mag for out_mag in out_mags]
            out_waves = [self.frontend.inverse(mag, phase, length) for (mag, phase) in zip(out_mags, in_phases)]

            if mix:
                out_datas = self.mix(out_waves)
            else:
                out_datas = {key: wave for (key, wave) in zip(self.data_types, out_waves)}

            return out_datas, out_results

    def forward(self,
                *in_waves,
                cue_region=None,
                mix=True):

        length = in_waves[0].size(-1)
        encoded = self.encode(in_waves)
        in_vecs, in_mags = encoded

        render_outs =  self.render(in_mags, in_vecs, cue_region=cue_region)
        out_results = {key: render_out[1] for (key, render_out) in zip(self.data_types, render_outs)}
        out_datas = [render_out[0] for render_out in render_outs]

        if mix:
            out_datas = self.mix(out_datas)

        return out_datas, out_results

## Discriminator

In [107]:
class Discriminator(nn.Module):
    def __init__(self,
                 encoder_args,
                 processor_args):

        super(Discriminator, self).__init__()

        self.hyper = self.get_hyperparameter(encoder_args, processor_args)

        self.frontend = Frontend()
        encoder_args = copy.deepcopy(encoder_args)

        processor_args['in_dim'] = estimate_postprocessor_in_dim(encoder_args)
        processor_args['out_dims'] = processor_args.get('out_dims') + [2 if processor_args.get('loss_type') == 'minmax' else 1]
        processor_args['last_activate'] = 'sigmoid' if processor_args.get('loss_type') == 'minmax' else None
        self.post_processor = PostProcessor(processor_args)
        self.encoder = Encoder(encoder_args)

    def get_hyperparameter(self, encoder_args, processor_args):
        return {
                'encoder_args'  : copy.deepcopy(encoder_args),
                'processor_args': copy.deepcopy(processor_args)
               }

    def encode(self, in_data):
        return self.encoder(self.frontend(in_data)[0])

    def forward(self, in_data):
        return self.post_processor(self.encode(in_data))

# Init Model

In [108]:
def get_generator(cnn_type='res_cnn'): # loss_type

    encoder_args   = {
        'in_dim': 1,
        'out_dims': [4, 8, 16],
        'last_dim': None,
        'cnn_type': cnn_type,
        'activate': 'relu',
        'norm_type': 'batch2d'
    }

    processor_args = {
        'out_dims': [1024, 512],
        'activate': 'lrelu',
        'norm_type': 'batch1d',
        'pool_type': 'mlps'
    }

    context_args = {
        'activate': 'lrelu',
        'norm_type': 'batch2d'
    }

    unzipper_args  = {
        'n_band': 4,
        'n_fader': 1,
        'n_param': 2
    }

    mixer_args= {
        'band_args'  : {
            'split_type': 'bias',
            'out_type': 'track',
            'sum_type': 'mean',
            'band_type': 'low',
            'fade_shape': 'linear'
        },
        'fader_args' : {
            'sum_type': 'mean',
            'fade_shape': 'linear'
        }
    }

    return Generator(unzipper_args,
                     encoder_args,
                     context_args,
                     processor_args,
                     mixer_args)

In [109]:
def get_discriminator(cnn_type, loss_type):

    encoder_args = {
        'in_dim': 1,
        'out_dims': [4, 8, 16],
        'last_dim': None,
        'cnn_type': cnn_type,
        'activate': 'lrelu',
        'norm_type': 'batch2d'
     }

    processor_args = {
        'out_dims': [1024, 512],
        'activate': 'lrelu',
        'norm_type': 'batch1d',
        'pool_type': 'mlps',
        'loss_type': loss_type
    }

    return Discriminator(encoder_args,
                         processor_args)

In [110]:
def get_net(**kwargs):
    cnn_type = kwargs.get('cnn_type', 'res_cnn')
    loss_type = kwargs.get('loss_type', 'minmax')

    G = get_generator(cnn_type)
    D = get_discriminator(cnn_type, loss_type)
    return (G, D)

# Trainer

### Losses

In [111]:
def get_labels(data, true=False):
    return torch.ones_like(data) if true else torch.zeros_like(data)

In [112]:
class MinMaxLoss(nn.Module):
    def __init__(self, reduction='mean'):
        super(MinMaxLoss, self).__init__()
        self.reduction = reduction
        self.criterion = nn.BCELoss(reduction=reduction)

    def generator_loss(self, dgz):
        reals = get_labels(dgz, true=True)
        return self.criterion(dgz, reals)

    def discriminator_loss(self, dgz, dx):
        fakes = get_labels(dgz, true=False)
        reals = get_labels(dx, true=True)
        fake_loss = self.criterion(dgz, fakes)
        real_loss = self.criterion(dx, reals)
        return fake_loss + real_loss

    def forward(self, dgz, dx=None):
        if dx is None:
            return self.generator_loss(dgz)
        return self.discriminator_loss(dgz, dx)


class LeastSquaresLoss(nn.Module):
    def __init__(self, a=0.0, b=1.0, c=1.0, reduction='mean'):
        super(LeastSquaresLoss, self).__init__()
        self.reduction = reduction
        self.a = a
        self.b = b
        self.c = c

    def generator_loss(self, dgz):
        return 0.5 * reduce((dgz - self.c) ** 2, self.reduction)

    def discriminator_loss(self, dgz, dx):
        fake_loss = 0.5 * (reduce((dgz - self.a) ** 2, self.reduction))
        real_loss = 0.5 * (reduce((dx - self.b) ** 2, self.reduction))
        return fake_loss + real_loss

    def forward(self, dgz, dx=None):
        if dx is None:
            return self.generator_loss(dgz)
        return self.discriminator_loss(dgz, dx)

## Help funcs

In [113]:
def get_device():
    return torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [114]:
def get_criterion(loss_name):
    criterion_dict = {
        'minmax': MinMaxLoss(),
        'least_square': LeastSquaresLoss()
    }

    criterion = criterion_dict.get(loss_name)

    if criterion is None:
        print(f'Cannot find {loss_name} criterion')
        return

    return criterion

## DataLoader

In [115]:
def batch_collect(batch):
    item_nums = len(batch[0])
    batch_out = ()
    for i in range(item_nums):
        if isinstance(batch[0][i], torch.Tensor):
            batch_item = torch.stack(pad_sequence([item[i].permute(1, 0) for item in batch],
                                                   batch_first=True).permute(0, 2, 1))
        else:
            batch_item = torch.stack([item[i] for item in batch])
        batch_out  += (batch_item, )

    return batch_out

In [116]:
def batchlize(dataset, batch_size, shuffle=False, custom=False):
    if custom:
        return DataLoader(dataset,
                          batch_size=batch_size,
                          shuffle=shuffle,
                          drop_last=True,
                          collate_fn=batch_collect)
    return DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=shuffle,
                        drop_last=True)

In [117]:
class DataLoaderSampler():
    def __init__(self, dataset, batch_size, drop_last=True, shuffle=True):
        self.count = 0
        self.dataset = dataset
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.shuffle = shuffle

        self.current = self.get_new_dataloader()
        self.length = len(self.current)

    def get_new_dataloader(self):
        return iter(batchlize(self.dataset, self.batch_size, shuffle=self.shuffle))

    def __call__(self):
        self.count += 1
        if self.count > self.length:
            self.current = self.get_new_dataloader()
            self.length = len(self.current)
            self.count += 1

        return next(self.current)

## Optim

In [118]:
def get_optimizer(net, optim_name, lr, **kwargs):
    optimizer_dict = {
        'SGD': optim.SGD(net.parameters(), lr=lr, momentum=kwargs.get('momentum', 0.9)),
        'Adam': optim.Adam(net.parameters(), lr=lr, betas=kwargs.get('betas', (0., 0.999)), eps=kwargs.get('eps', 1e-08)),
        'RMSProp': optim.RMSprop(net.parameters(), lr=lr, alpha=kwargs.get('alpha', 0.99), eps=kwargs.get('eps', 1e-08))
    }

    optimizer = optimizer_dict.get(optim_name)

    if optimizer is None:
        print(f'Cannot find {optim_name} optimizer')
        return

    return optimizer

In [119]:
def check_nan(data):
    return (torch.sum(torch.isnan(torch.tensor(data))) >= 1).item()

## Storer

In [120]:
def check_exist(out_path):
    if re.compile(r'^.*\.[^\\]+$').search(out_path):
        out_path = os.path.split(out_path)[0]
    existed = os.path.exists(out_path)
    if not existed:
        os.makedirs(out_path, exist_ok=True)
    return existed

In [121]:
def out_json(data, out_path, cover=True):
    if check_exist(out_path) and cover == False: return
    with open(out_path, 'w') as outfile: json.dump(data, outfile)

In [122]:
def out_pt(data, out_path):
    check_exist(out_path)
    torch.save(data, out_path)

In [123]:
def loss_visualize(loss, plt, title=None):
    if isinstance(loss, list):
        loss = torch.tensor(loss)

    plt.plot(loss['G'], label='G Loss')
    plt.plot(loss['D'], label='D Loss')

    if title:
        plt.suptitle(title)

    plt.legend(loc='best')

    return plt

In [124]:
def save_figure(plt, out_path):
    plt.savefig(out_path, transparent=False, facecolor='w', edgecolor='w', orientation='portrait')

In [125]:
def out_audio(data, out_path, sr=settings.SR):
    if isinstance(data, np.ndarray):
        data = torch.from_numpy(data).float()

    if len(data.size()) == 1:
      data = data.unsqueeze(0)
    torchaudio.save(out_path, data, sr)

In [126]:
class Storer:
    def __init__(self, dataset, n_sample, out_dir, hyper):
        sampler = DataLoaderSampler(dataset, batch_size=n_sample, shuffle=False)
        stored_out = sampler()

        self.out_dir = out_dir
        self.pair_audio = stored_out[0]
        self.cue = stored_out[1]

        self.loss = {'D': [], 'G': []}
        self.loss_type = hyper['train']['loss_type']
        self.store_hyperparameter(hyper)

    # Store data
    def store_hyperparameter(self, hyper):
        out_path = os.path.join(self.out_dir, 'hyperparameter.json')
        check_exist(out_path)
        out_json(hyper, out_path)

    def store_net(self, net, epoch):
        out_dir = os.path.join(self.out_dir, 'net', f'epoch_{epoch+1}')
        out_path = os.path.join(out_dir, 'generator.pth')
        check_exist(out_path)
        out_pt(purify_device(net[0].state_dict()), out_path)

        out_path = os.path.join(out_dir, 'discriminator.pth')
        check_exist(out_path)
        out_pt(purify_device(net[1].state_dict()), out_path)

    def store_loss(self, loss, epoch):
        out_dir = os.path.join(self.out_dir, 'loss', f'epoch_{epoch+1}')

        fig = loss_visualize(loss, plt, title=self.loss_type)

        out_path = os.path.join(out_dir, f'loss_data.pt')
        check_exist(out_path)
        out_pt(loss, out_path)

        out_path = os.path.join(out_dir, 'loss_figure.png')
        check_exist(out_path)
        save_figure(fig, out_path)

    def store_mix(self, mix_audio, mix_out, epoch):
        out_dir = os.path.join(self.out_dir, 'mix', f'epoch_{epoch+1}')

        # save mix audio
        for i, audio in enumerate(mix_audio):
            out_path = os.path.join(out_dir, 'audio', f'{i}.wav')
            check_exist(out_path)
            out_audio(purify_device(audio), out_path)

        # save mix res
        out_path = os.path.join(out_dir, 'mix_out.pt')
        check_exist(out_path)
        out_pt(purify_device(mix_out), out_path)

    # Log
    def log_loss(self, epoch):
        print(f"epoch_{epoch+1} G={self.loss['G'][-1]:.4f},"
              f"epoch_{epoch+1} D={self.loss['D'][-1]:.4f}")

    def generate_mix(self, generator):
        generator.eval()
        device = next(generator.parameters()).device
        pair_audio = [audio.to(device) for audio in self.pair_audio]
        cue = self.cue.to(device)
        mix_audio, mix_out = generator.infer(*pair_audio, cue_region=cue)
        return (mix_audio, mix_out)

    def log(self, loss, epoch):
        self.loss['G'].append(loss[0])
        self.loss['D'].append(loss[1])
        self.log_loss(epoch)

    def __call__(self, net, epoch):
        if net is not None:

            # generate
            mix_audio, mix_out = self.generate_mix(net[0])
            mix_out['cue'] = self.cue * settings.N_TIME

            # store
            self.store_mix(mix_audio, mix_out, epoch)
            self.store_net(net, epoch)
            if epoch >= 0:
                self.store_loss(self.loss, epoch)


## Trainer

In [127]:
class Trainer():
    def __init__(self,
                 model_args,
                 train_args,
                 store_args):

        self.device = get_device()
        self.epochs = train_args['epoch']
        self.loss_type = train_args['loss_type']
        self.n_critic = train_args['n_critic']


        self.criterion = get_criterion(train_args['loss_type'])
        self.generator = model_args['G'].to(self.device)
        self.discriminator = model_args['D'].to(self.device)


        dataset = train_args['dataset']
        batch_size = train_args['batch_size']

        self.datasampler = DataLoaderSampler(dataset[0], batch_size=batch_size, shuffle=True) if len(dataset) > 1 else None
        self.dataloader = batchlize(dataset[1], batch_size=batch_size, shuffle=True)

        self.otpim_g = get_optimizer(self.generator,
                                             train_args['optim'][0],
                                             train_args['lr'][0])
        self.otpim_d = get_optimizer(self.discriminator,
                                             train_args['optim'][1],
                                             train_args['lr'][1])

        self.log_interval = store_args['log_interval']
        self.out_dir = os.path.join(settings.STORE_DIR, store_args['out_dir'])
        self.begin_date = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        self.hyper = self.get_hyperparameter(model_args, train_args, store_args)
        self.storer = Storer(dataset[0],
                             store_args['n_sample'],
                             self.out_dir,
                             self.hyper)

    def get_hyperparameter(self,
                           model_args,
                           train_args,
                           store_args):

        train_args.pop('dataset')

        return {
            'model': {key: model.hyper for (key, model) in model_args.items()},
            'train': train_args,
            'store': store_args,
            'date': self.begin_date
        }

    def average_loss(self, loss):
        return torch.mean(torch.tensor(loss)).item()

    def generate_mix(self):
        pair_audios, cue = self.datasampler()
        cue = cue.to(self.device)
        pair_audios = [audio.to(self.device) for audio in pair_audios]
        mix_mag, mix_out = self.generator(*pair_audios, cue_region=cue)
        mix_mag = torch.sum(mix_mag, axis=1) if len(mix_mag.size()) == 5 else mix_mag
        return (mix_mag, mix_out)

    def train_step_G(self):
        self.generator.zero_grad()
        fake_mix, fake_mix_out = self.generate_mix()
        dgz = self.discriminator(fake_mix)
        loss = self.criterion(dgz)
        loss.backward()
        self.otpim_g.step()
        return loss.item()

    def train_step_D(self, real_mix, real_cue):
        self.discriminator.zero_grad()
        with torch.no_grad():
            fake_mix, fake_mix_out = self.generate_mix()

        dgz = self.discriminator(fake_mix)
        dx = self.discriminator(real_mix)
        loss = self.criterion(dgz, dx)
        loss.backward()
        self.otpim_d.step()
        return loss.item()

    def train(self):

        begin_time = time.time()
        D_Losses = []

        self.storer([self.generator, self.discriminator], -1)
        for epoch in range(self.epochs):
            for batch_idx, (real_mix, real_cue) in enumerate(tqdm(self.dataloader)):

                real_mix = real_mix.to(self.device)
                real_cue = real_cue.to(self.device)

                D_Losses.append(self.train_step_D(real_mix, real_cue)) # train D for one step
                if check_nan(D_Losses[-1]):
                    print(f'Batch: {batch_idx+1}, Epoch: {epoch+1}, D Loss is nan .....')

                if (batch_idx+1) % self.n_critic == 0:
                    G_Loss = self.train_step_G()
                    D_Loss = self.average_loss(D_Losses)
                    D_Losses.clear()

                    if check_nan(G_Loss):
                        print(f'Batch: {batch_idx+1}, Epoch: {epoch+1}, G Loss is nan .....')
                        break

                    if (batch_idx+1) % self.log_interval == 0:
                        self.storer.log([G_Loss, D_Loss], epoch)

            self.storer([self.generator, self.discriminator], epoch)
            print(f'[{epoch + 1}/{self.epochs}] one epoch completed ...')

        # print('Train finished. Elapsed: %s' % datetime.timedelta(seconds=time.time() - begin_time))

## Dataset

In [128]:
def get_list_intersect(lst1, lst2):
    return list(set(lst1) & set(lst2))

In [129]:
def load_json(file_path):
    with open(file_path, 'r') as file:
        data = json.load(file)
    return data

In [130]:
def time_to_samples(time): 
    return librosa.time_to_samples(time, sr=settings.SR)

In [131]:
def get_audio_info(audio_path):
    if 'mp3' in audio_path:
        torchaudio.set_audio_backend('sox_io')
    return torchaudio.info(audio_path)

In [132]:
def to_mono(audio, dim=-2): 
    if len(audio.size()) > 1:
        return torch.mean(audio, dim=dim, keepdim=True)
    else:
        return audio

In [133]:
def load_audio(audio_path, 
               sr=settings.SR, 
               mono=True, 
               start=0, 
               end=None
              ):
    if 'mp3' in audio_path:
        torchaudio.set_audio_backend('sox_io')
    audio, org_sr = torchaudio.load(audio_path, frame_offset=start, num_frames=-1 if end is None else end-start)
    audio = to_mono(audio) if mono else audio
    
    if org_sr != sr:
        audio = torchaudio.transforms.Resample(org_sr, sr)(audio)

    return audio

In [134]:
def pad_tensor(x, num_samples, value=None, last=True):
    if value is None:
        pads = (torch.rand(num_samples) * 2 - 1).unsqueeze(0) * 0.1
    elif value == -1:
        if last:
            pads = torch.linspace(1, settings.EPSILON, num_samples).unsqueeze(0) * 0.1
        else:
            pads = torch.linspace(0, settings.EPSILON, num_samples).unsqueeze(0) * 0.1
        pads *= 0.1
    elif value == 0.:
        pads = torch.full((x.size(-2), num_samples), settings.EPSILON)    
    else:
        pads = torch.full((x.size(-2), num_samples), value)   
    return [torch.cat((pads, x), axis=-1), torch.cat((x, pads), axis=-1)][last]

In [135]:
def samples_to_time(samples):
    return librosa.samples_to_time(samples, sr=settings.SR)

In [136]:
def get_new_cue(cue, cue_mid, step, ratio):
    new_cue = [c-(cue_mid-step) for c in cue]
    if ratio:
        new_cue = [c / (2*step) for c in new_cue]
    else:
        new_cue = [samples_to_time(c) for c in new_cue]
    return torch.tensor(new_cue).float()

In [137]:
def select_audio_region(in_data, cue, n_time, ratio, select_idx):
    data_type = 'tensor' if isinstance(in_data, torch.Tensor) else 'str'
    length = in_data.size(-1) if data_type == 'tensor' else get_audio_info(in_data).num_frames
    cue = [time_to_samples(c) for c in cue] 
    step = int(time_to_samples(n_time/2))
    cue_mid = int(sum(cue) // 2)

    time_dict = [
        [0 if cue_mid-step < 0 else cue_mid-step, cue[1]], 
        [cue[0], length if cue_mid+step > length else cue_mid+step],
        [0 if cue_mid-step < 0 else cue_mid-step, length if cue_mid+step > length else cue_mid+step]
    ]
    timestamps = time_dict[select_idx]
    
    if data_type == 'tensor':
        audio = in_data[:, timestamps[0]:timestamps[1]]
    else:
        audio = load_audio(in_data, 
                           start = timestamps[0], 
                           end   = timestamps[1])
    
    pad = cue_mid + step - length 
    if pad > 0:
        audio  = pad_tensor(audio,
                           pad, 
                           value=0., 
                           last=True)
        
    pad = 2*step - audio.size(-1)
    if pad > 0:
        audio = pad_tensor(audio, 
                           pad, 
                           value=0., 
                           last=(1-select_idx)>0)
    new_cue = get_new_cue(cue, cue_mid, step, ratio)
    return audio, new_cue, (cue, timestamps)

In [139]:
def squeeze_dim(data):
    dims = [i for i in range(len(data.size())) if data.size(i) == 1]
    for dim in dims:
        data = data.squeeze(dim)
    return data

In [140]:
def normalize_loudness(audio, loudness=-12): # unit: db
    meter = pyln.Meter(settings.SR)
    if isinstance(audio, torch.Tensor):
        audio = squeeze_dim(audio).numpy()
        measured = meter.integrated_loudness(audio)
        normalized = pyln.normalize.loudness(audio, measured, loudness)
        normalized = torch.from_numpy(normalized).unsqueeze(0)
    else:
        measured = meter.integrated_loudness(audio)
        normalized = pyln.normalize.loudness(audio, measured, loudness)
    return normalized

In [141]:
def normalize_peak(audio, peak_loudness=-1):
    if isinstance(audio, torch.Tensor):
        audio = squeeze_dim(audio).numpy()
        normalized = pyln.normalize.peak(audio, -1.0)
        normalized = torch.from_numpy(normalized).unsqueeze(0)
    else:
        normalized = pyln.normalize.peak(audio, -1.0)
    return normalized

In [142]:
def normalize(audio, norm_type='loudness'):
    if norm_type is None:
        return audio
    norm_dict = {
        'peak': normalize_peak(audio),
        'loudness': normalize_loudness(audio), 
    }
    normalized = norm_dict.get(norm_type, None)
    return audio if normalized is None else normalized

In [143]:
def load_npy(in_path):
    return np.load(in_path, allow_pickle=True)

In [None]:
    # def __get_pair_ids(self, df, audio_dir):
    #     audio_files = [file[:-4] for file in os.listdir(audio_dir)]
    #     print('Making df...')
    #     meta_info = {} # храним id для вызова: {название трека: , начало первого трека: , начало перехода: , начало второго трека: , начало перехода: }
    #     idx = 0

    #     for i in tqdm.tqdm(range(len(df))):
    #         name = self.df.iloc[i]['mix_id']
    #         row4next_track = df[(df['i_tran']==df.iloc[i]['i_track_next'])&(df['mix_id']==df.iloc[i]['mix_id'])]


    #         if ((name not in audio_files) or 
    #             (1 != len(row4next_track)) or 
    #             pd.isnull(df.iloc[i]['timestamp_prev']) or
    #             pd.isnull(df.iloc[i]['timestamp_next']) or
    #             pd.isnull(row4next_track['mix_cue_out_time_prev']).any()):
    #             continue
            

    #         meta_info[idx] = {
    #             'path': os.path.join(audio_dir, name + '.mp3'), 
    #             'start_time_first_track': df.iloc[i]['timestamp_prev'],
    #             'start_mix': df.iloc[i]['mix_cue_out_time_prev'],
    #             'end_mix': df.iloc[i]['mix_cue_in_time_next'],
    #             'end_time_second_track': row4next_track['mix_cue_out_time_prev'],

    #             '__meta_info': df.iloc[i].to_dict()
    #         }

    #         idx += 1
        
    #     print('Dataframe is ready...')
    #     return meta_info

In [None]:
class PairDJMix(Dataset):
    def __init__(self, 
                 n_time=settings.N_TIME, 
                 norm_type='loudness', 
                 cue_ratio=True,
                 audio_dir:str ='E:\\diplom\\automaticDJtransition\\dataset_seg\\audio',
                 meta_dir: str= 'E:\\diplom\\automaticDJtransition\\dataset_seg\\meta_audio'):
        
        self.cue_ratio = cue_ratio # bool отвечает за какой то шаг в функции get_new_cue 
        self.n_time = n_time # отвечает за то, какая длина у отрывка для обучения
        self.norm_type = norm_type # тип нормализации 

        # df = pd.read_csv(df_path)
        self.df = self.__get_pair_ids(audio_dir, meta_dir)
        self.data_types = ['prev', 'next']
    
    def __get_pair_ids(self, audio_dir, meta_dir):
        audios = [name[:-4] for name in os.listdir(audio_dir)]
        metas = [name[:-4] for name in os.listdir(meta_dir)]
        return_df = {}
        idx = 0
        print('Making pairs...')
        for i in tqdm.tqdm(range(len(audios))):
            if audios[i] in metas:
                with open(os.path.join(meta_dir, audio_dir[i] + '.json'), 'r', encoding='utf-8') as file:
                    data = json.load(file)
                return_df[idx] = {'audio_path': os.path.join(audio_dir, audios[i] + '.mp3'),
                                  'cue': data['cue']}
        return return_df
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        next_track_idx = random.randint(0, len(self.df))
        while next_track_idx == index:
            next_track_idx = random.randint(0, len(self.df))

        pair_audios = []
        cues = []
        for i, pair_obj in enumerate(zip([self.df[index], self.df[next_track_idx]])):
            audio, cue, _ = select_audio_region(pair_obj['audio_path'],
                                                pair_obj['cue'], 
                                                self.n_time, 
                                                self.cue_ratio, i)
            audio = normalize(audio, norm_type=self.norm_type)
            cues.append(cue)
            pair_audios.append(audio)
        return pair_audios, cues[0]

In [73]:
class Pair(Dataset):
    def __init__(self, n_time=settings.N_TIME, norm_type='loudness', cue_ratio=True):
        self.cue_ratio = cue_ratio # bool отвечает 
        self.n_time = n_time 
        self.norm_type = norm_type 
        self.audio_dir = os.path.join(settings.PAIR_DIR, 'audio')
        self.meta_dir = os.path.join(settings.PAIR_DIR, 'meta')
        self.pair_ids = self.__get_pair_ids()
        self.data_types = ['prev', 'next']

    def __get_pair_ids(self):
        audio_ids = list(os.listdir(self.audio_dir))
        meta_ids = [file.split('.json')[0] for file in os.listdir(self.meta_dir) if file.endswith('.json')]
        return get_list_intersect(audio_ids, meta_ids)

    def __get_obj(self, pair_id, data_type):
        audio_path = os.path.join(self.audio_dir, pair_id, f'{data_type}.wav')
        meta = load_json(os.path.join(self.meta_dir, f'{pair_id}.json'))
        return {'cue': meta[data_type]['cue'], 'audio_path': audio_path}

    def __getitem__(self, index):
        pair_id = self.pair_ids[index]
        pair_objs = [self.__get_obj(pair_id, data_type) for data_type in self.data_types]
        pair_audios = []
        cues = []
        for i, pair_obj in enumerate(pair_objs):
            audio, cue, _ = select_audio_region(pair_obj['audio_path'], pair_obj['cue'], self.n_time, self.cue_ratio, i)
            audio = normalize(audio, norm_type=self.norm_type)
            cues.append(cue)
            pair_audios.append(audio)
        return pair_audios, cues[0]

    def __len__(self):
        return len(self.pair_ids)

In [None]:
class Mix(Dataset):
    def __init__(self, n_time=settings.N_TIME, norm_type='loudness', cue_ratio=True):
        self.cue_ratio = cue_ratio
        self.n_time = n_time
        self.norm_type = norm_type
        self.audio_dir = os.path.join(settings.MIX_DIR, 'audio')
        self.obj_dir = os.path.join(settings.MIX_DIR, 'obj')
        self.pair_ids = self.__get_pair_ids()

    def __get_pair_ids(self):
        audio_ids = [file.split('.wav')[0] for file in os.listdir(self.audio_dir) if file.endswith('.wav')]
        obj_ids = [file.split('.npy')[0] for file in os.listdir(self.obj_dir) if file.endswith('.npy')]
        return get_list_intersect(audio_ids, obj_ids)

    def __get_obj(self, pair_id):
        audio_path = os.path.join(self.audio_dir, f'{pair_id}.wav')
        obj = load_npy(os.path.join(self.obj_dir, f'{pair_id}.npy')).item()
        return {'cue': obj['cue'], 'audio_path': audio_path}

    def __getitem__(self, index):
        pair_id = self.pair_ids[index]
        pair_obj = self.__get_obj(pair_id)
        audio, cue, _ = select_audio_region(pair_obj['audio_path'], pair_obj['cue'], self.n_time, self.cue_ratio, -1)
        audio = normalize(audio, norm_type=self.norm_type)
        return audio, cue

    def __len__(self):
        return len(self.pair_ids)

# Main part

In [None]:
def get_trainer(**kwargs):
    loss_type = kwargs.get('loss_type', 'minmax')
    lr = kwargs.get('lr')

    if len(kwargs.dataset) != 2:
        print('please provide proper dataset ...')
        return

    if len(kwargs.net) != 2:
        print('please provide proper G and D ...')
        return

    model_args = {
        'G': kwargs.net[0],
        'D': kwargs.net[1]
    }

    train_args = {
        'lr': [1e-5, 1e-5] if lr is None else lr ,
        'optim': ['Adam', 'Adam'],
        'epoch': kwargs.get('epoch', 5),
        'n_critic': kwargs.get('n_critic', 1),
        'batch_size': kwargs.get('batch_size', 4),
        'loss_type': loss_type,
        'dataset': kwargs.dataset
    }

    store_args = {
        'n_sample': kwargs.get('n_sample', 4),
        'log_interval': kwargs.get('log', 20),
        'out_dir': kwargs.get('out_dir' , f'{loss_type}_gan')
    }
    return Trainer(model_args, train_args, store_args)

In [59]:
net = get_net(cnn_type='res_cnn', loss_type='minmax')

STFT kernels created, time used = 0.0742 seconds
STFT kernels created, time used = 0.0673 seconds


In [151]:
num_params1 = sum(p.numel() for p in net[0].parameters())
num_params1

34833549

In [152]:
num_params2 = sum(p.numel() for p in net[1].parameters())
num_params2

23707241

In [153]:
num_params1 + num_params2

58540790

In [74]:
trainer = get_trainer(net=net,
                      dataset=dataset,
                      log=args.log,
                      lr=args.lr,
                      n_gpu=args.n_gpu,
                      epoch=args.epoch,
                      n_critic=args.n_critic,
                      batch_size=args.batch_size,
                      loss_type=args.loss_type,
                      out_dir=args.out_dir)

NameError: name 'dataset' is not defined