---   
# HW3 - Transfer learning

#### Due October 30, 2019

In this assignment you will learn about transfer learning. This technique is perhaps one of the most important techniques for industry. When a problem you want to solve does not have enough data, we use a different (larger) dataset to learn representations which can help us solve our task using the smaller task.

The general steps to transfer learning are as follows:

1. Find a huge dataset with similar characteristics to the problem you are interested in.
2. Choose a model powerful enough to extract meaningful representations from the huge dataset.
3. Train this model on the huge dataset.
4. Use this model to train on the smaller dataset.


### This homework has the following sections:
1. Question 1: MNIST fine-tuning (Parts A, B, C, D).
2. Question 2: Pretrain on Wikitext2 (Part A, B, C, D)
3. Question 3: Finetune on MNLI (Part A, B, C, D)
4. Question 4: Finetune using pretrained BERT (Part A, B, C)

---   
## Question 1 (MNIST transfer learning)
To grasp the high-level approach to transfer learning, let's first do a simple example using computer vision. 

The torchvision library has pretrained models (resnets, vggnets, etc) on the Imagenet dataset. Imagenet is a dataset
with 1.3 million images covering over 1000 classes of objects. When you use one of these models, the weights of the model initialize
with the weights saved from training on imagenet.

In this task we will:
1. Choose a pretrained model.
2. Freeze the model so that the weights don't change.
3. Fine-tune on a few labels of MNIST.   

#### Choose a model
Here we pick any of the models from torchvision

In [0]:
import torch
import torchvision.models as models

class Identity(torch.nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

# init the pretrained feature extractor
pretrained_resnet18 = models.resnet18(pretrained=True)

# we don't want the built in last layer, we're going to modify it ourselves
pretrained_resnet18.fc = Identity()

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/checkpoints/resnet18-5c106cde.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 74.8MB/s]


#### Freeze the model
Here we freeze the weights of the model. Freezing means the gradients will not backpropagate
into these weights.

By doing this you can think about the model as a feature extractor. This feature extractor outputs
a **representation** of an input. This representation is a matrix that encodes information about the input.

In [0]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False
        
def unfreeze_model(model):
    for param in model.parameters():
        param.requires_grad = True
        
freeze_model(pretrained_resnet18)

#### Init target dataset
Here we define the dataset we are actually interested in.

In [0]:
import os
from torchvision import transforms
from torchvision.datasets import  MNIST
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F

transform = transforms.Compose([transforms.Grayscale(3),
transforms.ToTensor()
])                          
#  train/val  split
mnist_dataset = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_train, mnist_val = random_split(mnist_dataset, [55000, 5000])

mnist_train = DataLoader(mnist_train, batch_size=32)
mnist_val = DataLoader(mnist_val,batch_size=32)

# test split
mnist_test = DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transform), batch_size=32)

### Part A (init fine-tune model)
decide what model to use for fine-tuning

In [0]:
import torch.nn as nn

def init_fine_tune_model():
  class fine_tune_model(nn.Module):
    def __init__(self):
        super().__init__()
        self.s= nn.Sequential(nn.Linear(512,216), nn.ReLU(), nn.Dropout(0.4), nn.Linear(216, 10))   
    def forward(self, data):
        logits=self.s(data)
        return logits
  return fine_tune_model()

### Part B (Fine-tune (Frozen))

The actual problem we care about solving likely has a different number of classes or is a different task altogether. Fine-tuning is the process of using the extracted representations (features) to solve this downstream task  (the task you're interested in).

To illustrate this, we'll use our pretrained model (on Imagenet), to solve the MNIST classification task.

There are two types of finetuning. 

#### 1. Frozen feature_extractor
In the first type we pretrain with the FROZEN feature_extractor and NEVER unfreeze it during finetuning.


#### 2. Unfrozen feature_extractor
In the second, we finetune with a FROZEN feature_extractor for a few epochs, then unfreeze the feature extractor and finish training.


In this part we will use the first version

In [0]:
import torch.optim as optim

def FROZEN_fine_tune_mnist(feature_extractor, fine_tune_model, mnist_train, mnist_val):
    """
    model is a feature extractor (resnet).
    Create a new model which uses those features to finetune on MNIST
    
    return the fine_tune model
    """     
        
    device = torch.device("cuda:0")
    criterion = nn.CrossEntropyLoss().to(device)
    fine_tune_model.to(device)
    feature_extractor.to(device)
    feature_extractor.fc = fine_tune_model
    model = feature_extractor
    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=2e-05, eps=1e-08)
    # Train!
    for epoch_number in range(15):
        model.train()
        for inp, target in mnist_train:
            optimizer.zero_grad()
            inp = inp.to(device)
            target = target.to(device)   
            logits = model(inp)
            loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
            loss.backward()
            optimizer.step()
        
        model.eval()
        total=0
        correct=0
        with torch.no_grad():
            for inp, target in mnist_val:
                inp = inp.to(device)
                target = target.to(device)
                logits = model(inp)

                outputs = F.softmax(logits,dim=1)
                predicted = outputs.max(1, keepdim=True)[1]
                total += target.size(0)
                correct += predicted.eq(target.view_as(predicted)).sum().item()

            print('Validation acc after '+str(epoch_number+1)+' epochs = {:.{prec}f}'.format((100 * correct / total), prec=4))

### Part C (compute test accuracy)
Compute the test accuracy of fine-tuned model on MNIST

