In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.autograd import Variable


class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder,self).__init__()
        
        
        self.encoder=nn.ModuleList([
            nn.Conv1d(1,8,8,4,0,dilation=2),nn.BatchNorm1d(8), nn.ReLU(),
            nn.Conv1d(8,64,8,4,0,dilation=2),nn.BatchNorm1d(64), nn.ReLU(),
            nn.Conv1d(64,128,8,4,0,dilation=2),nn.BatchNorm1d(128), nn.ReLU(),
            nn.Conv1d(128,256,4,2,0,dilation=1),nn.BatchNorm1d(256), nn.ReLU(),
            nn.Conv1d(256,256,4,2,0,dilation=1), nn.BatchNorm1d(256),nn.ReLU(),  
        ])
        self.meanL=nn.Sequential(
            nn.Linear(256,256),nn.BatchNorm1d(256),nn.ReLU(),
            nn.Linear(256,256),nn.BatchNorm1d(256),nn.ReLU(),
            nn.Linear(256,128),nn.BatchNorm1d(128),nn.ReLU(),
            nn.Linear(128,128),nn.BatchNorm1d(128),nn.ReLU(),
            nn.Linear(128,128),nn.BatchNorm1d(128),nn.ReLU(),
            nn.Linear(128,127),nn.BatchNorm1d(127),nn.ReLU()
        )
        self.sigmaL=nn.Sequential(
            nn.Linear(256,256),nn.BatchNorm1d(256),nn.ReLU(),
            nn.Linear(256,256),nn.BatchNorm1d(256),nn.ReLU(),
            nn.Linear(256,128),nn.BatchNorm1d(128),nn.ReLU(),
            nn.Linear(128,128),nn.BatchNorm1d(128),nn.ReLU(),
            nn.Linear(128,128),nn.BatchNorm1d(128),nn.ReLU(),
            nn.Linear(128,127),nn.BatchNorm1d(127),nn.ReLU()
        )
        
        self.LinDecoder=nn.Sequential(
            nn.Linear(128,128),nn.BatchNorm1d(128),nn.ReLU(),
            nn.Linear(128,128),nn.BatchNorm1d(128),nn.ReLU(),
            nn.Linear(128,256),nn.BatchNorm1d(256),nn.ReLU(),
            nn.Linear(256,256),nn.BatchNorm1d(256),nn.ReLU(),
            nn.Linear(256,256),nn.BatchNorm1d(256),nn.ReLU(),
            nn.Linear(256,512),nn.BatchNorm1d(512),nn.ReLU(),
            nn.Linear(512,512),nn.BatchNorm1d(512),nn.ReLU(),
            nn.Linear(512,512),nn.Sigmoid()
        )
        self.up=nn.Upsample(scale_factor=2)
        self.UpDec=nn.Sequential(
            nn.Linear(1024,1024),nn.ReLU(),
            nn.Linear(1024,1024),nn.ReLU()
        )
        

    def sample_latent(self,x,cl):
        mean=self.meanL(x)
        sigma=self.sigmaL(x)
        sigma=torch.sqrt(torch.exp(sigma))
        self.mean=mean
        self.sigma=sigma
        eps = torch.distributions.normal.Normal(0, 1).sample(sample_shape=sigma.size())
        z=mean+sigma*Variable(eps,requires_grad=False).cuda()
        z=torch.cat((z,cl),dim=1)
        return z
    
    def forward(self,x):
        cl=x[0:x.shape[0],0,-1].view(x.shape[0],1)
        x=x[0:x.shape[0],0,0:1024].view(x.shape[0],1,1024)

        for conv in self.encoder:
            x=conv(x)

        x=x.view(x.shape[0],256)
        z=self.sample_latent(x,cl)

        
        
        x=self.LinDecoder(z)
        x=x.view(x.shape[0],1,512)
        x=self.up(x)
        x=self.UpDec(x)
        return x

In [2]:
import torch
from torch.autograd import Variable
import numpy as np
import torch.nn as nn
import torchaudio
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import DataLoader
from utils import load_dataset,myspec
from os import listdir
import math
dataloader=load_dataset(samples=2)

trainset=DataLoader(dataloader, batch_size=62, shuffle=False,drop_last=True, num_workers=4)
vae=AutoEncoder().cuda()
optimizer = optim.Adam(vae.parameters(),lr=1e-3)

crit=nn.MSELoss()

Collecting Samples In flute
Collecting Samples In bass
Collecting Samples In vocal
Collecting Samples In mallet
Collecting Samples In keyboard
Collecting Samples In string
Collecting Samples In brass
Collecting Samples In organ
Collecting Samples In reed
Collecting Samples In guitar
Current Samples:  1240


In [3]:
num_epochs=5
for epoch in range(num_epochs):
    for i in trainset:
        x=Variable(i).cuda()
        y=vae(x)
        loss=crit(y,x[:,:,0:1024])
        loss.backward()
        optimizer.step()
    print(loss.item())



0.1119002252817154
0.08301513642072678
0.13044005632400513
0.21896444261074066
0.205958753824234


In [5]:
newstate={}
for i in vae.state_dict():
    try:
        newstate[i]=checkpoint[i]
    except:
        newstate[i]=vae.state_dict()[i]

In [7]:
torch.save(newstate,'../wavenet/Upsave.tar')

In [4]:
checkpoint = torch.load('../wavenet/batchMSE.tar')

In [6]:
for i in newstate:
    print(i)

encoder.0.weight
encoder.0.bias
encoder.1.weight
encoder.1.bias
encoder.1.running_mean
encoder.1.running_var
encoder.1.num_batches_tracked
encoder.3.weight
encoder.3.bias
encoder.4.weight
encoder.4.bias
encoder.4.running_mean
encoder.4.running_var
encoder.4.num_batches_tracked
encoder.6.weight
encoder.6.bias
encoder.7.weight
encoder.7.bias
encoder.7.running_mean
encoder.7.running_var
encoder.7.num_batches_tracked
encoder.9.weight
encoder.9.bias
encoder.10.weight
encoder.10.bias
encoder.10.running_mean
encoder.10.running_var
encoder.10.num_batches_tracked
encoder.12.weight
encoder.12.bias
encoder.13.weight
encoder.13.bias
encoder.13.running_mean
encoder.13.running_var
encoder.13.num_batches_tracked
meanL.0.weight
meanL.0.bias
meanL.1.weight
meanL.1.bias
meanL.1.running_mean
meanL.1.running_var
meanL.1.num_batches_tracked
meanL.3.weight
meanL.3.bias
meanL.4.weight
meanL.4.bias
meanL.4.running_mean
meanL.4.running_var
meanL.4.num_batches_tracked
meanL.6.weight
meanL.6.bias
meanL.7.weight
