In [1]:
import sys
sys.path.append("../")

In [2]:
import matplotlib.pyplot as plt
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import IPython.display as ipd

from mel2wav.dataset import AudioDataset
from mel2wav.modules import Generator, Discriminator, Audio2Mel

In [3]:
ngf = 32
n_residual_layers = 3
num_D = 3
ndf = 16 
n_layers_D = 4
downsamp_factor = 4
n_mel_channels = 80

netG = Generator(n_mel_channels, ngf, n_residual_layers).cuda()
netD = Discriminator(num_D, ndf, n_layers_D, downsamp_factor).cuda()
fft = Audio2Mel(n_mel_channels=n_mel_channels).cuda()

optG = torch.optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))
optD = torch.optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))

In [4]:
seq_len = 8192 * 12
data_path = '..\data\jazz_classical'

train_set = AudioDataset(
        Path(data_path) / "train_files.txt", seq_len, sampling_rate=22050)
style_set = AudioDataset(
        Path(data_path) / "style_files.txt", seq_len, sampling_rate=22050)

In [5]:
load_root = Path('../data/audio_decoder')

if load_root and load_root.exists():
    netG.load_state_dict(torch.load(load_root / "netG.pt"))
    optG.load_state_dict(torch.load(load_root / "optG.pt"))
    netD.load_state_dict(torch.load(load_root / "netD.pt"))
    optD.load_state_dict(torch.load(load_root / "optD.pt"))
    print('weights successfully loaded ...')

weights successfully loaded ...


In [6]:
class StructureLoss(nn.Module):

    def __init__(self, target):
        super(StructureLoss, self).__init__()
        # we 'detach' the target content from the tree used
        # to dynamically compute the gradient: this is a stated value,
        # not a variable. Otherwise the forward method of the criterion
        # will throw an error.
        self.target = target

    def forward(self, input):
        self.loss = F.mse_loss(input, self.target)
        return input

In [7]:
content_layers_default = [2, 4]
style_layers_default = [1, 3, 5] # list(range(nb_layers))

def get_style_model_and_losses(net_ensamble, style_song, content_song,
                               content_layers=content_layers_default,
                               style_layers=style_layers_default,
                               pref_disc='disc_0'):
    
    # just in order to have an iterable access to or list of content/syle
    # losses
    content_losses = []
    style_losses = []
    
    model = nn.Sequential()
    source_net = netD.model[pref_disc]
    
    for idx, [name, layer] in enumerate(source_net.model.items()):
        if isinstance(layer, nn.Sequential):
            for i, x in enumerate(layer):
                model.add_module(name + f'_{i}', x)
                    
        if idx in content_layers:
            target = model(content_song)
            content_loss = StructureLoss(target)
            model.add_module("content_loss_{}".format(idx), content_loss) 
            content_losses.append(content_loss)
            
        if idx in style_layers:
            target = model(style_song)
            style_loss = StructureLoss(target)
            model.add_module("style_loss_{}".format(idx), style_loss)
            style_losses.append(style_loss)
            
    return model, style_losses, content_losses

In [8]:
def get_input_optimizer(input_img):
    # this line to show that input is a parameter that requires a gradient
    optimizer = optim.LBFGS([input_img.requires_grad_()], max_iter=1)
    return optimizer

In [28]:
def run_style_transfer(netD, style_song, content_song, input_song, num_steps=25, pref_disc='disc_0',
                       style_weight=0.5, content_weight=0.3):
    """Run the style transfer."""
    print('Building the style transfer model..')
    model, style_losses, content_losses = get_style_model_and_losses(netD, style_song, content_song, pref_disc=pref_disc)
    optimizer = get_input_optimizer(input_song)
    
    print('Optimizing..')
    run = [0]
    while run[0] <= num_steps:

        def closure():
            # correct the values of updated input image
            optimizer.zero_grad()
            model(input_song)
            style_score = 0
            content_score = 0

            for sl in style_losses:
                style_score += sl.loss
                
            for cl in content_losses:
                content_score += cl.loss
            
            style_score *= style_weight
            content_score *= content_weight

            loss = style_score + content_score
            loss.backward(retain_graph=True)

            run[0] += 1
            if run[0] % 10 == 0:
                print("run {}:".format(run))
                print('Style Loss : {:4f} Content Loss: {:4f}'.format(
                    style_score.item(), content_score.item()))
                print()

            return style_score + content_score

        optimizer.step(closure)

    return input_song

In [29]:
cont_idx = np.random.choice(len(train_set)) 
styl_idx = np.random.choice(len(style_set)) 

content_song = style_set[cont_idx].unsqueeze(dim=0).cuda()
style_song = style_set[styl_idx].unsqueeze(dim=0).cuda()
input_song = style_set[styl_idx].unsqueeze(dim=0).cuda()

output = run_style_transfer(netD, style_song, content_song, input_song, num_steps=15)

Building the style transfer model..
Optimizing..
run [10]:
Style Loss : 18.791470 Content Loss: 43.738976



In [30]:
style  = style_set[styl_idx].numpy()[0]
content  = style_set[cont_idx].numpy()[0]
combined = output.cpu().detach().numpy()[0][0]

In [31]:
# from melonet.utils import read_wav_file

In [32]:
# wab_sample_path = '../data/test_transfer/ciorba.mp3'
# cont_idx = np.random.choice(len(train_set)) 

# wav_sample, _ = read_wav_file(wab_sample_path, seq_len=seq_len)
# style_song = torch.from_numpy(wav_sample).float().unsqueeze(dim=0)
# style_song = style_song.unsqueeze(dim=0).cuda()

# content_song = style_set[cont_idx].unsqueeze(dim=0).cuda()
# output = run_style_transfer(netD, style_song, content_song, content_song, num_steps=7)

# style  = wav_sample
# content  = style_set[cont_idx].numpy()[0]
# combined = output.cpu().detach().numpy()[0][0]

In [33]:
ipd.Audio(content, rate=22050)

In [34]:
ipd.Audio(style, rate=22050)

In [35]:
ipd.Audio(combined, rate=22050)