In [0]:
def calculate_mnist_test_accuracy(feature_extractor, fine_tune_model, mnist_test):
    device = torch.device("cuda:0")
    correct = 0
    total = 0
    fine_tune_model.eval()
    for inp, target in mnist_test:
        inp = inp.to(device)
        inp = feature_extractor(inp)
        target = target.to(device)   
        output = F.softmax(inp,dim=1)
        pred = output.max(1,keepdim=True)[1]
        total += target.size(0)
        correct += pred.eq(target.view_as(pred)).sum().item()
    
    return 100*correct/total

### Grade!
Let's see how you did

In [0]:
def grade_mnist_frozen():
    
    # init a ft model
    fine_tune_model = init_fine_tune_model()
    
    # run the transfer learning routine
    FROZEN_fine_tune_mnist(pretrained_resnet18, fine_tune_model, mnist_train, mnist_val)
    
    # calculate test accuracy
    test_accuracy = calculate_mnist_test_accuracy(pretrained_resnet18, fine_tune_model, mnist_test)
    
    # the real threshold will be released by Oct 11 
    assert test_accuracy > 0.0, 'your accuracy is too low...'
    
    return test_accuracy
    
frozen_test_accuracy = grade_mnist_frozen()

Validation acc after 1 epochs = 60.7000
Validation acc after 2 epochs = 67.0600
Validation acc after 3 epochs = 69.7000
Validation acc after 4 epochs = 71.6600
Validation acc after 5 epochs = 73.1200
Validation acc after 6 epochs = 73.9600
Validation acc after 7 epochs = 74.2400
Validation acc after 8 epochs = 74.6000
Validation acc after 9 epochs = 74.7400
Validation acc after 10 epochs = 75.0200
Validation acc after 11 epochs = 75.2800
Validation acc after 12 epochs = 75.5200
Validation acc after 13 epochs = 75.8800
Validation acc after 14 epochs = 76.1200
Validation acc after 15 epochs = 76.3400


In [0]:
frozen_test_accuracy

76.28

### Part D (Fine-tune Unfrozen)
Now we'll learn how to train using the "unfrozen" approach.

In this approach we'll:
1. keep the feature_extract frozen for a few epochs (10)
2. Unfreeze it.
3. Finish training

In [0]:
def UNFROZEN_fine_tune_mnist(feature_extractor, fine_tune_model, mnist_train, mnist_val):

    # INSERT YOUR CODE:
    # keep frozen for 10 epochs
    # ... train
    # unfreeze
    # train for rest of the time
    device = torch.device("cuda:0")
    criterion = nn.CrossEntropyLoss().to(device)
    fine_tune_model.to(device)
    feature_extractor.to(device)
    feature_extractor.fc = fine_tune_model
    model = feature_extractor
    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=2e-05, eps=1e-08)
    # Train!
    for epoch_number in range(15):
        if epoch_number == 10:
            unfreeze_model(model)
            optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=2e-05, eps=1e-08)
        model.train()
        for inp, target in mnist_train:
            optimizer.zero_grad()
            inp = inp.to(device)
            target = target.to(device)   
            logits = model(inp)
            loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
            loss.backward()
            optimizer.step()
        
        model.eval()
        total=0
        correct=0
        with torch.no_grad():
            for inp, target in mnist_val:
                inp = inp.to(device)
                target = target.to(device)
                logits = model(inp)

                outputs = F.softmax(logits,dim=1)
                predicted = outputs.max(1, keepdim=True)[1]
                total += target.size(0)
                correct += predicted.eq(target.view_as(predicted)).sum().item()

            print('Validation acc after '+str(epoch_number+1)+' epochs = {:.{prec}f}'.format((100 * correct / total), prec=4))
    pass

### Grade UNFROZEN
Let's see if there's a difference in accuracy!

In [0]:
def grade_mnist_unfrozen():
    
    # init a ft model
    fine_tune_model = init_fine_tune_model()
    
    # run the transfer learning routine
    UNFROZEN_fine_tune_mnist(pretrained_resnet18, fine_tune_model, mnist_train, mnist_val)
    
    # calculate test accuracy
    test_accuracy = calculate_mnist_test_accuracy(pretrained_resnet18, fine_tune_model, mnist_test)
    
    # the real threshold will be released by Oct 11 
    assert test_accuracy > 0.0, 'your accuracy is too low...'
    
    return test_accuracy
    
pretrained_resnet18 = models.resnet18(pretrained=True)
pretrained_resnet18.fc = Identity()
freeze_model(pretrained_resnet18)

unfrozen_test_accuracy = grade_mnist_unfrozen()

Validation acc after 1 epochs = 62.2200
Validation acc after 2 epochs = 68.2400
Validation acc after 3 epochs = 70.6000
Validation acc after 4 epochs = 72.0200
Validation acc after 5 epochs = 72.8200
Validation acc after 6 epochs = 73.7000
Validation acc after 7 epochs = 74.1400
Validation acc after 8 epochs = 74.6800
Validation acc after 9 epochs = 74.9800
Validation acc after 10 epochs = 75.2000
Validation acc after 11 epochs = 97.8000
Validation acc after 12 epochs = 98.3600
Validation acc after 13 epochs = 98.4800
Validation acc after 14 epochs = 98.6000
Validation acc after 15 epochs = 98.8600


In [0]:
unfrozen_test_accuracy

98.9

In [0]:
assert unfrozen_test_accuracy > frozen_test_accuracy, 'the unfrozen model should be better'

--- 
## Question 2 (train a model on Wikitext-2)

Here we'll apply what we just learned to NLP. In this section we'll make our own feature extractor and pretrain it on Wikitext-2.

The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike License.

##### Utilities

