---   
# HW3 - Transfer learning

#### Due October 30, 2019

In this assignment you will learn about transfer learning. This technique is perhaps one of the most important techniques for industry. When a problem you want to solve does not have enough data, we use a different (larger) dataset to learn representations which can help us solve our task using the smaller task.

The general steps to transfer learning are as follows:

1. Find a huge dataset with similar characteristics to the problem you are interested in.
2. Choose a model powerful enough to extract meaningful representations from the huge dataset.
3. Train this model on the huge dataset.
4. Use this model to train on the smaller dataset.


### This homework has the following sections:
1. Question 1: MNIST fine-tuning (Parts A, B, C, D).
2. Question 2: Pretrain on Wikitext2 (Part A, B, C, D)
3. Question 3: Finetune on MNLI (Part A, B, C, D)
4. Question 4: Finetune using pretrained BERT (Part A, B, C)

---   
## Question 1 (MNIST transfer learning)
To grasp the high-level approach to transfer learning, let's first do a simple example using computer vision. 

The torchvision library has pretrained models (resnets, vggnets, etc) on the Imagenet dataset. Imagenet is a dataset
with 1.3 million images covering over 1000 classes of objects. When you use one of these models, the weights of the model initialize
with the weights saved from training on imagenet.

In this task we will:
1. Choose a pretrained model.
2. Freeze the model so that the weights don't change.
3. Fine-tune on a few labels of MNIST.   

#### Choose a model
Here we pick any of the models from torchvision

In [22]:
import torch
import torchvision.models as models

class Identity(torch.nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

In [23]:


# init the pretrained feature extractor
pretrained_resnet18 = models.resnet18(pretrained=True)
num_ftrs = pretrained_resnet18.fc.in_features
pretrained_resnet18.fc = Identity()

In [24]:
num_ftrs

512

#### Freeze the model
Here we freeze the weights of the model. Freezing means the gradients will not backpropagate
into these weights.

By doing this you can think about the model as a feature extractor. This feature extractor outputs
a **representation** of an input. This representation is a matrix that encodes information about the input.

In [25]:
def freeze_model(model): #feature extract
    for param in model.parameters():
        param.requires_grad = False
        
def unfreeze_model(model): #finetune 
    for param in model.parameters():
        param.requires_grad = True
        
#freeze_model(pretrained_resnet50)
freeze_model(pretrained_resnet18)

#### Init target dataset
Here we define the dataset we are actually interested in.

In [26]:
import os
from torchvision import transforms
from torchvision.datasets import  MNIST
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F

#  train/val  split
mnist_dataset = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(mnist_dataset, [55000, 5000])

mnist_train = DataLoader(mnist_train, batch_size=32)
mnist_val = DataLoader(mnist_val, batch_size=32)

# test split
mnist_test = DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)

In [27]:
import numpy as np

In [28]:
for step, batch in enumerate(mnist_test):
  print(len(batch))
  print(batch[0].shape)
  
  a = np.repeat(batch[0], 3, axis=1)
  print(a.shape)
  #print(batch[1].shape)
  #print(batch[1])
  break
#mnist_test batch has length of 2
#batch[0] is test_X data : [32, 1, 28, 28] => 32 images in each batch and each image has a dimension of 28 x 28 pixels.
#batch[1] is test_Y data: [32]

2
torch.Size([32, 1, 28, 28])
torch.Size([32, 3, 28, 28])


In [29]:
len(mnist_train)

1719

### Part A (init fine-tune model)
decide what model to use for fine-tuning

In [30]:
import torch.nn as nn
#num_ftrs is 512
#hidden_dim = 100 #arbitrary


def init_fine_tune_model(num_ftrs):
 
    num_classes = 10
    
    fc = nn.Sequential(
            nn.Linear(num_ftrs, 100),
            nn.ReLU(),
            nn.Linear(100, num_classes))

    return fc 

### Part B (Fine-tune (Frozen))

The actual problem we care about solving likely has a different number of classes or is a different task altogether. Fine-tuning is the process of using the extracted representations (features) to solve this downstream task  (the task you're interested in).

To illustrate this, we'll use our pretrained model (on Imagenet), to solve the MNIST classification task.

There are two types of finetuning. 

#### 1. Frozen feature_extractor
In the first type we pretrain with the FROZEN feature_extractor and NEVER unfreeze it during finetuning.


#### 2. Unfrozen feature_extractor
In the second, we finetune with a FROZEN feature_extractor for a few epochs, then unfreeze the feature extractor and finish training.


In this part we will use the first version

In [31]:
def test_model(loader, model):
    """
    Help function that tests the model's performance on a dataset
    @param: loader - data loader for the dataset to test against
    """
    device = torch.device("cuda:0")
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
      for i, (images, labels) in enumerate(loader):
          images = np.repeat(images, 3, axis=1) #convert to 3 channel
          inputs, labels = images.to(device), labels.to(device)    
          outputs = F.softmax(model(inputs), dim=1)
          predicted = outputs.max(1, keepdim=True)[1]
          total += labels.size(0)
          correct += predicted.eq(labels.view_as(predicted)).sum().item()
        
    return (100 * correct / total)


In [32]:
import torch.optim as optim

from tqdm import trange

def FROZEN_fine_tune_mnist(feature_extractor, fine_tune_model, num_epochs, mnist_train, mnist_val):
    """
    model is a feature extractor (resnet).
    Create a new model which uses those features to finetune on MNIST
    
    return the fine_tune model
    """     
    device = torch.device("cuda:0")
    #num_epochs = 3
    # INSERT YOUR CODE: (train the fine_tune model using features extracted by feature_extractor)
    #first freeze the layers
    freeze_model(feature_extractor)
    
    #create the finetune model
    feature_extractor.fc = fine_tune_model #this is MLP toplayer
    model = feature_extractor
    
    #create loss etc. 
    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=2e-05, eps=1e-08)
    criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device)
    model.to(device)
    
    train_losses = []
    val_accs= []
    model.train()
    for epoch in range(num_epochs): 
      #train
       
      for i , (images, labels) in enumerate(mnist_train):
            images = np.repeat(images, 3, axis=1) #convert to 3 channel
            inputs, labels = images.to(device), labels.to(device)
            #print(inputs.shape)
            #inputs = inputs.unsqueeze_(0)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
            train_losses.append(loss.item())
        
            # validate every 100 iterations
            if i > 0 and i % 500== 0:
                val_acc = test_model(mnist_val, model)
                val_accs.append(val_accs)
                print('Epoch: [{}/{}], Step: [{}/{}], Train Loss {}, Validation Acc: {}'.format( 
                           epoch+1, num_epochs, i+1, len(mnist_train), loss,  val_acc))
                model.train() #go back to training
  
    return model, train_losses, val_accs


