<a href="https://colab.research.google.com/github/ashwinvaswani/Generative-Modelling-of-Images-from-Speech/blob/master/src/pytorch/pytorch_encoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!nvidia-smi

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader,sampler,Dataset
import torchvision.datasets as dset
import torchvision.transforms as T
import timeit
from PIL import Image
import os
import numpy as np
import scipy.io
import pandas as pd
import torchvision.models.inception as inception
import cv2
from torchsummary import summary

In [0]:
PATH = './drive/My Drive/TIP/Dataset/'
PATH_TO_MAIN = './drive/My Drive/TIP/'
YT_LINK = 'www.youtube.com/watch?v='

In [0]:
with open(PATH_TO_MAIN + 'Pickles/encoder_trainX.pkl','rb') as f:
    x_enc_train = pickle.load(f)

In [0]:
with open(PATH_TO_MAIN + 'Pickles/encoder_trainY.pkl','rb') as f:
    y_train_encoder= pickle.load(f)

In [0]:
x_enc_train.shape

In [0]:
from sklearn.model_selection import train_test_split

X_train,X_val,y_train,y_val = train_test_split(x_enc_train,y_train_encoder,test_size = 0.05)

In [0]:
X_train.shape

In [0]:
dtype = torch.cuda.FloatTensor # the CPU datatype
# Constant to control how frequently we print train loss
print_every = 10
# This is a little utility that we'll use to reset the model
# if we want to re-initialize all our parameters
def reset(m):
    if hasattr(m, 'reset_parameters'):
        m.reset_parameters()

In [0]:
class Flatten(nn.Module):
    def forward(self, x):
        N, C, H, W = x.size() # read in N, C, H, W
        return x.view(N, -1)  # "flatten" the C * H * W values into a single vector per image

In [0]:
fixed_model_base = nn.Sequential( 
                nn.Conv2d(2, 64, kernel_size=(4,4), stride=(1,1)),
                nn.LeakyReLU(inplace=True),
                nn.BatchNorm2d(64),
                nn.Conv2d(64, 64, kernel_size=4, stride=1),
                nn.LeakyReLU(inplace=True),
                nn.BatchNorm2d(64),
                nn.Conv2d(64, 128, kernel_size=4, stride=1),
                nn.LeakyReLU(inplace=True),
                nn.BatchNorm2d(128),
                nn.MaxPool2d(kernel_size = (2,1),stride = (2,1)),
                nn.Conv2d(128, 128, kernel_size=4, stride=1),
                nn.LeakyReLU(inplace=True),
                nn.BatchNorm2d(128),
                nn.MaxPool2d(kernel_size = (2,1),stride = (2,1)),
                nn.Conv2d(128, 256, kernel_size=4, stride=1),
                nn.LeakyReLU(inplace=True),
                nn.BatchNorm2d(256),
                nn.MaxPool2d(kernel_size = (2,1),stride = (2,1)),
                nn.Conv2d(256, 512, kernel_size=4, stride=1),
                nn.LeakyReLU(inplace=True),
                nn.BatchNorm2d(512),
                nn.Conv2d(512, 512, kernel_size=4, stride=2),
                nn.LeakyReLU(inplace=True),
                nn.BatchNorm2d(512),
                nn.Conv2d(512, 512, kernel_size=4, stride=2),
                nn.AvgPool2d(kernel_size = (15,1),stride = (1,1)),
                nn.LeakyReLU(inplace=True),
                nn.BatchNorm2d(512),
                Flatten(),
                nn.Linear(29696, 4096),
                nn.Dropout(0.3),
                nn.Linear(4096,2048)
            )


fixed_model = fixed_model_base.type(dtype)

fixed_model.cuda()

In [0]:
summary(fixed_model,(2,598,257))

In [0]:
## Now we're going to feed a random batch into the model you defined and make sure the output is the right size
x = torch.randn(16, 2, 598, 257).type(dtype)
x_var = Variable(x.type(dtype)) # Construct a PyTorch Variable out of your input data
ans = fixed_model(x_var)        # Feed it through the model! 

# Check to make sure what comes out of your model
# is the right dimensionality... this should be True
# if you've done everything correctly
print(np.array(ans.size()))
np.array_equal(np.array(ans.size()), np.array([16,2048]))