In [0]:
import os
import json
import numpy as np
from collections import defaultdict
import torch
import torch.nn as nn
from torch.nn import RNNCell
from torch.nn import RNNBase, RNN
from torch.utils.data import Dataset, DataLoader
from torch.nn import Embedding
import torch.optim as optim
from tqdm import tqdm
import sys
import pickle as pkl
import matplotlib.pyplot as plt
import random
import math
import pandas as pd

In [0]:
class Dictionary(object):
    def __init__(self, datasets, include_valid=False):
        self.tokens = []
        self.ids = {}
        self.counts = {}
        self.add_token('<bos>')
        self.add_token('<eos>')
        self.add_token('<pad>')
        self.add_token('<unk>')
        
        for line in tqdm(datasets['train']):
            for w in line:
                self.add_token(w)
                    
        if include_valid is True:
            for line in tqdm(datasets['valid']):
                for w in line:
                    self.add_token(w)
                            
    def add_token(self, w):
        if w not in self.tokens:
            self.tokens.append(w)
            _w_id = len(self.tokens) - 1
            self.ids[w] = _w_id
            self.counts[w] = 1
        else:
            self.counts[w] += 1

    def get_id(self, w):
        return self.ids[w]
    
    def get_token(self, idx):
        return self.tokens[idx]
    
    def decode_idx_seq(self, l):
        return [self.tokens[i] for i in l]
    
    def encode_token_seq(self, l):
        return [self.ids[i] if i in self.ids else self.ids['<unk>'] for i in l]
    
    def __len__(self):
        return len(self.tokens)


def tokenize_dataset(datasets, dictionary, ngram_order=2):
    tokenized_datasets = {}
    for split, dataset in datasets.items():
        _current_dictified = []
        for l in tqdm(dataset):
            l = ['<bos>']*(ngram_order-1) + l + ['<eos>']
            encoded_l = dictionary.encode_token_seq(l)
            _current_dictified.append(encoded_l)
        tokenized_datasets[split] = _current_dictified
        
    return tokenized_datasets
  
class TensoredDataset(Dataset):
    def __init__(self, list_of_lists_of_tokens):
        self.input_tensors = []
        self.target_tensors = []
        
        for sample in list_of_lists_of_tokens:
            self.input_tensors.append(torch.tensor([sample[:-1]], dtype=torch.long))
            self.target_tensors.append(torch.tensor([sample[1:]], dtype=torch.long))
    
    def __len__(self):
        return len(self.input_tensors)
    
    def __getitem__(self, idx):
        # return a (input, target) tuple
        return (self.input_tensors[idx], self.target_tensors[idx])

def pad_list_of_tensors(list_of_tensors, pad_token):
    max_length = max([t.size(-1) for t in list_of_tensors])
    padded_list = []
    
    for t in list_of_tensors:
        padded_tensor = torch.cat([t, torch.tensor([[pad_token]*(max_length - t.size(-1))], dtype=torch.long)], dim = -1)
        padded_list.append(padded_tensor)
        
    padded_tensor = torch.cat(padded_list, dim=0)
    
    return padded_tensor

def pad_collate_fn(batch):
    # batch is a list of sample tuples
    input_list = [s[0] for s in batch]
    target_list = [s[1] for s in batch]
    
    pad_token = wiki_dict.get_id('<pad>')
    
    input_tensor = pad_list_of_tensors(input_list, pad_token)
    target_tensor = pad_list_of_tensors(target_list, pad_token)
    
    return input_tensor, target_tensor      
  
  
def load_wikitext(filename='wikitext2-sentencized.json'):
    if not os.path.exists(filename):
        !wget "https://nyu.box.com/shared/static/9kb7l7ci30hb6uahhbssjlq0kctr5ii4.json" -O $filename
    
    datasets = json.load(open(filename, 'r'))
    for name in datasets:
        datasets[name] = [x.split() for x in datasets[name]]
    vocab = list(set([t for ts in datasets['train'] for t in ts]))      
    print("Vocab size: %d" % (len(vocab)))
    return datasets, vocab

#### Part A
In this section you need to generate the training, validation and test split. Feel free to use code from your previous lectures.

In [0]:
def init_wikitext_dataset():
    """
    Fill in the details
    """
    datasets, vocab=load_wikitext()
    wiki_dict = pkl.load(open("wiki_dict.p", "rb"))

    wiki_tokenized_datasets = tokenize_dataset(datasets, wiki_dict)
    
    wiki_tensor_dataset = {}

    for split, listoflists in wiki_tokenized_datasets.items():
        wiki_tensor_dataset[split] = TensoredDataset(listoflists)

    wiki_loaders = {}
    batch_size = 32
    for split, wiki_dataset in wiki_tensor_dataset.items():
        wiki_loaders[split] = DataLoader(wiki_dataset, batch_size=batch_size, shuffle=True, collate_fn=pad_collate_fn)
    wikitext_train=wiki_loaders['train']
    wikitext_val=wiki_loaders['valid']
    wikitext_test=wiki_loaders['test']
    
    return wikitext_train, wikitext_val, wikitext_test

#### Part B   
Here we design our own feature extractor. In MNIST that was a resnet because we were dealing with images. Now we need to pick a model that can model sequences better. Design an RNN-based model here.