### Part C (compute test accuracy)
Compute the test accuracy of fine-tuned model on MNIST

In [33]:
#def calculate_mnist_test_accuracy(feature_extractor, fine_tune_model, mnist_test):
def calculate_mnist_test_accuracy(model, mnist_test):   
    test_accuracy = test_model(mnist_test, model)
    
    return test_accuracy

### Grade!
Let's see how you did

In [34]:
def grade_mnist_frozen():
    
    # init a ft model
    fine_tune_model = init_fine_tune_model(num_ftrs)
    
    # run the transfer learning routine
    num_epochs = 10
    model, train_losses, val_accs=FROZEN_fine_tune_mnist(pretrained_resnet18, fine_tune_model, num_epochs, mnist_train, mnist_val)
    
    # calculate test accuracy
    test_accuracy = calculate_mnist_test_accuracy(model, mnist_test)
    
    # the real threshold will be released by Oct 11 
    assert test_accuracy > 0.0, 'your accuracy is too low...'
    
    #save model

    PATH_TO_FOLDER=  '/scratch/cp2530/myjupyter/'
    torch.save({
            'epoch': num_epochs,
            'model_state_dict': model.state_dict(),
            'loss': train_losses[-1],
            'frozen_test_accuracy': test_accuracy 
           
            }, PATH_TO_FOLDER + "models/ResNet18Freeze_CP")

    
    
    return test_accuracy
    
frozen_test_accuracy = grade_mnist_frozen()

Epoch: [1/10], Step: [501/1719], Train Loss 2.0012693405151367, Validation Acc: 39.52
Epoch: [1/10], Step: [1001/1719], Train Loss 1.5530532598495483, Validation Acc: 55.0
Epoch: [1/10], Step: [1501/1719], Train Loss 1.3267741203308105, Validation Acc: 59.6
Epoch: [2/10], Step: [501/1719], Train Loss 1.3506823778152466, Validation Acc: 64.14
Epoch: [2/10], Step: [1001/1719], Train Loss 1.043002724647522, Validation Acc: 66.28
Epoch: [2/10], Step: [1501/1719], Train Loss 1.0089775323867798, Validation Acc: 66.72
Epoch: [3/10], Step: [501/1719], Train Loss 1.1376147270202637, Validation Acc: 68.9
Epoch: [3/10], Step: [1001/1719], Train Loss 0.873719334602356, Validation Acc: 70.0
Epoch: [3/10], Step: [1501/1719], Train Loss 0.8653533458709717, Validation Acc: 69.9
Epoch: [4/10], Step: [501/1719], Train Loss 1.0356626510620117, Validation Acc: 71.1
Epoch: [4/10], Step: [1001/1719], Train Loss 0.7839066386222839, Validation Acc: 71.98
Epoch: [4/10], Step: [1501/1719], Train Loss 0.78432679

In [35]:
frozen_test_accuracy 


75.9

In [36]:
abs(-5)

5

### Part D (Fine-tune Unfrozen)
Now we'll learn how to train using the "unfrozen" approach.

In this approach we'll:
1. keep the feature_extract frozen for a few epochs (10)
2. Unfreeze it.
3. Finish training

In [37]:
def UNFROZEN_fine_tune_mnist(feature_extractor, fine_tune_model, num_epochs, mnist_train, mnist_val):
    """
    model is a feature extractor (resnet).
    Create a new model which uses those features to finetune on MNIST
    
    return the fine_tune model
    """     
    
    # INSERT YOUR CODE:
    # keep frozen for 10 epochs, let's do 5 frozen 5 unfrozen
    # ... train
    # unfreeze
    # train for rest of the time
    
    device = torch.device("cuda:0")
    num_freeze_epochs = 5
    # INSERT YOUR CODE: (train the fine_tune model using features extracted by feature_extractor)
    #first freeze the layers
    freeze_model(feature_extractor)
    
    #create the finetune model
    feature_extractor.fc = fine_tune_model #this is MLP toplayer
    model = feature_extractor
    
    #create loss etc. 
    param_list = [p for p in model.parameters() if p.requires_grad]
    print("num param req grad {}".format(len(param_list)))
    optimizer = optim.Adam(param_list, lr=2e-05, eps=1e-08)
    criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device)
    model.to(device)
    
    train_losses = []
    val_accs= []
    model.train()
    for epoch in range(num_freeze_epochs): 
      #train
      for i , (images, labels) in enumerate(mnist_train):
            images = np.repeat(images, 3, axis=1) #convert to 3 channel
            inputs, labels = images.to(device), labels.to(device)
            #print(inputs.shape)
            #inputs = inputs.unsqueeze_(0)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
            train_losses.append(loss.item())
        
            # validate every 100 iterations
            if i > 0 and i % 500== 0:
                val_acc = test_model(mnist_val, model) #calls model.eval()
                val_accs.append(val_accs)
                print('Epoch: [{}/{}], Step: [{}/{}], Train Loss {}, Validation Acc: {}'.format( 
                           epoch+1, num_freeze_epochs, i+1, len(mnist_train), loss,  val_acc))
                model.train() #go back to training
    #do the unfreeze part          
    num_left = abs(num_epochs - num_freeze_epochs)
    unfreeze_model(feature_extractor) #hope this works
    param_list = [p for p in model.parameters() if p.requires_grad]
    
    print('unfreeze')
    print("num param req grad {}".format(len(param_list)))
    optimizer = optim.Adam(param_list, lr=2e-05, eps=1e-08)
    
    for epoch in range(num_left):
      for i , (images, labels) in enumerate(mnist_train):
          images = np.repeat(images, 3, axis=1) #convert to 3 channel
          inputs, labels = images.to(device), labels.to(device)
        
          optimizer.zero_grad()
          outputs = model(inputs)
          loss = criterion(outputs, labels)
          loss.backward()
          optimizer.step()

          train_losses.append(loss.item())

          # validate every 100 iterations
          if i > 0 and i % 500== 0:
              val_acc = test_model(mnist_val, model) #calls model.eval()
              val_accs.append(val_accs)
              print('Epoch: [{}/{}], Step: [{}/{}], Train Loss {}, Validation Acc: {}'.format( 
                         epoch+1, num_left, i+1, len(mnist_train), loss,  val_acc))
              model.train()#go back to training
           
  
    return model, train_losses, val_accs