In [0]:
def train(model, loss_fn, optimizer, x_train,y_train,x_val,y_val, num_epochs = 1):
    # Early stopping details
    n_epochs_stop = 15
    min_val_loss = np.Inf
    epochs_no_improve = 0
    max_acc = 0
    for epoch in range(num_epochs):
        val_loss = 0
        print('Starting epoch %d / %d' % (epoch + 1, num_epochs))
        print(type(x_val),type(y_val))
        val_loss,val_acc = check_accuracy(fixed_model,x_val,y_val ,loss_fn,val_loss)# check accuracy on the training set
        scheduler.step(val_loss)
        
        model.train()
        for t in range(len(x_train)):
            x_var = Variable(torch.from_numpy(x_train[t]).type(dtype))
            y_var = Variable(torch.from_numpy(y_train[t])).type(dtype)

            scores = model(x_var)
            
            loss = loss_fn(scores, y_var)
            if (t + 1) % print_every == 0:
                #print('t = %d, loss = %.4f' % (t + 1, loss.data))
                print("training loss : " + str(loss.item()))
                print(scores)
                print(y_var)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        val_loss = val_loss / len(x_val)
        print("valid loss : " + str(val_loss))
        if val_acc > max_acc:
            torch.save({'state_dict': fixed_model.state_dict()}, 'tmp_new_best.pt')
            print("Best Model Saved")
            max_acc = val_acc
        
        if val_loss < min_val_loss:
            torch.save({'state_dict': fixed_model.state_dict()}, 'tmp_new.pt')
            torch.save(fixed_model.state_dict(), 'tmp_new_model.pt')
            print("Model saved")
            current_dir = os.path.dirname(os.path.abspath('__file__'))
            epochs_no_improve = 0
            min_val_loss = val_loss
  
        else:
            epochs_no_improve += 1
            # Check early stopping condition
            if epochs_no_improve == n_epochs_stop:
                print('Early stopping!')
                epochs_no_improve = 0
                #break
                # Load in the best model
                model = fixed_model_base.type(dtype)
                model.load_state_dict(torch.load('tmp_new_model.pt'))


In [0]:
def check_accuracy(model,x_val,y_val,loss_fn,val_loss):
    '''
    if loader.dataset.train:
        print('Checking accuracy on validation set')
    else:
        print('Checking accuracy on test set')  
    '''
    num_correct = 0
    num_samples = 0
    #print(type(x_val),x_val)

    model.eval() # Put the model in test mode (the opposite of model.train(), essentially)
    x = 0
    for t  in range(len(x_val)):
        x_var = torch.from_numpy(x_val[t]).type(dtype)
        y_var = torch.from_numpy(y_val[t]).type(dtype)
        scores = model(x_var)
        loss = loss_fn(scores,y_var)
        val_loss += loss.item()
        
        
        _, preds = scores.data.max(1) #scores.data.cpu().max(1)

        num_correct += (preds.cpu().numpy() == y_var.cpu().numpy()).sum()

        
        num_samples += preds.size(0)
    acc = float(num_correct) / num_samples

    print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))
    
    return val_loss,acc

In [0]:

optimizer = torch.optim.Adam(fixed_model_base.parameters(), lr = 0.001)
#optimizer = torch.optim.Adadelta(fixed_model_base.parameters(), lr = 0.0001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',patience=15,verbose = True)
#loss_fn = nn.MultiMarginLoss()
loss_fn = nn.MSELoss()
X_val.shape

In [0]:
X_train = np.reshape(X_train,(39,1,2,598,257))
X_val = np.reshape(X_val,(3,1,2,598,257))

In [0]:
X_train.shape

In [0]:
y_train = np.reshape(y_train,(39,1,2048))
y_val = np.reshape(y_val,(3,1,2048))

In [0]:
y_train.shape

In [0]:
torch.random.manual_seed(54321)
fixed_model.apply(reset) 
fixed_model.train() 
train(fixed_model, loss_fn, optimizer,X_train,y_train,X_val,y_val, num_epochs=10) 
# check_accuracy(fixed_model,zip(X_train,y_train), ,loss_fn,0) #heck accuracy on the training set

In [0]:
torch.save({'state_dict': fixed_model.state_dict()}, PATH_TO_MAIN + 'Models/pytorch_encoders.pt')