In [0]:
def init_feature_extractor():
    options = {'embedding_dim': 128,
           'hidden_size': 128,
           'input_size': 128,
           'num_embeddings': 33178,
           'num_layers': 2,
           'padding_idx': 2,
           'rnn_dropout': 0.1}
    class feature_extractor(nn.Module):
      def __init__(self, options):
        super().__init__()
        
        self.lookup = nn.Embedding(num_embeddings=options['num_embeddings'], embedding_dim=options['embedding_dim'], padding_idx=options['padding_idx'])
        self.lstm = nn.LSTM(options['input_size'], options['hidden_size'], options['num_layers'], batch_first=True)
        self.projection = nn.Linear(options['hidden_size'], options['num_embeddings'])
        
      def forward(self, encoded_input_sequence):
      
        embeddings = self.lookup(encoded_input_sequence)
        lstm_outputs = self.lstm(embeddings)
        logits = self.projection(lstm_outputs[0])
        
        return logits
    return feature_extractor(options)

#### Part C
Pretrain the feature extractor

In [0]:
def fit_feature_extractor(feature_extractor, wikitext_train, wikitext_val):
    # FILL IN THE DETAILS
  current_device = torch.device('cuda')
  model = feature_extractor.to(current_device)
  criterion = nn.CrossEntropyLoss(ignore_index=wiki_dict.get_id('<pad>'),reduction='sum')
  model_parameters = [p for p in model.parameters() if p.requires_grad]
  optimizer = optim.Adam(model_parameters, lr=0.001)

  for epoch_number in tqdm(range(8)):
        avg_loss=0
        # do train
        model.train()
        train_loss_cache = 0
        train_non_pad_tokens_cache = 0
        for i, (inp, target) in enumerate(wikitext_train):
            optimizer.zero_grad()
            inp = inp.to(current_device)
            target = target.to(current_device)
            logits = model(inp)

            loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
            train_loss_cache += loss.item()
            non_pad_tokens = target.view(-1).ne(wiki_dict.get_id('<pad>')).sum().item()
            train_non_pad_tokens_cache += non_pad_tokens
            loss /= non_pad_tokens 
            loss.backward()
            optimizer.step()
        

        avg_loss = train_loss_cache / train_non_pad_tokens_cache
        ppl = 2**(avg_loss/np.log(2))
        if epoch_number == 7:
            print('\nAvg train perplexity = {:.{prec}f}'.format(ppl, prec=4))
            
            
        valid_loss_cache = 0
        valid_non_pad_tokens_cache = 0
        #do valid
        valid_losses = []
        model.eval()
        with torch.no_grad():
            for i, (inp, target) in enumerate(wikitext_val):
                inp = inp.to(current_device)
                target = target.to(current_device)
                logits = model(inp)

                loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
                valid_loss_cache += loss.item()
                non_pad_tokens = target.view(-1).ne(wiki_dict.get_id('<pad>')).sum().item()
                valid_non_pad_tokens_cache += non_pad_tokens
                
            avg_val_loss = valid_loss_cache / valid_non_pad_tokens_cache
            ppl_val = 2**(avg_val_loss/np.log(2))
            if epoch_number == 7:
                print('\nValidation Perplexity = {:.{prec}f}'.format(ppl_val, prec=4))
                print()

  torch.save(model.state_dict(), "LSTM_feature_extractor.ckpt")

#### Part D
Calculate the test perplexity on wikitext2. Feel free to recycle code from previous assignments from this class. 

In [0]:
def calculate_wiki2_test_perplexity(feature_extractor, wikitext_test):
    
    # FILL IN DETAILS
    current_device = torch.device('cuda')
    model=feature_extractor.to(current_device)
    model.load_state_dict(torch.load('LSTM_feature_extractor.ckpt'))
    criterion = nn.CrossEntropyLoss(ignore_index=wiki_dict.get_id('<pad>'),reduction='sum')
    model.eval()
    test_loss_cache = 0
    test_non_pad_tokens_cache = 0
    
    model.eval()
    for i, (inp, target) in enumerate(wikitext_test):
        inp = inp.to(current_device)
        target = target.to(current_device)
        logits = model(inp)

        loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
        test_loss_cache += loss.item()
        non_pad_tokens = target.view(-1).ne(wiki_dict.get_id('<pad>')).sum().item()
        test_non_pad_tokens_cache += non_pad_tokens

    avg_test_loss = test_loss_cache / test_non_pad_tokens_cache
    test_ppl = 2**(avg_test_loss/np.log(2))
    return test_ppl

