In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchaudio
import matplotlib.pyplot as plt
import numpy as np
import os
import json

In [None]:
#Change this to the location of the NSynth dataset on your machine
files = os.listdir('/media/data/nsynth-train/audio')

In [None]:
LEN = 2500 #Length of the audio samples
BATCH = 50 #Batch size
START = 100 #Start sample of the audio sample

#Training data
X = torch.zeros(BATCH, LEN)

for i in range(BATCH):
    #Change this to the location of the NSynth dataset on your machine
    sound, sr = torchaudio.load('/media/data/nsynth-train/audio/' + files[i])
    
    #Normalize the audio so samples are between 0 and 1
    sound = sound[:, START:LEN+START]
    sound = sound -sound.min()
    sound = sound / (sound.max() + 0.001)

    X[i] = sound

In [None]:
#Encoder uses three layers of GRUs
class Encoder(nn.Module):
    def __init__(self, insize, statesize, embedsize):
        super(Encoder, self).__init__()
        self.statesize = statesize
        
        self.gateLayer1 = nn.GRUCell(insize, statesize)
        self.gateLayer2 = nn.GRUCell(statesize, statesize)
        self.gateLayer3 = nn.GRUCell(statesize, statesize)
        self.out = nn.Linear(statesize, embedsize)
    
    def forward(self, inpt, state):
        newstate1 = self.gateLayer1(inpt, state[:, 0])
        newstate2 = self.gateLayer2(newstate1, state[:, 1])
        newstate3 = self.gateLayer3(newstate2, state[:, 2])
        embed = self.out(newstate3)
        newstate = torch.cat((newstate1.view(-1, 1, self.statesize), 
                              newstate2.view(-1, 1, self.statesize), 
                              newstate3.view(-1, 1, self.statesize)), 1)
        return embed, newstate

In [None]:
class Decoder(nn.Module):
    def __init__(self, segments, statesize, embedsize):
        super(Decoder, self).__init__()
        self.statesize = statesize
        self.segments = segments
        
        insize = segments
        outisze = segments
        
        self.gateLayer1 = nn.GRUCell(insize + embedsize, statesize)
        self.gateLayer2 = nn.GRUCell(statesize, statesize)
        self.gateLayer3 = nn.GRUCell(statesize, statesize)
        self.linOut = nn.Linear(statesize, outisze)   
    
    def forward(self, inpt, state, embed):
        newstate1 = self.gateLayer1(torch.cat((inpt, embed), 1), state[:, 0])
        newstate2 = self.gateLayer2(newstate1, state[:, 1])
        newstate3 = self.gateLayer3(newstate2, state[:, 2])
        output = self.linOut(newstate3)
        newstate = torch.cat((newstate1.view(-1, 1, self.statesize), 
                              newstate2.view(-1, 1, self.statesize), 
                              newstate3.view(-1, 1, self.statesize)), 1)
        
        return output, newstate

In [None]:
STATE = 512 #size of the hidden state
EMBED = 25  #size of the encoding / embedding
SEG = 25    #segment size / number of samples input at a time

enc = Encoder(SEG, STATE, EMBED).cuda()
dec = Decoder(SEG, STATE, EMBED).cuda()

In [None]:
encopt = optim.Adam(enc.parameters(), lr=0.0003)
decopt = optim.Adam(dec.parameters(), lr=0.0003, weight_decay=0)

X = X.cuda()

NEPOCH = 3000
for epoch in range(0, NEPOCH):
    
    encstate = torch.zeros(BATCH, 3, STATE).cuda()
    decstate = torch.zeros(BATCH, 3, STATE).cuda()

    #used to keep track of losses so we can take the mean later
    epochlosses = []
    
    #First, traing the encoder
    for step in range(0, X.shape[1], SEG):
        inpt = X[:, step:step+SEG].cuda()
        embed, encstate = enc(inpt, encstate)
    
    loss = 0

    for step in range(0, X.shape[1], SEG):
        #On the first step, input zeros
        if step == 0:
            inpt = torch.zeros(BATCH, SEG).cuda()
        #otherwise randomly use teacher forcing ...
        elif np.random.uniform() < 1 - (epoch/NEPOCH)*0.1:
            inpt = X[:, step-SEG:step].cuda()
        #...or don't use teacher forcing
        else:
            inpt = decoding
            
        #forward pass through the decoder    
        decoding, decstate = dec(inpt, decstate, embed)
        loss += F.mse_loss(decoding, X[:, step:step+SEG])

    epochlosses.append(float(loss.detach())/LEN)
    loss.backward()
    encopt.step()
    encopt.zero_grad()
    decopt.step()
    decopt.zero_grad()

    print(epoch,np.mean(epochlosses))
    

In [None]:
torch.set_printoptions(threshold=np.inf)
#move the encoder and decoder from gpu to cpu
dec_c = dec.cpu()
enc_c = enc.cpu()
with torch.no_grad():
    #factor is the number of times to loop the generation
    #for making longer audio samples
    factor = 24
    #samp is the index for the embedding when generating
    samp = 45

    #embedding
    embed = torch.zeros(1, 3, STATE)
    outputs = torch.zeros(X.shape[1]*factor)
    #encoder state
    encstate = torch.zeros(1, 3, STATE)

    #First, run a pass through the encoder
    for step in range(0, X.shape[1]-SEG, SEG):
        embed, encstate = enc_c(X[samp, step:step+SEG].view(1, -1).cpu(), encstate)

    output = torch.zeros(1, SEG)
    decstate = torch.zeros(1, 3, STATE)

    #From the generated embedding create the output
    for step in range(0, X.shape[1]*factor, SEG):
        inpt = output
        output, decstate = dec_c(inpt, decstate, embed)
        outputs[step:step+SEG] = output
        
    plt.plot(outputs[0:5000])
    plt.show()
    plt.plot(X[samp, 0:].cpu())


In [None]:
#Output generated audio
import wave
oput = outputs.detach().numpy()
w = wave.open('Long_MSE2.wav', 'wb')
w.setparams((1, 2, 16000, oput.shape[0], 'NONE', 'NONE'))
w.writeframes((oput * 10000).astype('int16'))