### Grade UNFROZEN
Let's see if there's a difference in accuracy!

In [38]:
def grade_mnist_unfrozen():
    
    # init a ft model
    fine_tune_model = init_fine_tune_model(num_ftrs)
    
    # run the transfer learning routine
    num_epochs = 10
    model, train_losses, val_accs = UNFROZEN_fine_tune_mnist(pretrained_resnet18, fine_tune_model, num_epochs, mnist_train, mnist_val)
    
    # calculate test accuracy
    test_accuracy = calculate_mnist_test_accuracy(model, mnist_test)
    print(test_accuracy)
    # the real threshold will be released by Oct 11 
    assert test_accuracy > 0.0, 'your accuracy is too low...'
    
    #save model

    PATH_TO_FOLDER =  '/scratch/cp2530/myjupyter/'
    torch.save({
            'epoch': num_epochs,
            'model_state_dict': model.state_dict(),
            'loss': train_losses[-1],
            'frozen_test_accuracy': test_accuracy 
           
            }, PATH_TO_FOLDER + "models/ResNet18UnFreeze_CP")

    
    return test_accuracy
    


In [39]:
#load traned unfrozen - did not work
# feature_extractor = pretrained_resnet18
# feature_extractor.fc = init_fine_tune_model(num_ftrs= 512)
# model_unfrozen = feature_extractor

# checkpoint = torch.load('/scratch/cp2530/myjupyter/models/ResNet18UnFreeze_CP')
# model_unfrozen.load_state_dict(checkpoint['model_state_dict'])
# epoch = checkpoint['epoch']
# loss = checkpoint['loss']

# unfrozen_test_accuracy = calculate_mnist_test_accuracy(model_unfrozen, mnist_test)

In [40]:
unfrozen_test_accuracy = grade_mnist_unfrozen()

num param req grad 4
Epoch: [1/5], Step: [501/1719], Train Loss 2.0520548820495605, Validation Acc: 41.48
Epoch: [1/5], Step: [1001/1719], Train Loss 1.665267825126648, Validation Acc: 54.38
Epoch: [1/5], Step: [1501/1719], Train Loss 1.4178346395492554, Validation Acc: 59.5
Epoch: [2/5], Step: [501/1719], Train Loss 1.3698444366455078, Validation Acc: 63.4
Epoch: [2/5], Step: [1001/1719], Train Loss 1.1348119974136353, Validation Acc: 65.46
Epoch: [2/5], Step: [1501/1719], Train Loss 1.0798262357711792, Validation Acc: 65.9
Epoch: [3/5], Step: [501/1719], Train Loss 1.1365126371383667, Validation Acc: 68.48
Epoch: [3/5], Step: [1001/1719], Train Loss 0.9355572462081909, Validation Acc: 69.38
Epoch: [3/5], Step: [1501/1719], Train Loss 0.9325653910636902, Validation Acc: 69.18
Epoch: [4/5], Step: [501/1719], Train Loss 1.0224225521087646, Validation Acc: 71.0
Epoch: [4/5], Step: [1001/1719], Train Loss 0.8333955407142639, Validation Acc: 70.96
Epoch: [4/5], Step: [1501/1719], Train Los

In [41]:
assert unfrozen_test_accuracy > frozen_test_accuracy, 'the unfrozen model should be better'

--- 
# Question 2 (train a model on Wikitext-2)

Here we'll apply what we just learned to NLP. In this section we'll make our own feature extractor and pretrain it on Wikitext-2.

The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike License.

#### Part A
In this section you need to generate the training, validation and test split. Feel free to use code from your previous lectures.

In [10]:
#!pip install jsonlines

Collecting jsonlines
  Downloading https://files.pythonhosted.org/packages/4f/9a/ab96291470e305504aa4b7a2e0ec132e930da89eb3ca7a82fbe03167c131/jsonlines-1.2.0-py2.py3-none-any.whl
Installing collected packages: jsonlines
Successfully installed jsonlines-1.2.0


In [15]:
import os
import json
import jsonlines
from collections import defaultdict

In [16]:
#!pip install tqdm


In [17]:
import importlib.util
PATH_TO_FOLDER=  '/scratch/cp2530/myjupyter/'

def module_from_file(module_name, file_path):
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module

model_CP = module_from_file("model_CP", PATH_TO_FOLDER+"hw2/model_CP.py")

In [16]:
#!pip install torchtext

In [18]:

#support code from lab
from torchtext.datasets import WikiText2
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import pickle
import torch
import torchvision.models as models
import os
from torchvision import transforms
from torchvision.datasets import  MNIST
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F
import os
import json
import jsonlines
import numpy as np
from collections import defaultdict
from torch import nn
import numpy