#### Let's grade your results!
(don't touch this part)

In [0]:
def grade_wikitext2():
    # load data
    wikitext_train, wikitext_val, wikitext_test = init_wikitext_dataset()

    # load feature extractor
    feature_extractor = init_feature_extractor()

    # pretrain using the feature extractor
    fit_feature_extractor(feature_extractor, wikitext_train, wikitext_val)

    # check test accuracy
    test_ppl = calculate_wiki2_test_perplexity(feature_extractor, wikitext_test)

    # the real threshold will be released by Oct 11 
    assert test_ppl < 200, 'ummm... your perplexity is too high...'
    
grade_wikitext2()

 17%|█▋        | 13051/78274 [00:00<00:00, 130507.01it/s]

Vocab size: 33175


100%|██████████| 78274/78274 [00:00<00:00, 127621.12it/s]
100%|██████████| 8464/8464 [00:00<00:00, 131296.93it/s]
100%|██████████| 9708/9708 [00:00<00:00, 132846.68it/s]
 88%|████████▊ | 7/8 [19:40<02:48, 168.50s/it]


Avg train perplexity = 109.5618


100%|██████████| 8/8 [22:28<00:00, 168.37s/it]


Validation Perplexity = 193.8261






In [82]:
!md5sum LSTM_feature_extractor.ckpt

fa99e1bc2633f1d302aab1c1d9425727  LSTM_feature_extractor.ckpt


---   
## Question 3 (fine-tune on MNLI)
In this question you will use your feature_extractor from question 2
to fine-tune on MNLI.

(From the website):
The Multi-Genre Natural Language Inference (MultiNLI) corpus is a crowd-sourced collection of 433k sentence pairs annotated with textual entailment information. The corpus is modeled on the SNLI corpus, but differs in that covers a range of genres of spoken and written text, and supports a distinctive cross-genre generalization evaluation. The corpus served as the basis for the shared task of the RepEval 2017 Workshop at EMNLP in Copenhagen.

MNLI has 3 genres (3 classes).
The goal of this question is to maximize the test accuracy in MNLI. 

### Part A
In this section you need to generate the training, validation and test split. Feel free to use code from your previous lectures.

In [0]:
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import pickle as pkl
import torch
import torchvision.models as models
import os
from torchvision import transforms
from torchvision.datasets import  MNIST
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F
import json
import numpy as np
from collections import defaultdict
from torch import nn
import pandas as pd 

In [0]:
def init_mnli_dataset():
    y_label_map = {'contradiction':0,'neutral':1,'entailment':2}

    def get_string_tokenized_data(data):

        tokenized_data_x = [];
        y_labels = []
        all_tokens = [];

        for i,x in enumerate(data):
            label = x[2]
            if label == 'nan':
                continue

            label = y_label_map[label]
            y_labels.append(label)

            dp = [x[0].split(), x[1].split()]
            tokenized_data_x.append(dp)
            all_tokens += (dp[0] + dp[1])


        return all_tokens, tokenized_data_x, y_labels

    # LOAD VAL

    val_df = pd.read_csv('mnli_val.tsv', sep="\t")

    val_df  = np.array(val_df)
    val_genre_list = val_df[:, 3]

    _, val_data_x, val_data_y = get_string_tokenized_data(val_df)
    del val_df

    train_df = pd.read_csv('mnli_train.tsv', sep="\t")

    train_df  = np.array(train_df)
    train_genre_list = train_df[:, 3]

    _, train_data_x, train_data_y = get_string_tokenized_data(train_df)
    del train_df


    mnli_raw_datasets = {'train': train_data_x, 'val': val_data_x}
    mnli_tokenized_datasets = tokenize_mnli_dataset(mnli_raw_datasets, wiki_dict)

    train_data_indices = mnli_tokenized_datasets['train']
    val_data_indices = mnli_tokenized_datasets['val']



    del mnli_tokenized_datasets

    unique_genre = list(set(val_genre_list));
    nb_classes = len(y_label_map)

    MAX_SENTENCE_LENGTH = 200

    class MNLIDataset(Dataset):
        def __init__(self, data_x, target_list):
            self.data_x = data_x;
            self.target_list = target_list

            assert(len(data_x) == len(target_list))

        def __len__(self):
            return len(self.target_list)

        def __getitem__(self, key):

            prem_token_idx = self.data_x[key][0][:MAX_SENTENCE_LENGTH]
            hyp_token_idx = self.data_x[key][1][:MAX_SENTENCE_LENGTH]
            label = self.target_list[key]
            return [prem_token_idx, hyp_token_idx, label]


    def encode_collate_func(batch):
        """
        Customized function for DataLoader that dynamically pads the batch so that all
        data have the same length
        """
        prem_data_list = []
        hyp_data_list = []
        label_list = []
        length_list = []
        # print("collate batch: ", batch[0][0])
        # batch[0][0] = batch[0][0][:MAX_SENTENCE_LENGTH]
        for datum in batch:
            label_list.append(datum[2])
        # padding
        for datum in batch:
            prem_padded_vec = np.pad(np.array(datum[0]),
                                     pad_width=((0, MAX_SENTENCE_LENGTH - len(datum[0]))),
                                     mode="constant", constant_values=wiki_dict.get_id('<pad>'))
            hyp_padded_vec = np.pad(np.array(datum[1]),
                                    pad_width=((0, MAX_SENTENCE_LENGTH - len(datum[1]))),
                                    mode="constant", constant_values=wiki_dict.get_id('<pad>'))
            prem_data_list.append(prem_padded_vec)
            hyp_data_list.append(hyp_padded_vec)
        return [torch.from_numpy((np.array(prem_data_list))), torch.from_numpy(np.array(hyp_data_list)),
                torch.LongTensor(label_list)]

    BATCH_SIZE = 32
    nb_train_samples = int(0.95 * len(train_data_indices))
    nb_val_samples = len(train_data_indices) - nb_train_samples

    # train/val split
    train_val_dataset = MNLIDataset(train_data_indices, train_data_y)
    train_dataset, val_dataset = random_split(train_val_dataset, [nb_train_samples, nb_val_samples])

    # train loader
    train_mnli_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=BATCH_SIZE,
                                               collate_fn=encode_collate_func,
                                               shuffle=True)

    # val loader
    val_mnli_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                               batch_size=BATCH_SIZE,
                                               collate_fn=encode_collate_func,
                                               shuffle=True)

    # test loader
    test_dataset = MNLIDataset(val_data_indices, val_data_y)
    test_mnli_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                               batch_size=BATCH_SIZE,
                                               collate_fn=encode_collate_func,
                                               shuffle=True)

    return train_mnli_loader, val_mnli_loader, test_mnli_loader

### Part B
Here we again design a model for finetuning. Use the output of your feature-extractor as the input to this model. This should be a powerful classifier (up to you).

