In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchaudio
import matplotlib.pyplot as plt
import numpy as np
import os
import json

In [None]:
#Change this to the location of the NSynth dataset on your machine
files = os.listdir('/media/data/nsynth-train/audio')

In [None]:
LEN = 256 #Length of the audio files
FILES = 6 #Number of files to use in a batch
SEGMN = 1 #Number of segments to take per file
BATCH = FILES * SEGMN #Batch size
START = 3000 #Start location of the audio 
NF = 0.05 #'Noise floor' i.e. amount of noise to add for regularization
BITR = 64 #Bit rate used for crossentropy loss / mu-law encoding

In [None]:
X = torch.zeros(BATCH, LEN)
Param = torch.zeros(BATCH, 1)

for f in range(FILES):
    #Change this to the location of the NSynth dataset on your machine
    sound, sr = torchaudio.load('/media/data/nsynth-train/audio/brass_acoustic_037-0'+str(f*2+60)+'-100.wav')
    for b in range(SEGMN):
        #Normalize the audio from 0 to 1
        clip = sound[:, START:START+LEN]
        clip = clip - clip.min() + NF*2
        clip = clip / (clip.max() + NF*2)

        X[f*SEGMN + b] = clip
        Param[f*SEGMN + b] = torch.tensor(f/FILES)


In [None]:
class Network(nn.Module):
    def __init__(self, statesize, outsize, paramsize):
        super(Network, self).__init__()
        self.statesize = statesize
        
        self.gateLayer1 = nn.GRUCell(1 + paramsize, statesize)
        self.linOut = nn.Linear(statesize, outsize)   
    
    def forward(self, inpt, state, param):
        
        newstate1 = self.gateLayer1(torch.cat((inpt, param), 1), state[:, 0])
        output = self.linOut(newstate1)
        
        return output, newstate1.view(-1, 1, self.statesize)

In [None]:
#Mu-Law encoding is used for compressing the audio
def muLaw(tensor):
    return torch.log(1 + (BITR-1)*torch.abs(tensor)) / np.log(BITR)

In [None]:
STATE = 40 #Size of the hidden state

rnn = Network(STATE, BITR, 1)
optimizer = optim.Adam(rnn.parameters(), lr=0.003)

In [None]:
NEPOCH = 100

for epoch in range(0, NEPOCH):
    
    state = torch.zeros(BATCH, 1, STATE)
    output = torch.zeros(BATCH, 1)
    loss = 0
    
    for step in range(1, X.shape[1]):
        
        inpt = muLaw(X[:, step-1]).view(-1, 1)
        
        #Add random noise for regularization
        inpt += torch.randn(1).view(-1, 1) * NF

        output, state = rnn(inpt, state, Param)
        
        #Use either crossentropy loss, or uncomment the next line
        #In order to use Mean Squared Error instead
        
        loss += F.cross_entropy(output, (X[:, step] * (BITR-1)).long())
        #loss += F.mse_loss(output, X[:, step])
        
        #We only perform backpropgation a certain number of steps
        #This is known as 'Truncated back-propgation through time'
        #Which makes it much easier to train RNNs on long sequences
        #In this case we run backprop every 32 steps
        if step % 32 == 31:
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            state = state.detach()    
            output = output.detach()
            print(epoch, "::", loss/32)
            loss = 0 

In [None]:
#This cell is used to generate audio from the trained network


#Move the network from gpu to cpu
rnn_c = rnn.cpu()

with torch.no_grad():
    state = torch.zeros(1, 4, STATE)
    #factor is the number of times to repeat audio generation
    factor = 5

    output = torch.zeros(1, 1)
    outputs = torch.zeros(X.shape[1] * factor, 1)

    for step in range(0, X.shape[1] * factor):
        inpt = muLaw((torch.argmax(output, dim=1).float()/BITR).view(-1, 1))

        #pm is the vector containing pitch information
        pm = torch.tensor([step / (X.shape[1] * factor)]).view(1, 1)
        output, state = rnn_c(inpt, state, pm)

        outputs[step] = (torch.argmax(output, dim=1).float()/BITR).view(-1, 1)

    plt.plot(outputs)
    plt.show()

In [None]:
#Output the generated audio
import wave
oput = outputs.detach().numpy()
w = wave.open('output_1_low.wav', 'wb')
w.setparams((1, 2, 16000, oput.shape[0], 'NONE', 'NONE'))
w.writeframes((oput * 30000).astype('int16'))