In [97]:
import sys
sys.path.append("../")

In [98]:
import matplotlib.pyplot as plt
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import IPython.display as ipd

from mel2wav.dataset import AudioDataset
from mel2wav.modules import Generator, Discriminator, Audio2Mel

In [99]:
ngf = 32
n_residual_layers = 3
num_D = 3
ndf = 16 
n_layers_D = 4
downsamp_factor = 4
n_mel_channels = 80

netG = Generator(n_mel_channels, ngf, n_residual_layers).cuda()
netD = Discriminator(num_D, ndf, n_layers_D, downsamp_factor).cuda()
fft = Audio2Mel(n_mel_channels=n_mel_channels).cuda()

optG = torch.optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))
optD = torch.optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))

In [100]:
seq_len = 8192 * 28
data_path = '..\data\jazz_classical'

train_set = AudioDataset(
        Path(data_path) / "train_files.txt", seq_len, sampling_rate=22050)
style_set = AudioDataset(
        Path(data_path) / "style_files.txt", seq_len, sampling_rate=22050)

In [101]:
load_root = Path('../data/audio_decoder')

if load_root and load_root.exists():
    netG.load_state_dict(torch.load(load_root / "netG.pt"))
    optG.load_state_dict(torch.load(load_root / "optG.pt"))
    netD.load_state_dict(torch.load(load_root / "netD.pt"))
    optD.load_state_dict(torch.load(load_root / "optD.pt"))
    print('weights successfully loaded ...')

weights successfully loaded ...


In [102]:
class ContentLoss(nn.Module):

    def __init__(self, target):
        super(ContentLoss, self).__init__()
        # we 'detach' the target content from the tree used
        # to dynamically compute the gradient: this is a stated value,
        # not a variable. Otherwise the forward method of the criterion
        # will throw an error.
        self.target = target

    def forward(self, input):
        self.loss = F.mse_loss(input, self.target)
        return input

In [307]:
content_layers_default = list(range(6)) 
style_layers_default = [6] # list(range(nb_layers))

def get_style_model_and_losses(net_ensamble, style_song, content_song,
                               content_layers=content_layers_default,
                               style_layers=style_layers_default,
                               pref_disc='disc_0'):
    
    # just in order to have an iterable access to or list of content/syle
    # losses
    content_losses = []
    style_losses = []
    
    model = nn.Sequential()
    source_net = netD.model[pref_disc]
    
    for idx, [name, layer] in enumerate(source_net.model.items()):
        if isinstance(layer, nn.Sequential):
            for i, x in enumerate(layer):
                model.add_module(name + f'_{i}', x)
                    
        if idx in content_layers:
            target = model(content_song)
            content_loss = ContentLoss(target)
            model.add_module("content_loss_{}".format(idx), content_loss) 
            content_losses.append(content_loss)
            
        if idx in style_layers:
            target = model(style_song)
            style_loss = ContentLoss(target)
            model.add_module("style_loss_{}".format(idx), style_loss)
            style_losses.append(style_loss)
            
    return model, style_losses, content_losses

In [308]:
def get_input_optimizer(input_img):
    # this line to show that input is a parameter that requires a gradient
    optimizer = optim.LBFGS([input_img.requires_grad_()], max_iter=10)
    return optimizer

In [309]:
def run_style_transfer(netD, style_song, content_song, input_song, num_steps=10, pref_disc='disc_0',
                       style_weight=10, content_weight=1):
    """Run the style transfer."""
    print('Building the style transfer model..')
    model, style_losses, content_losses = get_style_model_and_losses(netD, style_song, content_song, pref_disc=pref_disc)
    optimizer = get_input_optimizer(input_song)
    
    print('Optimizing..')
    run = [0]
    while run[0] <= num_steps:

        def closure():
            # correct the values of updated input image
            optimizer.zero_grad()
            model(input_song)
            style_score = 0
            content_score = 0

            for i, sl in enumerate(style_losses):
                style_score += sl.loss * style_weight / (i + 1)
                
            for cl in content_losses:
                content_score += cl.loss

            content_score *= content_weight

            loss = style_score + content_score
            loss.backward(retain_graph=True)

            run[0] += 1
            if run[0] % 10 == 0:
                print("run {}:".format(run))
                print('Style Loss : {:4f} Content Loss: {:4f}'.format(
                    style_score.item(), content_score.item()))
                print()

            return style_score + content_score

        optimizer.step(closure)

    return input_song

In [310]:
cont_idx = np.random.choice(len(train_set)) 
styl_idx = np.random.choice(len(style_set)) 

content_song = train_set[cont_idx].unsqueeze(dim=0).cuda()
style_song = style_set[styl_idx].unsqueeze(dim=0).cuda()

output = run_style_transfer(netD, style_song, content_song, content_song, num_steps=100)

Building the style transfer model..
Optimizing..
run [10]:
Style Loss : 145.779129 Content Loss: 46.327599

run [20]:
Style Loss : 61.447201 Content Loss: 55.653515

run [30]:
Style Loss : 39.649448 Content Loss: 60.021965

run [40]:
Style Loss : 31.786802 Content Loss: 61.513420

run [50]:
Style Loss : 27.351236 Content Loss: 62.560951

run [60]:
Style Loss : 25.217480 Content Loss: 62.864075

run [70]:
Style Loss : 23.935440 Content Loss: 63.026802

run [80]:
Style Loss : 23.059238 Content Loss: 63.083595

run [90]:
Style Loss : 22.420280 Content Loss: 63.136528

run [100]:
Style Loss : 21.991289 Content Loss: 63.103439

run [110]:
Style Loss : 21.602676 Content Loss: 63.101456



In [311]:
src  = train_set[cont_idx].numpy()[0]
dst  = style_set[styl_idx].numpy()[0]
ceva = output.cpu().detach().numpy()[0][0]

In [312]:
ipd.Audio(src, rate=22050)

In [313]:
ipd.Audio(dst, rate=22050)

In [314]:
ipd.Audio(ceva, rate=22050)