In [0]:
def init_feature_extractor():
    options = {'embedding_dim': 128,
           'hidden_size': 128,
           'input_size': 128,
           'num_embeddings': 33178,
           'num_layers': 2,
           'padding_idx': 2,
           'rnn_dropout': 0.1}
    class feature_extractor(nn.Module):
      def __init__(self, options):
        super().__init__()
        
        self.lookup = nn.Embedding(num_embeddings=options['num_embeddings'], embedding_dim=options['embedding_dim'], padding_idx=options['padding_idx'])
        self.lstm = nn.LSTM(options['input_size'], options['hidden_size'], options['num_layers'], batch_first=True)
        self.projection = nn.Linear(options['hidden_size'], options['num_embeddings'])
        
      def forward(self, encoded_input_sequence):
      
        embeddings = self.lookup(encoded_input_sequence)
        lstm_outputs = self.lstm(embeddings)
        logits = self.projection(lstm_outputs[0])
        
        return logits
    return feature_extractor(options)
device = torch.device('cuda')
feature_extractor_model = init_feature_extractor().to(device)
feature_extractor_model.load_state_dict(torch.load('LSTM_feature_extractor.ckpt'))
feature_extractor_model.projection = Identity()
freeze_model(feature_extractor_model)

In [0]:
def init_fine_tune_model():
  class fine_tune_model(nn.Module):
    def __init__(self):
        super().__init__()
        self.s= nn.Sequential(nn.Linear(2*128,256), nn.ReLU(), nn.Linear(256,512), nn.ReLU(), nn.Dropout(0.2), nn.Linear(512, 3))   
    def forward(self, data):
        logits=self.s(data)
        return logits
  return fine_tune_model()

### Part C
Use the feature_extractor and your fine_tune_model to fine_tune MNLI

In [0]:
import torch.optim as optim
def fine_tune_mnli(feature_extractor, fine_tune_model, mnli_train, mnli_val):
    # YOUR CODE HERE
    current_device = torch.device("cuda:0")
    freeze_model(feature_extractor)
    model = nn.Sequential(feature_extractor.lookup, fine_tune_model)
    model = model.to(current_device)
  
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],lr=2e-5)

    
    for epoch in tqdm(range(10)):
      if epoch>=5:
        unfreeze_model(model)
        optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],lr=2e-5)
      model.train()
      
      for i, (inp1,inp2, target) in enumerate(mnli_train):
          optimizer.zero_grad()
          inp1 = inp1.to(current_device)
          inp2 = inp2.to(current_device)
          target = target.to(current_device)
          feature1 = model[0](inp1)
          feature2 = model[0](inp2)
          feature1 = torch.sum(feature1,dim=1)
          feature2 = torch.sum(feature2,dim=1)
          inp = torch.cat([feature1, feature2], dim=1)
          logits = model[1](inp)
          loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
          
          loss.backward()
          optimizer.step()
          
      model.eval()
      total=0
      correct=0
      with torch.no_grad():
          for i, (inp1,inp2, target) in enumerate(mnli_val):
              inp1 = inp1.to(current_device)
              inp2 = inp2.to(current_device)
              target = target.to(current_device)
              feature1 = model[0](inp1)
              feature2 = model[0](inp2)
              feature1 = torch.sum(feature1,dim=1)
              feature2 = torch.sum(feature2,dim=1)
              inp = torch.cat([feature1, feature2], dim=1)
              logits = model[1](inp)
              
              outputs = F.softmax(logits,dim=1)
              predicted = outputs.max(1, keepdim=True)[1]
              total += target.size(0)
              correct += predicted.eq(target.view_as(predicted)).sum().item()

          print('Validation acc = {:.{prec}f}'.format((100 * correct / total), prec=4))

### Part D
Evaluate the test accuracy

In [0]:
def calculate_mnli_test_accuracy(feature_extractor, fine_tune_model, mnli_test):
    
    current_device = torch.device("cuda:0")
    correct = 0
    total = 0
    model = nn.Sequential(feature_extractor.lookup, fine_tune_model)
    model.eval()    
    for inp1, inp2, target in mnli_test:
        inp1 = inp1.to(current_device)
        inp2 = inp2.to(current_device)
        target = target.to(current_device)
        feature1 = model[0](inp1)
        feature2 = model[0](inp2)
        feature1 = torch.sum(feature1,dim=1)
        feature2 = torch.sum(feature2,dim=1)
        inp = torch.cat([feature1, feature2], dim=1)
        logits = model[1](inp)

        outputs = F.softmax(logits,dim=1)
        predicted = outputs.max(1, keepdim=True)[1]
        total += target.size(0)
        correct += predicted.eq(target.view_as(predicted)).sum().item()
    
    return 100*correct/total

### Let's grade your results

In [0]:
def grade_mnli():
    # load data
    mnli_train, mnli_val, mnli_test = init_mnli_dataset()

    # no need to load feature extractor because it is fine-tuned
    feature_extractor = feature_extractor_model

    # init the fine_tune model
    fine_tune_model = init_fine_tune_model()
    
    # finetune
    fine_tune_mnli(feature_extractor, fine_tune_model, mnli_train, mnli_val)

    # check test accuracy
    test_accuracy = calculate_mnli_test_accuracy(feature_extractor, fine_tune_model, mnli_test)

    # the real threshold will be released by Oct 11 
    assert test_accuracy > 0.00, 'ummm... your accuracy is too low...'
    
    return test_accuracy
    
test_accuracy = grade_mnli()

100%|██████████| 20000/20000 [00:00<00:00, 49537.28it/s]
100%|██████████| 5000/5000 [00:00<00:00, 78110.58it/s]
 10%|█         | 1/10 [00:03<00:30,  3.43s/it]