import io
# def load_vectors(fname):
#     fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
#     n, d = map(int, fin.readline().split())
#     embedding_size = 300
#     max_vocab_size = 35000
#     embedding_dict = np.random.randn(max_vocab_size+2, embedding_size)
#     all_train_tokens = []
#     i = 0
    
#     for line in fin:
#         tokens = line.rstrip().split(' ')
#         all_train_tokens.append(tokens[0])
#         embedding_dict[i+2] = list(map(float, tokens[1:]))
#         i += 1
#         if i == max_vocab_size:
#             break
            
#     return embedding_dict, all_train_tokens
  
# # download the vectors yourself
# fasttext_embedding_dict, all_fasttext_tokens = load_vectors('wiki-news-300d-1M.vec')
  
class LMDataset(Dataset):
    def __init__(self, list_of_token_lists):
        self.input_tensors = []
        self.target_tensors = []

        for sample in list_of_token_lists:
            self.input_tensors.append(torch.tensor([sample[:-1]], dtype=torch.long))
            self.target_tensors.append(torch.tensor([sample[1:]], dtype=torch.long))

    def __len__(self):
        return len(self.input_tensors)

    def __getitem__(self, idx):
        return (self.input_tensors[idx], self.target_tensors[idx])


def tokenize_dataset(datasets, dictionary):
    tokenized_datasets = {}
    for split, dataset in datasets.items():
        _current_dictified = []
        for l in tqdm(dataset):
            l = ['<bos>'] + l + ['<eos>']
            encoded_l = dictionary.encode_token_seq(l)
            _current_dictified.append(encoded_l)
        tokenized_datasets[split] = _current_dictified
    return tokenized_datasets

def tokenize_mnli_dataset(datasets, dictionary):
    tokenized_datasets = {}
    for split, dataset in datasets.items():
        _current_dictified = []
        for s1, s2 in tqdm(dataset):
            s1 = ['<bos>'] + s1 + ['<eos>']
            s2 = ['<bos>'] + s2 + ['<eos>']
            encoded_s1 = dictionary.encode_token_seq(s1)            
            encoded_s2 = dictionary.encode_token_seq(s2)
            _current_dictified.append([encoded_s1, encoded_s2])
        tokenized_datasets[split] = _current_dictified
    return tokenized_datasets

def pad_list_of_tensors(list_of_tensors, pad_token):
    max_length = max([t.size(-1) for t in list_of_tensors])
    padded_list = []
    for t in list_of_tensors:
        padded_tensor = torch.cat(
            [t, torch.tensor([[pad_token] * (max_length - t.size(-1))], dtype=torch.long)], dim=-1)
        padded_list.append(padded_tensor)

    padded_tensor = torch.cat(padded_list, dim=0)
    return padded_tensor


def pad_collate_fn(pad_idx, batch):
    input_list = [s[0] for s in batch]
    target_list = [s[1] for s in batch]
    input_tensor = pad_list_of_tensors(input_list, pad_idx)
    target_tensor = pad_list_of_tensors(target_list, pad_idx)
    return input_tensor, target_tensor


def load_wikitext(data_dir):
    import subprocess
    filename = os.path.join(data_dir, 'wikitext2-sentencized.json')
    if not os.path.exists(filename):
        os.makedirs(data_dir, exist_ok=True)
        url = "https://nyu.box.com/shared/static/9kb7l7ci30hb6uahhbssjlq0kctr5ii4.json"
        args = ['wget', '-O', filename, url]
        subprocess.call(args)
    raw_datasets = json.load(open(filename, 'r'))
    for name in raw_datasets:
        raw_datasets[name] = [x.split() for x in raw_datasets[name]]

    if os.path.exists(os.path.join(data_dir, 'vocab.pkl')):
        vocab = pickle.load(open(os.path.join(data_dir, 'vocab.pkl'), 'rb'))
    else:
        vocab = Dictionary(raw_datasets, include_valid=False)
        pickle.dump(vocab, open(os.path.join(data_dir, 'vocab.pkl'), 'wb'))

    tokenized_datasets = tokenize_dataset(raw_datasets, vocab)
    datasets = {name: LMDataset(ds) for name, ds in tokenized_datasets.items()}
    print("Vocab size: %d" % (len(vocab)))
    print(" padding index {}".format(vocab.get_id('<pad>')))
    return raw_datasets, datasets, vocab


class Dictionary(object):
    def __init__(self, datasets, include_valid=False):
        self.tokens = []
        self.ids = {}
        self.counts = {}
        
        # add special tokens
        self.add_token('<bos>')
        self.add_token('<eos>')
        self.add_token('<pad>')
        self.add_token('<unk>')
        
        for line in tqdm(datasets['train']):
            for w in line:
                self.add_token(w)
                    
        if include_valid is True:
            for line in tqdm(datasets['valid']):
                for w in line:
                    self.add_token(w)
                            
    def add_token(self, w):
        if w not in self.tokens:
            self.tokens.append(w)
            _w_id = len(self.tokens) - 1
            self.ids[w] = _w_id
            self.counts[w] = 1
        else:
            self.counts[w] += 1

    def get_id(self, w):
        return self.ids[w]
    
    def get_token(self, idx):
        return self.tokens[idx]
    
    def decode_idx_seq(self, l):
        return [self.tokens[i] for i in l]
    
    def encode_token_seq(self, l):
        return [self.ids[i] if i in self.ids else self.ids['<unk>'] for i in l]
    
    def __len__(self):
        return len(self.tokens)





In [19]:
# def perplexity(model, sequences):
#     n_total = 0
#     logp_total = 0
#     for sequence in sequences:
#         logp_total += model.sequence_logp(sequence)
#         n_total += len(sequence) + 1  
#     ppl = 2 ** (- (1.0 / n_total) * logp_total)  
#     return ppl



