In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [7]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import copy
import torch.optim as optim
import scipy
from scipy import signal
from scipy.io.wavfile import write
import matplotlib.pyplot as plt
import os
import PIL
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import soundfile
import librosa

def gram_matrix(map):
    batch_size, feature_maps, c = map.size()
    matrix = map.view(feature_maps * batch_size,c)
    gram = torch.mm(matrix,matrix.t())
    return gram.div(feature_maps * c)

class StyleTransfer(nn.Module): #base class for all nn models
    def __init__(self,params):
        super(StyleTransfer,self).__init__() #call superclass constructor
        #load a pretrained vgg-19 model
        self.params = params
        self.model = nn.Sequential(nn.Conv1d(in_channels=1025, out_channels=4096, kernel_size=3, stride=1, padding=1),
                                   nn.ReLU())
    def get_model(self):
        print(self.model)
        return self.model.to(self.params["device"])
    def forward(self,x):
        x = x.view(x.size(0),-1)
        return self.model(x)

class StyleLoss(nn.Module):
    def __init__(self,target,weight):
        super(StyleLoss,self).__init__()
        self.weight = weight
        self.loss_target = gram_matrix(target.detach()) * weight
        self.loss_target = self.loss_target
    def forward(self,input):
        input_gram = gram_matrix(input.clone())
        input_gram.mul_(self.weight)
        # print(self.loss_target.device, input_gram.device)
        self.loss = F.mse_loss(self.loss_target,input_gram)
        return input

class ContentLoss(nn.Module):
    def __init__(self,target):
        super(ContentLoss,self).__init__()
        self.loss_target = target.detach()
    def forward(self,input):
        self.loss = F.mse_loss(self.loss_target,input)
        return input


class Normalization(nn.Module):
    def __init__(self, mean, std):
        super(Normalization, self).__init__()
        # .view the mean and std to make them [C x 1 x 1] so that they can
        # directly work with image Tensor of shape [B x C x H x W].
        # B is batch size. C is number of channels. H is height and W is width.
        self.mean = torch.tensor(mean).view(-1, 1, 1)
        self.std = torch.tensor(std).view(-1, 1, 1)

    def forward(self, img):
        # normalize img
        return (img - self.mean) / self.std
def setup_dataset():
    # jazz_loc = 'drive/MyDrive/jazz'
    #     class_loc = 'drive/MyDrive/class'

    '''
    device: check if GPU has been added (runtime -> change runtime type), if not use a CPU,  

    '''
    params = {
        'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        'content_layers': ['conv_1'],
        'style_layers': ['conv_1'],
        's_contribution': [100],
        'c_contribution': [1],
        'lr': 0.01,
        'gamma': 0.01,
        's_weight': 1000000,
        'c_weight': 1,
        'epochs': 2001,
        'full_validate': 1000,
        'train_validate':50,
        'save_loc': 'drive/MyDrive/466 songs/max_save/',
        'model_save_loc': 'drive/MyDrive/466 songs/max_save/models/',
        'content_save_loc': 'drive/MyDrive/466 songs/max_samples/chainsaw.wav',
        'style_save_loc': 'drive/MyDrive/466 songs/max_samples/loopguitar.wav',
        'imsize': 128
    }
    # params['mean'] = torch.tensor([0.485, 0.456, 0.406]).to(params['device'])
    # params['std'] = torch.tensor([0.229, 0.224, 0.225]).to(params['device'])
    # def load_image(path):
    #     loader = transforms.Compose([transforms.Resize(params['imsize']),transforms.ToTensor()])
    #     img = PIL.Image.open(path)
    #     img = loader(img).unsqueeze(0)
    #     img = img[:,:3,:,:]
    #     print(2,img.shape)
    #     return img.to(params['device'],torch.float)
    # content = load_image(params['content_save_loc'])
    # style = load_image(params['style_save_loc'])

    N_FFT=2048
    def read_audio_spectum(filename):
      x, fs = librosa.load(filename, duration=58.04) # Duration=58.05 so as to make sizes convenient
      S = librosa.stft(x, N_FFT)
      p = np.angle(S)
      S = np.log1p(np.abs(S))  
      return S, fs

    style_audio, style_sr = read_audio_spectum(params['style_save_loc'])
    content_audio, content_sr = read_audio_spectum(params['content_save_loc'])

    if(content_sr == style_sr):
      print('Sampling Rates are same')
    else:
      print('Sampling rates are not same')
      exit()

    plt.figure()
    plt.imshow(style_audio)
    plt.axis('off')
    plt.savefig("style" + ".png",bbox_inches = 0)
    plt.ioff()
    plt.show()

    plt.figure()
    plt.imshow(content_audio)
    plt.axis('off')
    plt.savefig("content_start" + ".png",bbox_inches = 0)
    plt.ioff()
    plt.show()       

    style_audio = torch.from_numpy(style_audio)
    content_audio = torch.from_numpy(content_audio)

    style = style_audio.unsqueeze(0)
    content = content_audio.unsqueeze(0)
    print(style.shape)
    return params,content, style, style_sr