Validation acc = 38.2000


 20%|██        | 2/10 [00:06<00:27,  3.38s/it]

Validation acc = 38.5000


 30%|███       | 3/10 [00:09<00:23,  3.33s/it]

Validation acc = 37.2000


 40%|████      | 4/10 [00:13<00:19,  3.29s/it]

Validation acc = 38.8000


 50%|█████     | 5/10 [00:16<00:16,  3.26s/it]

Validation acc = 39.3000


 60%|██████    | 6/10 [00:20<00:14,  3.50s/it]

Validation acc = 40.0000


 70%|███████   | 7/10 [00:24<00:11,  3.67s/it]

Validation acc = 40.0000


 80%|████████  | 8/10 [00:28<00:07,  3.79s/it]

Validation acc = 41.8000


 90%|█████████ | 9/10 [00:32<00:03,  3.88s/it]

Validation acc = 42.8000


100%|██████████| 10/10 [00:36<00:00,  3.96s/it]

Validation acc = 41.7000





In [0]:
test_accuracy

42.16

---  
## Question 4 (BERT)

A major direction in research came from a model called BERT, released last year.  

In this question you'll use BERT as your feature_extractor instead of the model you
designed yourself.

To get BERT, head on over to (https://github.com/huggingface/transformers) and load your BERT model here

#### Utilities

In [0]:
!pip install transformers
import os
import sys
import shutil
import argparse
import tempfile
import urllib.request
import zipfile
from tqdm import tqdm, trange
import pickle as pkl
import json
from collections import defaultdict
import numpy as np
import pandas as pd
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, TensorDataset, DataLoader, random_split, RandomSampler
import torchvision.models as models
from torchvision import transforms
from torchvision.datasets import  MNIST
from transformers.data.processors.glue import MnliProcessor
from transformers import glue_convert_examples_to_features as convert_examples_to_features
from transformers import (
    BertModel,
    BertTokenizer
)

In [0]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False
        
def unfreeze_model(model):
    for param in model.parameters():
        param.requires_grad = True

In [0]:
TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
TASK2PATH = {
    "CoLA": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4",  # noqa
    "SST": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8",  # noqa
    "MRPC": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc",  # noqa
    "QQP": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP-clean.zip?alt=media&token=11a647cb-ecd3-49c9-9d31-79f8ca8fe277",  # noqa
    "STS": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5",  # noqa
    "MNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce",  # noqa
    "SNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df",  # noqa
    "QNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601",  # noqa
    "RTE": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb",  # noqa
    "WNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf",  # noqa
    "diagnostic": [
        "https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D",  # noqa
        "https://www.dropbox.com/s/ju7d95ifb072q9f/diagnostic-full.tsv?dl=1",
    ],
}

MRPC_TRAIN = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt"
MRPC_TEST = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt"


def download_and_extract(task, data_dir):
    print("Downloading and extracting %s..." % task)
    data_file = "%s.zip" % task
    urllib.request.urlretrieve(TASK2PATH[task], data_file)
    with zipfile.ZipFile(data_file) as zip_ref:
        zip_ref.extractall(data_dir)
    os.remove(data_file)
    print("\tCompleted!")

download_and_extract('MNLI', '.')
processor = MnliProcessor()

Downloading and extracting MNLI...
	Completed!


### Part A (init BERT)
In this section you need to create an instance of BERT and return if from the function

In [0]:
def init_bert():
    bert = BertModel.from_pretrained('bert-base-cased', output_attentions=True)
    return bert

In [0]:
def init_mnli_dataset(train=True):
  # ----------------------
  # TRAIN/VAL DATALOADERS
  # ----------------------
    tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
    if train == True:
        train = processor.get_train_examples('MNLI')
        features = convert_examples_to_features(train,tokenizer,label_list=['contradiction','neutral','entailment'],
                            max_length=128,output_mode='classification',
                            pad_on_left=False,pad_token=tokenizer.pad_token_id,pad_token_segment_id=0)
        train_dataset = TensorDataset(torch.tensor([f.input_ids for f in features], dtype=torch.long), 
                                    torch.tensor([f.attention_mask for f in features], dtype=torch.long), 
                                    torch.tensor([f.token_type_ids for f in features], dtype=torch.long), 
                                    torch.tensor([f.label for f in features], dtype=torch.long))

        nb_train_samples = int(0.75 * len(train_dataset))
        nb_val_samples = len(train_dataset) - nb_train_samples

        bert_mnli_train_dataset, bert_mnli_val_dataset = random_split(train_dataset, [nb_train_samples, nb_val_samples])

      # train loader
        train_sampler = RandomSampler(bert_mnli_train_dataset)
        bert_mnli_train_dataloader = DataLoader(bert_mnli_train_dataset, sampler=train_sampler, batch_size=32)

      # val loader
        val_sampler = RandomSampler(bert_mnli_val_dataset)
        bert_mnli_val_dataloader = DataLoader(bert_mnli_val_dataset, sampler=val_sampler, batch_size=32)

  # ----------------------
  # TEST DATALOADERS
  # ----------------------
    dev = processor.get_dev_examples('MNLI')
    features = convert_examples_to_features(dev,tokenizer,label_list=['contradiction','neutral','entailment'],
                         max_length=128,output_mode='classification',
                         pad_on_left=False,pad_token=tokenizer.pad_token_id,pad_token_segment_id=0)

    bert_mnli_test_dataset = TensorDataset(torch.tensor([f.input_ids for f in features], dtype=torch.long), 
                                torch.tensor([f.attention_mask for f in features], dtype=torch.long), 
                                torch.tensor([f.token_type_ids for f in features], dtype=torch.long), 
                                torch.tensor([f.label for f in features], dtype=torch.long))

  # test dataset
    test_sampler = RandomSampler(bert_mnli_test_dataset)
    bert_mnli_test_dataloader = DataLoader(bert_mnli_test_dataset, sampler=test_sampler, batch_size=32)
  
    return bert_mnli_train_dataloader, bert_mnli_val_dataloader, bert_mnli_test_dataloader

In [0]:
def init_finetune_model():
    class fine_tune_model(nn.Module):
        def __init__(self):
            super().__init__()
            self.s= nn.Sequential(nn.Linear(768,512), nn.ReLU(), 
                                  nn.Linear(512,128), nn.ReLU(), nn.Dropout(0.1), nn.Linear(128,3))   
        def forward(self, data):
            logits=self.s(data)
            return logits
    return fine_tune_model()    

### Part B (fine-tune with BERT)

Use BERT as your feature extractor to finetune MNLI. Use a new finetune model (reset weights).

In [0]:
def fine_tune_mnli_BERT(BERT_feature_extractor, fine_tune_model, mnli_train, mnli_val):
    # YOUR CODE HERE
    current_device = torch.device("cuda")
    model = nn.Sequential(BERT_feature_extractor, fine_tune_model)
    model = model.to(current_device)
    unfreeze_model(model)
    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],lr=2e-5)
    criterion = torch.nn.CrossEntropyLoss(ignore_index=-1).to(current_device)

    model.train()
    for inp,mask,token,target in tqdm(mnli_train):
        inp = inp.to(current_device)
        mask = mask.to(current_device)
        token = token.to(current_device)
        target = target.to(current_device) 
        h, _, attn = model[0](input_ids=inp, 
                     attention_mask=mask, 
                     token_type_ids=token)
        h = h[:,0]
        logits = model[1](h)
        loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
    torch.save(model.state_dict(), "BERT_finetune.ckpt")

    model.eval()
    total=0
    correct=0
    with torch.no_grad():
        for inp,mask,token,target in mnli_val:
            inp = inp.to(current_device)
            mask = mask.to(current_device)
            token = token.to(current_device)
            target = target.to(current_device) 
            h, _, attn = model[0](input_ids=inp,attention_mask=mask, 
                         token_type_ids=token)
            h = h[:,0]
            logits = model[1](h)
            outputs = F.softmax(logits,dim=1)
            predicted = outputs.max(1, keepdim=True)[1]
            total += target.size(0)
            correct += predicted.eq(target.view_as(predicted)).sum().item()
    print('Validation acc = {:.{prec}f}'.format((100 * correct / total), prec=4))
    pass

