In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from TIMIT_dataloader_DAE import prepareTIMIT_train
from tqdm import tqdm, tqdm_notebook

import matplotlib.pyplot as plt

In [2]:
num_frames = 11
batch_size = 256         # Number of samples in each minibatch
use_cuda = torch.cuda.is_available()

# Setup GPU optimization if CUDA is supported
if use_cuda:
    computing_device = torch.device("cuda")
    extras = {"num_workers": 8, "pin_memory": True}
    print("CUDA is supported")
else: # Otherwise, train on the CPU
    computing_device = torch.device("cpu")
    extras = False
    print("CUDA NOT supported")

train_loader, val_loader = prepareTIMIT_train(batch_size = batch_size, 
                                              num_frames = num_frames, 
                                              shuffle = True,
                                              seed = 1,
                                              extras=extras)

CUDA is supported


In [3]:
class AE(nn.Module):
    
    def __init__(self, 
                 input_size, 
                 hidden_size, 
                 num_layers = 1, 
                 tied = True,
                 batch_normalization = False):
        
        super(AE, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.tied = tied
        self.batch_normalization = batch_normalization
        
        self.W1 = nn.Parameter(torch.randn(hidden_size, input_size) * 0.1)
        self.W2 = nn.Parameter(torch.randn(input_size, hidden_size) * 0.1)
        self.b = nn.Parameter(torch.randn(hidden_size) * 0.01)
        self.c = nn.Parameter(torch.randn(input_size) * 0.01)
        
        self.normed = nn.BatchNorm1d(hidden_size)
    
    def forward(self, batch):
        
        for i in range(self.num_layers):
            batch = F.linear(batch, self.W1, bias = self.b)
#             if self.batch_normalization:
            batch = self.normed(batch)
            batch = F.sigmoid(batch)
            # if self.tied:
            batch = F.linear(batch, weight = self.W1.t(), bias = self.c)
            # else:
                # batch = F.linear(batch, weight = self.W2, bias = self.c)
        
        return batch

In [6]:
# input size as dimension of features, hidden size is # hidden unit, and number of classes is the dimension of dictionary
input_size = num_frames * 65
hidden_size = 500
num_layers = 3

model = AE(input_size = input_size, 
           hidden_size = hidden_size,
           num_layers = num_layers,
           tied = True,
           batch_normalization = False)
model = model.to(computing_device)

criterion = nn.MSELoss()
criterion = criterion.to(computing_device)

criterion_val = nn.MSELoss(reduction='sum')
criterion_val = criterion_val.to(computing_device)

learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay = 0.0002)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, patience=3)

In [7]:
def validate(model, criterion, val_loader):

    full_val_loss = 0.0
    count = 0.0

    with torch.no_grad():
        model.eval()
        for minibatch_count, (input, target) in enumerate(val_loader, 0):

            batch_size, _  = input.shape
            
            input = input.float()
            target = target.float()
            
            input = input.to(computing_device)
            target = target.to(computing_device)

            loss = criterion(output, target)

            full_val_loss += loss
            count += batch_size
    
    avg_loss = full_val_loss / count
    scheduler.step(avg_loss)
    return avg_loss

In [None]:
# clip = 1.0
epochs_number = 100
sample_history = []
best_val_loss = float("inf")
loss_list = []
train_loss_list = []
val_loss_list = []

for epoch_number in range(epochs_number):   

    N = 100
    
    for minibatch_count, (input, target) in enumerate(tqdm_notebook(train_loader), 0):
        model.train()
        
        input = input.float()
        target = target.float()
        input = input.to(computing_device)
        target = target.to(computing_device)
        
        optimizer.zero_grad()

        output = model(input)
        
        loss = criterion(output, target)
        loss_list.append(loss)
        loss.backward()
    
        if minibatch_count % N == 0:
            print('Minibatch: {}. Training loss: {}'.format(minibatch_count, loss))
    
        optimizer.step()
    
    if epoch_number % N == 0: 
        current_val_loss = validate(model, criterion_val, val_loader)
        print('epoch: {}. Training loss: {}, Val loss: {}'.format(epoch_number, 
                                                                  loss,
                                                                  current_val_loss))
        val_loss_list.append(current_val_loss)
    
    else:
        print('epoch: {}. Training loss: {}'.format(epoch_number, loss))
    
    
    if current_val_loss < best_val_loss :        
        torch.save(model.state_dict(), 'DAE.pt')
        best_val_loss = current_val_loss
        
torch.save((loss_list, val_loss_list), './losses.pt')
plt.figure()
plt.plot(range(len(loss_list)), loss_list, label='Training Loss')
plt.plot(range(0, len(val_loss_list) * N * minibatch_count, N * minibatch_count) , val_loss_list, label='Vadiation Loss')
plt.xlabel('minibatches')
plt.ylabel('loss')
plt.legend()

HBox(children=(IntProgress(value=0, max=64532), HTML(value='')))



Minibatch: 0. Training loss: 4.315248489379883
Minibatch: 100. Training loss: 1.5449827909469604
Minibatch: 200. Training loss: 1.3817142248153687
Minibatch: 300. Training loss: 1.3263052701950073
Minibatch: 400. Training loss: 1.4581021070480347
Minibatch: 500. Training loss: 1.2081118822097778
Minibatch: 600. Training loss: 1.2433403730392456
Minibatch: 700. Training loss: 1.1042286157608032
Minibatch: 800. Training loss: 1.107130765914917
Minibatch: 900. Training loss: 1.0067908763885498
Minibatch: 1000. Training loss: 1.1362861394882202
Minibatch: 1100. Training loss: 1.0064464807510376
Minibatch: 1200. Training loss: 0.8355127573013306
Minibatch: 1300. Training loss: 1.069661259651184
Minibatch: 1400. Training loss: 0.9805367588996887
Minibatch: 1500. Training loss: 0.9639266729354858
Minibatch: 1600. Training loss: 0.9421364665031433
Minibatch: 1700. Training loss: 0.8836820721626282
Minibatch: 1800. Training loss: 0.9240055084228516
Minibatch: 1900. Training loss: 0.923991262912