def setup_comp(cnn,params,style,content):
    #https://pytorch.org/tutorials/advanced/neural_style_tutorial.html
    style_vals = []
    content_vals = []
    s_done = False
    # normalization = Normalization(params['mean'], params['std']).to(params['device'])
    model = nn.Sequential()

    print(style.shape)
    i = 0  # increment every time we see a conv'
    for layer in cnn.children():
        print("dd")
        if isinstance(layer, nn.Conv1d) or isinstance(layer, nn.Conv2d):
            i += 1
            name = 'conv_{}'.format(i)
        model.add_module(name,layer)

        if name in params['content_layers']:
            loc = params['content_layers'].index(name)
            if loc == len(params['content_layers']) - 1:
                c_done = True
            out = model(content).detach()
            content_loss = ContentLoss(out).to(params['device'])
            model.add_module("content_loss_{}".format(i), content_loss)
            content_vals.append(content_loss)

        if name in params['style_layers']:
            loc = params['style_layers'].index(name)
            if loc == len(params['style_layers']) - 1:
                s_done = True
            out = model(style).detach()
            style_loss = StyleLoss(out,params["s_weight"]).to(params['device'])
            model.add_module("style_loss_{}".format(i), style_loss)
            style_vals.append(style_loss)
        if s_done and c_done:
            break
    return model,style_vals, content_vals

def apply(model,params,content,style_losses,content_losses, style_sr):
    content_in = nn.Parameter(content.data).to(params['device'])
    torch.cuda.empty_cache()
    # plt.figure()
    # plt.imshow(torch.squeeze(content).permute(1,2,0).detach().cpu())
    # plt.ioff()
    # plt.show()
    dir = os.listdir(params['model_save_loc'])
    dir = [item for item in dir if ".pt" in item]
    dir.sort()
    print(dir)
    start = 1
    optimizer = optim.Adam([content_in.requires_grad_()],lr=params['lr'])
    # optimizer = optim.Adam(model.parameters(),lr = params['lr'], weight_decay = params['gamma'])
    if len(dir) != 0:
        newest = dir[len(dir) - 1]
        checkpoint = torch.load(params['model_save_loc'] + newest)
        start = checkpoint['Epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['opt'])
        print("Loaded model, Epoch:",start - 1,"S_Loss:",checkpoint['s_loss'])
    model.eval()
    for epoch in range(start,params['epochs']):
        optimizer.zero_grad()
        model(content_in)
        # print(out.shape)
        s_score = 0
        c_score = 0
        for loss, c in zip(style_losses,params['s_contribution']):
            s_score += loss.loss
        for loss, c in zip(content_losses,params['c_contribution']):
            c_score += loss.loss
        c_score = params['c_weight'] * c_score
        total_loss = s_score + c_score
        total_loss.backward()
        
        if epoch % params['train_validate'] == 0 or epoch == 1:
            save = torch.squeeze(content).detach().cpu()
            plt.figure()
            plt.imshow(save)
            plt.axis('off')
            plt.savefig(str(epoch) + ".png",bbox_inches = 0)
            plt.ioff()
            plt.show()

        if epoch % params['train_validate'] == 0:
            print('Epoch : {}, Total Loss : {:8f}, Style Loss : {:8f}, Content Loss : {:26f}'.format(epoch,total_loss.item(),s_score.item(),c_score.item()))

        if epoch % params['full_validate'] == 0:
            N_FFT=2048
            output = content.squeeze()
            a = torch.exp(output) - 1
            a = a.detach().cpu().numpy()
            # This code is supposed to do phase reconstruction
            p = 2 * np.pi * np.random.random_sample(a.shape) - np.pi
            for i in range(500):
                S = a * np.exp(1j*p)
                x = librosa.istft(S)
                p = np.angle(librosa.stft(x, N_FFT))

            OUTPUT_FILENAME = params["save_loc"] + str(epoch) + "_chainsaw_loopguitar" + '.wav'

            soundfile.write(OUTPUT_FILENAME, x, style_sr,"PCM_24")
            # torch.save({
            #     'Epoch': epoch,
            #     'model': model.state_dict(),
            #     'opt': optimizer.state_dict(),
            #     's_loss': s_score.item(),
            # },params['model_save_loc'] + 'model.pt.' + str(epoch))

        optimizer.step()
    return content
        
def main():
    params, content, style, style_sr = setup_dataset()
    model = StyleTransfer(params)
    model,styles,contents = setup_comp(model.get_model(),params,style.to(params['device']),content.to(params['device']))
    model = model.to(params['device'])
    content = content.to(params["device"])
    apply(model,params,content.clone(),styles,contents,style_sr)
main()  


Output hidden; open in https://colab.research.google.com to view.