def init_wikitext_dataset(): #same as grade
    """
    Fill in the details
    """
    
    raw_datasets, datasets, vocab = load_wikitext(os.getcwd())

    data_loaders = {name: DataLoader(datasets[name], batch_size=32, shuffle=True,
                                     collate_fn=lambda x: pad_collate_fn(vocab.get_id('<pad>'), x))
                    for name in datasets}
    
    wikitext_val = data_loaders['valid'] 
    wikitext_train = data_loaders['train'] 
    wikitext_test = data_loaders['test'] 
    
    #wiki_dict = model_CP.Dictionary(datasets, include_valid=True)
    
    return wikitext_train, wikitext_val, wikitext_test #, wiki_dict ##

In [20]:
init_wikitext_dataset()

100%|██████████| 78274/78274 [00:00<00:00, 117288.21it/s]
100%|██████████| 8464/8464 [00:00<00:00, 86327.61it/s]
100%|██████████| 9708/9708 [00:00<00:00, 48274.98it/s]


Vocab size: 33178
 padding index 2


(<torch.utils.data.dataloader.DataLoader at 0x2b24d7be6d10>,
 <torch.utils.data.dataloader.DataLoader at 0x2b24d7be6810>,
 <torch.utils.data.dataloader.DataLoader at 0x2b24d7be6e90>)

#### Part B   
Here we design our own feature extractor. In MNIST that was a resnet because we were dealing with images. Now we need to pick a model that can model sequences better. Design an RNN-based model here.

In [26]:
class LSTM_CP(nn.Module):
    def __init__(self, options):
        super().__init__()
        
        # create each LM part here 
        self.lookup = nn.Embedding(num_embeddings=options['num_embeddings'], embedding_dim=options['embedding_dim'], padding_idx=options['padding_idx'])
        
        self.lstm = nn.LSTM(options['input_size'], options['hidden_size'], options['num_layers'], dropout=options['rnn_dropout'], batch_first=True)
        self.projection = nn.Linear(options['hidden_size'], options['num_embeddings'])
        
    def forward(self, encoded_input_sequence):
        """
        Forward method process the input from token ids to logits
        """
        embeddings = self.lookup(encoded_input_sequence)
        lstm_outputs = self.lstm(embeddings)
        logits = self.projection(lstm_outputs[0])
        
        return logits

class RNNLanguageModel(nn.Module):
    """
    This model combines embedding, rnn and projection layer into a single model
    """
    def __init__(self, options):
        super().__init__()
        
        # create each LM part here 
        self.lookup = nn.Embedding(num_embeddings=options['num_embeddings'], embedding_dim=options['embedding_dim'], padding_idx=options['padding_idx'])
        self.rnn = nn.RNN(options['input_size'], options['hidden_size'], options['num_layers'], dropout=options['rnn_dropout'], batch_first=True)
        self.projection = nn.Linear(options['hidden_size'], options['num_embeddings'])
        
    def forward(self, encoded_input_sequence):
        """
        Forward method process the input from token ids to logits
        """
        embeddings = self.lookup(encoded_input_sequence)
        rnn_outputs = self.rnn(embeddings)
        logits = self.projection(rnn_outputs[0])
        
        return logits

def init_feature_extractor(): 
    num_embeddings = 33178#len(vocab)
    embedding_size = 128
    hidden_size = 256
    num_layers = 2
    rnn_dropout = 0.1

    options = {
        'num_embeddings': num_embeddings, #len(wiki_dict),
        'embedding_dim': embedding_size,
        'padding_idx': 2,
        'input_size': embedding_size,
        'hidden_size': hidden_size,
        'num_layers': num_layers,
        'rnn_dropout': rnn_dropout,
    }
    
    #feature_extractor =  RNNLanguageModel(options)
    feature_extractor =  LSTM_CP(options)
    #feature_extractor.projection = Identity() #we remove the last layer for now
    
    return feature_extractor

#### Part C
Pretrain the feature extractor

In [22]:
def fit_feature_extractor(feature_extractor, wikitext_train, wikitext_val):
    # FILL IN THE DETAILS
    #define current_device
    model = feature_extractor
    
    current_device = torch.device("cuda:0")
    
    #define criterion and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=2, reduction='sum') #2 is <pad>
    #no freezing yet, we fit the feature extractor
    param_list = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(param_list, lr=2e-05, eps=1e-08)
    model.to(current_device)
    plot_cache = []

    for epoch_number in range(5):
        avg_loss=0
        model.train()

        train_loss_cache = 0
        train_non_pad_tokens_cache = 0
        for i, (inp, target) in enumerate(wikitext_train):
            optimizer.zero_grad()
            inp = inp.to(current_device)
            target = target.to(current_device)
            logits = model(inp)

            loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
            train_loss_cache += loss.item()  # still sum here

          ### HERE WE COMPUTE NUMBER OF NON_PAD TOKENS IN THE TARGET
            non_pad_tokens = target.view(-1).ne(2).sum().item() #2 is index for <pad>

            train_non_pad_tokens_cache += non_pad_tokens

            loss /= non_pad_tokens  # very important to normalize your current loss before you run .backward()

            loss.backward()
            optimizer.step()

            if i % 500 == 0:
                avg_loss = train_loss_cache / train_non_pad_tokens_cache
                avg_ppl = 2**(avg_loss/numpy.log(2))
                print('Epoch {} Step {} avg train loss = {:.{prec}f} perplexity = {:.{prec}f}'.format(epoch_number, i, avg_loss, avg_ppl, prec=4))
                #train_log_cache = []

      #do valid
        avg_val_loss = 0
        avg_val_ppl = 0
        valid_loss_cache = 0
        valid_non_pad_tokens_cache = 0

        model.eval()
        with torch.no_grad():
            for i, (inp, target) in enumerate(wikitext_val):
                inp = inp.to(current_device)
                target = target.to(current_device)
                logits = model(inp)

                loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
                valid_loss_cache += loss.item()  # still sum here

              ### HERE WE COMPUTE NUMBER OF NON_PAD TOKENS IN THE TARGET
                non_pad_tokens = target.view(-1).ne(2).sum().item()# 2 is index for <pad>

                valid_non_pad_tokens_cache += non_pad_tokens

            avg_val_loss = valid_loss_cache / valid_non_pad_tokens_cache
            avg_val_ppl = 2**(avg_val_loss/numpy.log(2))

            print('Validation loss after {} epoch = {:.{prec}f} perplexity = {:.{prec}f}'.format(epoch_number, avg_val_loss, avg_val_ppl,prec=4))

        plot_cache.append((avg_loss, avg_val_loss))
        #save model

    PATH_TO_FOLDER=  '/scratch/cp2530/myjupyter/'
    torch.save({
            'epoch': 5,
            'model_state_dict': model.state_dict(),
            'train_loss': avg_loss,
            'train_perplexity': avg_ppl,
            'val_loss':avg_val_loss,
            'val_perplexity':avg_val_ppl, 
            'plot_cache': plot_cache
           
            }, PATH_TO_FOLDER + "models/LSTMfeatextract_CP_2")