### Part C(Evaluate how well we did)

In [0]:
def calculate_mnli_test_accuracy_BERT(feature_extractor, fine_tune_model, mnli_test):   
    current_device = torch.device("cuda:0")
    correct = 0
    total = 0
    model = nn.Sequential(feature_extractor, fine_tune_model)
    model = model.to(current_device)
    model.eval()    
    for inp,mask,token,target in mnli_test:
        inp = inp.to(current_device)
        mask = mask.to(current_device)
        token = token.to(current_device)
        target = target.to(current_device) 
        h, _, attn = model[0](input_ids=inp, attention_mask=mask, 
                     token_type_ids=token)
        h = h[:,0]
        logits = model[1](h)
        outputs = F.softmax(logits,dim=1)
        predicted = outputs.max(1, keepdim=True)[1]
        total += target.size(0)
        correct += predicted.eq(target.view_as(predicted)).sum().item()  
    return 100*correct/total

### Let's grade your BERT results!

In [0]:
def grade_mnli_BERT():
    BERT_feature_extractor = init_bert()
    
    # load data
    mnli_train, mnli_val, mnli_test = init_mnli_dataset()

    # init the fine_tune model
    fine_tune_model = init_finetune_model()
    
    # finetune
    fine_tune_mnli_BERT(BERT_feature_extractor, fine_tune_model, mnli_train, mnli_val)

    # check test accuracy
    test_accuracy = calculate_mnli_test_accuracy_BERT(BERT_feature_extractor, fine_tune_model, mnli_test)
    
    # the real threshold will be released by Oct 11 
    assert test_accuracy > 0.8, 'ummm... your accuracy is too low...'
    
    return test_accuracy
    
grade_mnli_BERT() 

100%|██████████| 313/313 [00:00<00:00, 239390.44B/s]
100%|██████████| 435779157/435779157 [00:07<00:00, 55074928.32B/s]
100%|██████████| 213450/213450 [00:00<00:00, 2530917.75B/s]
100%|██████████| 9204/9204 [3:26:18<00:00,  1.33s/it]


Validation acc = 82.1555


82.26184411614875

In [0]:
model = nn.Sequential(init_bert(), init_finetune_model())
model.load_state_dict(torch.load('BERT_finetune.ckpt'))
mnli_train, mnli_val, mnli_test = init_mnli_dataset()
test_accuracy = calculate_mnli_test_accuracy_BERT(model[0], model[1], mnli_test)
print('Test Accuracy is: {}'.format(test_accuracy))

100%|██████████| 313/313 [00:00<00:00, 199941.69B/s]
100%|██████████| 435779157/435779157 [00:07<00:00, 55102206.01B/s]
100%|██████████| 213450/213450 [00:00<00:00, 5479839.08B/s]


Test Accuracy is: 82.26184411614875


In [83]:
!md5sum BERT_finetune.ckpt

68b209b1ccaa6ace256983d4e83580c9  BERT_finetune.ckpt