#### Part D
Calculate the test perplexity on wikitext2. Feel free to recycle code from previous assignments from this class. 

In [23]:
def calculate_wiki2_test_perplexity(feature_extractor, wikitext_test):
    model = feature_extractor
    current_device = torch.device("cuda:0")
    criterion = nn.CrossEntropyLoss(ignore_index=2, reduction='sum')
    
    valid_loss_cache = 0
    valid_non_pad_tokens_cache = 0
    avg_val_loss = 0
    avg_val_ppl = 0
    model.eval()
    with torch.no_grad():
        for i, (inp, target) in enumerate(wikitext_test):
            inp = inp.to(current_device)
            target = target.to(current_device)
            logits = model(inp)
            loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
            valid_loss_cache += loss.item()  # still sum here

            ### HERE WE COMPUTE NUMBER OF NON_PAD TOKENS IN THE TARGET
            non_pad_tokens = target.view(-1).ne(2).sum().item()# 2 is index for <pad>

            valid_non_pad_tokens_cache += non_pad_tokens

        avg_val_loss = valid_loss_cache / valid_non_pad_tokens_cache
        avg_val_ppl = 2**(avg_val_loss/numpy.log(2))
    
    
    test_ppl = avg_val_ppl
    
    return test_ppl

#### Let's grade your results!
(don't touch this part)

In [27]:
def grade_wikitext2():
    # load data
    wikitext_train, wikitext_val, wikitext_test = init_wikitext_dataset()

    # load feature extractor
    feature_extractor = init_feature_extractor()

    # pretrain using the feature extractor
    fit_feature_extractor(feature_extractor, wikitext_train, wikitext_val)

    # check test accuracy
    test_ppl = calculate_wiki2_test_perplexity(feature_extractor, wikitext_test)
    print("test_ppl : {}".format( test_ppl)) #should be <200
    # the real threshold will be released by Oct 11 
    assert test_ppl < 10000, 'ummm... your perplexity is too high...'
    
grade_wikitext2()

100%|██████████| 78274/78274 [00:00<00:00, 89388.07it/s]
100%|██████████| 8464/8464 [00:00<00:00, 110954.25it/s]
100%|██████████| 9708/9708 [00:00<00:00, 90987.98it/s]


Vocab size: 33178
 padding index 2
Epoch 0 Step 0 avg train loss = 10.4174 perplexity = 33435.7746
Epoch 0 Step 500 avg train loss = 8.9709 perplexity = 7870.8750
Epoch 0 Step 1000 avg train loss = 8.0873 perplexity = 3252.9791
Epoch 0 Step 1500 avg train loss = 7.7476 perplexity = 2316.0394
Epoch 0 Step 2000 avg train loss = 7.5710 perplexity = 1941.0217
Validation loss after 0 epoch = 6.8274 perplexity = 922.7822
Epoch 1 Step 0 avg train loss = 6.7899 perplexity = 888.8265
Epoch 1 Step 500 avg train loss = 6.9786 perplexity = 1073.3803
Epoch 1 Step 1000 avg train loss = 6.9722 perplexity = 1066.5397
Epoch 1 Step 1500 avg train loss = 6.9646 perplexity = 1058.4797
Epoch 1 Step 2000 avg train loss = 6.9575 perplexity = 1050.9641
Validation loss after 1 epoch = 6.7250 perplexity = 832.9360
Epoch 2 Step 0 avg train loss = 6.9859 perplexity = 1081.3285
Epoch 2 Step 500 avg train loss = 6.8691 perplexity = 962.0473
Epoch 2 Step 1000 avg train loss = 6.8556 perplexity = 949.1375
Epoch 2 Ste

---   
## Question 3 (fine-tune on MNLI) REMOVED
In this question you will use your feature_extractor from question 2
to fine-tune on MNLI.

(From the website):
The Multi-Genre Natural Language Inference (MultiNLI) corpus is a crowd-sourced collection of 433k sentence pairs annotated with textual entailment information. The corpus is modeled on the SNLI corpus, but differs in that covers a range of genres of spoken and written text, and supports a distinctive cross-genre generalization evaluation. The corpus served as the basis for the shared task of the RepEval 2017 Workshop at EMNLP in Copenhagen.

MNLI has 3 genres (3 classes).
The goal of this question is to maximize the test accuracy in MNLI. 

### Part A
In this section you need to generate the training, validation and test split. Feel free to use code from your previous lectures.

In [None]:
from torchtext.datasets import MultiNLI

def init_mnli_dataset():
    """
    Fill in the details
    """
    mnli_val = None # TODO
    mnli_train = None # TODO
    mnli_test = None # TODO
    
    return mnli_train, mnli_val, mnli_test

### Part B
Here we again design a model for finetuning. Use the output of your feature-extractor as the input to this model. This should be a powerful classifier (up to you).

In [None]:
def init_finetune_model():
    
    # TODO FILL IN THE DETAILS
    fine_tune_model = ...
    
    return fine_tune_model

### Part C
Use the feature_extractor and your fine_tune_model to fine_tune MNLI

In [None]:
def fine_tune_mnli(feature_extractor, fine_tune_model, mnli_train, mnli_val):
    # YOUR CODE HERE
    pass

### Part D
Evaluate the test accuracy

In [None]:
def calculate_mnli_test_accuracy(feature_extractor, fine_tune_model, mnli_test):
    
    # YOUR CODE HERE...
    
    return test_ppl

### Let's grade your results

In [None]:
def grade_mnli():
    # load data
    mnli_train, mnli_val, mnli_test = init_mnli_dataset()

    # no need to load feature extractor because it is fine-tuned
    feature_extractor = feature_extractor

    # init the fine_tune model
    fine_tune_model = init_finetune_model()
    
    # finetune
    fine_tune_mnli(feature_extractor, fine_tune_model, mnli_train, mnli_val)

    # check test accuracy
    test_accuracy = calculate_mnli_test_accuracy(feature_extractor, wikitext_test)

    # the real threshold will be released by Oct 11 
    assert test_accuracy > 0.00, 'ummm... your accuracy is too low...'
    
grade_mnli()

---  
### Question 4 (BERT)

A major direction in research came from a model called BERT, released last year.  

In this question you'll use BERT as your feature_extractor instead of the model you
designed yourself.

To get BERT, head on over to (https://github.com/huggingface/transformers) and load your BERT model here

In [34]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/fd/f9/51824e40f0a23a49eab4fcaa45c1c797cbf9761adedd0b558dab7c958b34/transformers-2.1.1-py3-none-any.whl (311kB)
[K     |████████████████████████████████| 317kB 17.8MB/s eta 0:00:01
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/1f/8e/ed5364a06a9ba720fddd9820155cc57300d28f5f43a6fd7b7e817177e642/sacremoses-0.0.35.tar.gz (859kB)
[K     |████████████████████████████████| 860kB 25.7MB/s eta 0:00:01
Collecting boto3
[?25l  Downloading https://files.pythonhosted.org/packages/0e/41/27fb3969a76240d4c42a8f64b9d5ae78c676bab38e980e03b1bbaef279bd/boto3-1.10.2-py2.py3-none-any.whl (128kB)
[K     |████████████████████████████████| 133kB 34.9MB/s eta 0:00:01    |███████████████████████         | 92kB 50.2MB/s eta 0:00:01
[?25hCollecting regex
[?25l  Downloading https://files.pythonhosted.org/packages/6f/a6/99eeb5904ab763db87af4bd71d9b1dfdd9792681240657a4c0a599c10a81/regex-2019.08

In [36]:
!pip install pandas


Collecting pandas
[?25l  Downloading https://files.pythonhosted.org/packages/91/9d/217fc3c4fe19123fcb99385a35c3211e65d5eb07fbe8dd0008fae0d1fe74/pandas-0.25.2-cp37-cp37m-manylinux1_x86_64.whl (10.4MB)
[K     |████████████████████████████████| 10.4MB 17.1MB/s eta 0:00:01
[?25hCollecting pytz>=2017.2
[?25l  Downloading https://files.pythonhosted.org/packages/e7/f9/f0b53f88060247251bf481fa6ea62cd0d25bf1b11a87888e53ce5b7c8ad2/pytz-2019.3-py2.py3-none-any.whl (509kB)
[K     |████████████████████████████████| 512kB 61.9MB/s eta 0:00:01
Installing collected packages: pytz, pandas
Successfully installed pandas-0.25.2 pytz-2019.3


In [37]:
#from will
from transformers.data.processors.glue import MnliProcessor
import torch
import pandas as pd
import os
import sys
import shutil
import argparse
import tempfile
import urllib.request
import zipfile
from transformers import glue_convert_examples_to_features as convert_examples_to_features
from transformers import BertTokenizer
from torch.utils.data import TensorDataset, RandomSampler, DataLoader


# from transformers import (
#     BertModel,
#     BertTokenizer
# )

# tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
# bert = BertModel.from_pretrained('bert-base-cased', output_attentions=True)

In [38]:
#from will
TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
TASK2PATH = {
    "CoLA": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4",  # noqa
    "SST": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8",  # noqa
    "MRPC": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc",  # noqa
    "QQP": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP-clean.zip?alt=media&token=11a647cb-ecd3-49c9-9d31-79f8ca8fe277",  # noqa
    "STS": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5",  # noqa
    "MNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce",  # noqa
    "SNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df",  # noqa
    "QNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601",  # noqa
    "RTE": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb",  # noqa
    "WNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf",  # noqa
    "diagnostic": [
        "https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D",  # noqa
        "https://www.dropbox.com/s/ju7d95ifb072q9f/diagnostic-full.tsv?dl=1",
    ],
}

MRPC_TRAIN = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt"
MRPC_TEST = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt"


def download_and_extract(task, data_dir):
    print("Downloading and extracting %s..." % task)
    data_file = "%s.zip" % task
    urllib.request.urlretrieve(TASK2PATH[task], data_file)
    with zipfile.ZipFile(data_file) as zip_ref:
        zip_ref.extractall(data_dir)
    os.remove(data_file)
    print("\tCompleted!")

In [39]:
download_and_extract('MNLI', '.')

Downloading and extracting MNLI...
	Completed!


In [40]:
processor = MnliProcessor()

In [41]:
def generate_mnli_bert_dataloaders():
  # ----------------------
  # TRAIN/VAL DATALOADERS
  # ----------------------
  train = processor.get_train_examples('MNLI')
  features = convert_examples_to_features(train,
                                          tokenizer,
                                          label_list=['contradiction','neutral','entailment'],
                                          max_length=128,
                                          output_mode='classification',
                                          pad_on_left=False,
                                          pad_token=tokenizer.pad_token_id,
                                          pad_token_segment_id=0)
  train_dataset = TensorDataset(torch.tensor([f.input_ids for f in features], dtype=torch.long), 
                                torch.tensor([f.attention_mask for f in features], dtype=torch.long), 
                                torch.tensor([f.token_type_ids for f in features], dtype=torch.long), 
                                torch.tensor([f.label for f in features], dtype=torch.long))

  nb_train_samples = int(0.95 * len(train_dataset))
  nb_val_samples = len(train_dataset) - nb_train_samples

  bert_mnli_train_dataset, bert_mnli_val_dataset = random_split(train_dataset, [nb_train_samples, nb_val_samples])

  # train loader
  train_sampler = RandomSampler(bert_mnli_train_dataset)
  bert_mnli_train_dataloader = DataLoader(bert_mnli_train_dataset, sampler=train_sampler, batch_size=32)

  # val loader
  val_sampler = RandomSampler(bert_mnli_val_dataset)
  bert_mnli_val_dataloader = DataLoader(bert_mnli_val_dataset, sampler=val_sampler, batch_size=32)


  # ----------------------
  # TEST DATALOADERS
  # ----------------------
  dev = processor.get_dev_examples('MNLI')
  features = convert_examples_to_features(dev,
                                          tokenizer,
                                          label_list=['contradiction','neutral','entailment'],
                                          max_length=128,
                                          output_mode='classification',
                                          pad_on_left=False,
                                          pad_token=tokenizer.pad_token_id,
                                          pad_token_segment_id=0)

  bert_mnli_test_dataset = TensorDataset(torch.tensor([f.input_ids for f in features], dtype=torch.long), 
                                torch.tensor([f.attention_mask for f in features], dtype=torch.long), 
                                torch.tensor([f.token_type_ids for f in features], dtype=torch.long), 
                                torch.tensor([f.label for f in features], dtype=torch.long))

  # test dataset
  test_sampler = RandomSampler(bert_mnli_test_dataset)
  bert_mnli_test_dataloader = DataLoader(bert_mnli_test_dataset, sampler=test_sampler, batch_size=32)
  
  return bert_mnli_train_dataloader, bert_mnli_val_dataloader, bert_mnli_test_dataloader

### Part A (init BERT)
In this section you need to create an instance of BERT and return if from the function

In [44]:
from transformers import BertTokenizer, BertModel, BertForMaskedLM

def init_bert():
    BERT=BertModel.from_pretrained('bert-base-cased', output_attentions=True) # ... YOUR CODE HERE
    
    return BERT

In [49]:
BERT_feature_extractor = init_bert()
#num_ftrs = BERT_feature_extractor.BertPooler.in_features
#BERT_feature_extractor #check if fc layer exist
#num_ftrs

In [51]:
BERT_feature_extractor.config.hidden_size

768

## Part B (fine-tune with BERT)

Use BERT as your feature extractor to finetune MNLI. Use a new finetune model (reset weights).

In [54]:
class BERT_MNLIClassifier(nn.Module):
    def __init__(self, bert, num_classes):
        super().__init__()
        self.bert = bert
        self.fc = nn.Sequential(
                    nn.Linear(bert.config.hidden_size, 100),
                    nn.ReLU(),
                    nn.Linear(100, num_classes))
        
        self.num_classes = num_classes
        
    def forward(self, input_ids, attention_mask, token_type_ids):
        h, _, attn = self.bert(input_ids=input_ids, 
                               attention_mask=attention_mask, 
                               token_type_ids=token_type_ids)
        h_cls = h[:, 0]
        logits = self.fc(h_cls)
        return logits, attn





def init_finetune_model(bert, num_classes):
    model = BERT_MNLIClassifier (bert, num_classes)
    return model

In [None]:
def fine_tune_mnli_BERT(BERT_feature_extractor, fine_tune_model, mnli_train, mnli_val):
    
    
    #####need to fix this####
    device = torch.device("cuda:0")
    #num_epochs = 3
    # INSERT YOUR CODE: (train the fine_tune model using features extracted by feature_extractor)
    #first freeze the layers
    freeze_model(feature_extractor)
    
    #create the finetune model
    feature_extractor.fc = fine_tune_model #this is MLP toplayer
    model = feature_extractor
    
    #create loss etc. 
    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=2e-05, eps=1e-08)
    criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device)
    model.to(device)
    
    train_losses = []
    val_accs= []
    model.train()
    for epoch in range(num_epochs): 
      #train
       
      for i , (images, labels) in enumerate(mnist_train):
            images = np.repeat(images, 3, axis=1) #convert to 3 channel
            inputs, labels = images.to(device), labels.to(device)
            #print(inputs.shape)
            #inputs = inputs.unsqueeze_(0)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
            train_losses.append(loss.item())
        
            # validate every 100 iterations
            if i > 0 and i % 500== 0:
                val_acc = test_model(mnist_val, model)
                val_accs.append(val_accs)
                print('Epoch: [{}/{}], Step: [{}/{}], Train Loss {}, Validation Acc: {}'.format( 
                           epoch+1, num_epochs, i+1, len(mnist_train), loss,  val_acc))
                model.train() #go back to training
  
    return model, train_losses, val_accs
    
    
    
    pass

## Part C
Evaluate how well we did

In [None]:
def calculate_mnli_test_accuracy_BERT(feature_extractor, fine_tune_model, mnli_test):
    
    # YOUR CODE HERE...
    
    return test_ppl

## Let's grade your BERT results!

In [None]:
def grade_mnli_BERT():
    BERT_feature_extractor = init_bert()
    num_ftrs_Bert = 768 #from printing BERT_feature_extractor
    
    # load data
    mnli_train, mnli_val, mnli_test = generate_mnli_bert_dataloaders()#init_mnli_dataset()

    # init the fine_tune model
    fine_tune_model = init_finetune_model(num_ftrs_Bert)
    
    # finetune
    fine_tune_mnli(BERT_feature_extractor, fine_tune_model, mnli_train, mnli_val)

    # check test accuracy
    test_accuracy = calculate_mnli_test_accuracy(feature_extractor, wikitext_test)
    
    # the real threshold will be released by Oct 11 
    assert test_ppl > 0.0, 'ummm... your accuracy is too low...'
    
grade_mnli_BERT()