<center><h3>**Welcome to the Knowledge Distillation Notebook.**</h3></center>

This notebook is an experimental part of the homework and not worth points. It is not guaranteed to work correctly. 

A trend in Natural Language Processing is to pretrain large models that can then be fine-tuned for specific problems. However the state-of-the-art models can be quite large: the "base" BERT model has 110M parameters and the "large" BERT model has 350M parameters! In many applications, such as client-side mobile apps, we do not have the compute to run the BERT model even in an evaluation setting.

Here we look at a method for reducing model size, called Knowledge Distillation. Specifically, we will follow the paper __[TinyBERT: Distilling BERT for Natural Language Understanding](https://arxiv.org/abs/1909.10351)__.
In this assignment you will:
- Use an off-the-shelf API to replicate a paper method
- Implement loss functions for KD as specified by the paper

**Before You Get Started**

Read the Paper. Also, the API we will be using is the Transformers API released by HuggingFace. It may be helpful to look at the __[documentation](https://huggingface.co/transformers/)__.

# Library Imports

In [None]:
#block the output to keep not pollute the notebook
%%capture 
!pip install transformers
!pip install datasets

In [None]:
from google.colab import drive
drive.mount('/content/drive')
DRIVE=True

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
root_folder = "" if not DRIVE else "/content/drive/My Drive/cs182_hw3/"
import os
import sys
sys.path.append(root_folder)
import json
from utils import validate_to_array, model_out_to_list
import torch as th
from torch.nn import functional as F
from torch import nn
from torch import optim
import numpy as np
import math
device = th.device("cuda" if th.cuda.is_available() else "cpu")
# device = th.device("cpu")
print(device)

from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig, BertForPreTraining

# BERT Architecture

First load the BERT base model and take a look at the architecture. Don't mind the warnings for now. Based on the nn.Module names, what major component from the Transformer architecture in "Attention is All You Need" is substantially smaller in the BERT model? What is the purpose of the component?

Answer:

In [None]:
%%capture
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased",do_lower_case=True)
teacher_model = BertForMaskedLM.from_pretrained("bert-base-uncased").to(device)

In [None]:
print(teacher_model)

# KD Losses

First, we need to access the intermediate layer outputs of the model. Read section 3 of the paper and take a look at the documentation for __[the forward function](https://huggingface.co/transformers/model_doc/bert.html#transformers.BertModel.forward)__, or look at the docstring below. Fill in the kwargs to retrieve the necessary outputs from the model. Note that the returning the embedding is not an option, you can retrieve the embeddings via a method attribute of BERT, `get_input_embeddings(self)`. Consider what an Embedding is; why wouldn't we need to return an embedding for every sample in a batch?

In [None]:
forward_kwargs = dict(
    
    
    
    
    
    return_dict=True,
)

In [None]:
help(teacher_model.forward)

Implement to EmbeddingLayerLoss, AttentionLayerLoss, HiddenLayerLoss, PredictionLoss, and KnowledgeDistillationLoss as specified in section 3 of the paper. The output of BERT will be a dictionary, look at 'return' in the documentation or the docstring for the relevant keys in the return dictionary. We will add the embedding in under 'embeddings'

## (1) Implementing the Attention Layer Loss

This part is located in AttentionLayerLoss in kd_loss.py. You must implement the call function of the class. You will need to implement the formula (7) in section 3 of __[TinyBERT](https://arxiv.org/pdf/1909.10351.pdf)__. Note that the actual implemetation compares raw output from attention. The transformers API returns the softmax output.

In [None]:
from kd_loss import AttentionLayerLoss
num_channels = 10
batch_size = 2


with open(root_folder+"kd_checks/kd_attention_loss.json",'r') as f:
  io = json.load(f)
  teacher_attn = th.tensor(io['teacher_attention'])
  student_attn = th.tensor(io['student_attention'])
  expected_output = th.tensor(io['expected_output'])

attn_loss = AttentionLayerLoss()
output = attn_loss(teacher_attn, student_attn)
validate_to_array(model_out_to_list,((teacher_attn,student_attn),attn_loss),'kdattnloss', root_folder)
print("Total error on the output:",th.sum(th.abs(expected_output-output)).item(), "(should be 0.0 or close to 0.0)")

## (2) Implementing the Hidden Layer Loss

This part is located in HiddenLayerLoss in kd_loss.py. You must implement the call function of the class. You will need to implement the formula (8) in section 3 of __[TinyBERT](https://arxiv.org/pdf/1909.10351.pdf)__

In [None]:
from kd_loss import HiddenLayerLoss
teacher_hidden_dim = 50
student_hidden_dim = 10
batch_size = 2

with open(root_folder+"kd_checks/kd_hidden_loss.json",'r') as f:
  io = json.load(f)
  teacher_hddn = th.tensor(io['teacher_hidden'])
  student_hddn = th.tensor(io['student_hidden'])
  expected_output = th.tensor(io['expected_output'])

hddn_loss = HiddenLayerLoss(teacher_hidden_dim,student_hidden_dim)
hddn_loss.load_state_dict(th.load(root_folder+"kd_checks/kd_hidden_loss"))
output = hddn_loss(teacher_hddn, student_hddn)
validate_to_array(model_out_to_list,((teacher_hddn,student_hddn),hddn_loss),'kdhddnloss', root_folder)
print("Total error on the output:",th.sum(th.abs(expected_output-output)).item(), "(should be 0.0 or close to 0.0)")

## (3) Implementing the Embedding Layer Loss

This part is located in EmbedLayerLoss in kd_loss.py. You must implement the call function of the class. You will need to implement the formula (9) in section 3 of __[TinyBERT](https://arxiv.org/pdf/1909.10351.pdf)__

In [None]:
from kd_loss import EmbeddingLayerLoss
teacher_embed_dim = 50
student_embed_dim = 10
batch_size = 2

with open(root_folder+"kd_checks/kd_embed_loss.json",'r') as f:
  io = json.load(f)
  teacher_embd = th.tensor(io['teacher_embed'])
  student_embd = th.tensor(io['student_embed'])
  expected_output = th.tensor(io['expected_output'])

embd_loss = EmbeddingLayerLoss(teacher_embed_dim,student_embed_dim)
embd_loss.load_state_dict(th.load(root_folder+"kd_checks/kd_embed_loss"))
output = embd_loss(teacher_embd, student_embd)
validate_to_array(model_out_to_list,((teacher_embd,student_embd),embd_loss),'kdembdloss', root_folder)
print("Total error on the output:",th.sum(th.abs(expected_output-output)).item(), "(should be 0.0 or close to 0.0)")

## (4) Implementing the Prediction Loss

This part is located in PredictionLoss in kd_loss.py. You must implement the call function of the class. You will need to implement the formula (10) in section 3 of __[TinyBERT](https://arxiv.org/pdf/1909.10351.pdf)__

In [None]:
from kd_loss import PredictionLoss
word_count = 10
batch_size = 2

with open(root_folder+"kd_checks/kd_pred_loss.json",'r') as f:
  io = json.load(f)
  teacher_pred = th.tensor(io['teacher_pred'])
  student_pred = th.tensor(io['student_pred'])
  expected_output = th.tensor(io['expected_output'])

pred_loss = PredictionLoss()
output = pred_loss(teacher_pred, student_pred)
validate_to_array(model_out_to_list,((teacher_pred,student_pred),pred_loss),'kdpredloss', root_folder)
print("Total error on the output:",th.sum(th.abs(expected_output-output)).item(), "(should be 0.0 or close to 0.0)")

## (5) Implementing the Knowledge Distillation Loss

This part is located in KnowledgeDistillationLoss in kd_loss.py. You must implement the call function of the class. You will need to implement the formula (11) in section 3 of __[TinyBERT](https://arxiv.org/pdf/1909.10351.pdf)__

In [None]:
from kd_loss import KnowledgeDistillationLoss
num_channels = 12
teacher_hidden_dim = 60
student_hidden_dim = 15
teacher_embed_dim = 50
student_embed_dim = 10
word_count = 5
teacher_num_blocks = 6
student_num_blocks = 2
batch_size = 2
layer_mapping = range(2,6,3)

with open(root_folder+"kd_checks/kd_loss.json",'r') as f:
  io = json.load(f)
  teacher_out = io['teacher_out']
  student_out = io['student_out']
  teacher_out = dict(
      embeddings=th.tensor(teacher_out['embeddings']),
      attentions=[th.tensor(o) for o in teacher_out['attentions']],
      hidden_states=[th.tensor(o) for o in teacher_out['hidden_states']],
      logits=th.tensor(teacher_out['embeddings'])
  )
  student_out = dict(
      embeddings=th.tensor(student_out['embeddings']),
      attentions=[th.tensor(o) for o in student_out['attentions']],
      hidden_states=[th.tensor(o) for o in student_out['hidden_states']],
      logits=th.tensor(student_out['embeddings'])
  )
  expected_output = th.tensor(io['expected_output'])

kd_loss = KnowledgeDistillationLoss(teacher_embed_dim,student_embed_dim,teacher_hidden_dim,student_hidden_dim,layer_mapping)
kd_loss.load_state_dict(th.load(root_folder+"kd_checks/kd_loss"))
output = kd_loss(teacher_out, student_out)
validate_to_array(model_out_to_list,((teacher_out,student_out),kd_loss),'kdloss', root_folder)
print("Total error on the output:",th.sum(th.abs(expected_output-output)).item(), "(should be 0.0 or close to 0.0)")

#Experimental Setup

##General Distillation

###Data Loading

Below is the text parsing set up. We will be using the wikitext dataset as used in the paper. But since we are just demonstrating the method, we will use the small wikitext-2 dataset instead of the standard wikitext-103 set. Wikitext contains thousands of cleaned English Wikipedia articles separated by sentence. Since the order of sentences is left in tact, the dataset can be used to model long term dependencies between words.

In [None]:
%%capture
from datasets import load_dataset
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')

See that the data has been split into training, validation, and testing sets.

In [None]:
print(datasets)

A sample of the dataset:

In [None]:
print("".join(datasets['train'][:100]['text']))

The words must be parsed and hashed according to the vocabulary of our model. Instead of masking sentences to equal length, this time we will separate the contiguous text sequence into equal size blocks, possibly breaking up whole sentences.

In [None]:
%%capture
tokenized_datasets = datasets.map(lambda samples: tokenizer(samples['text']), batched=True, num_proc=4, remove_columns=["text"])

In [None]:
block_size = 128
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [None]:
%%capture
lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
    num_proc=4,
    load_from_cache_file=False
)

In [None]:
tokenizer.decode(lm_datasets["train"][1]["input_ids"])

In [None]:
lm_datasets["train"][1].keys()

### Set Up Student Model

Fill in the dimensions of BERT and the student network as specified in section 4.2 of the paper

In [None]:
vocab_size = int(1e4)
teacher_hddn_dim = 
student_hddn_dim = 
teacher_num_hddn_layers = 
student_num_hddn_layers =  
teacher_num_attn_heads = 
student_num_attn_heads = 
teacher_ff_dim = 
student_ff_dim = 
teacher_embd_dim = 
student_embd_dim = 
layer_mapping = range(,
                      ,
                      )

student_config = BertConfig(
    hidden_size=student_hddn_dim,
    num_hidden_layers=student_num_hddn_layers,
    num_attention_heads=student_num_attn_heads,
    intermediate_size=student_ff_dim,
)

###Training

In [None]:
from kd_loss import KnowledgeDistillationLoss
teacher_model = BertForMaskedLM.from_pretrained("bert-base-uncased").to(device)
teacher_model.load_state_dict(th.load(root_folder+'bert_models/teacher_wikitext.pt'))
student_model = BertForMaskedLM(student_config).to(device)
criterion = KnowledgeDistillationLoss(teacher_embd_dim,student_embd_dim,teacher_hddn_dim,student_hddn_dim,layer_mapping).to(device)


In [None]:
from tqdm.notebook import tqdm
import gc
gc.collect()
optimizer = optim.Adam(params=student_model.parameters(),lr=5e-5,weight_decay=0.01)
student_model.to(device)
lr = 1e-4
batch_size = 10
epochs=10
for epoch in range(epochs):
    lm_datasets["train"].shuffle(load_from_cache_file=False)
    t = tqdm(range(0,len(lm_datasets["train"]),batch_size))
    accuracies = []
    losses = []
    for i in t:
        data = lm_datasets["train"][i:i+batch_size]
        data = {k: th.tensor(v).to(device) for k,v in data.items()}
        teacher_out = teacher_model(**data,**forward_kwargs)
        student_out = student_model(**data,**forward_kwargs)
        teacher_out['embeddings'] = teacher_model.get_input_embeddings().weight
        student_out['embeddings'] = student_model.get_input_embeddings().weight
        loss = criterion(teacher_out,student_out,penalize_prediction=False)
        losses.append(loss.detach().cpu().numpy())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        accuracy = th.eq(student_out['logits'].argmax(dim=2,keepdim=False).float(),data['labels']).float().mean()
        accuracies.append(accuracy.detach().cpu().numpy())
        loss = np.around(np.mean(losses[-100:]),3)
        accuracy = np.around(np.mean(accuracies[-100:]),2)
        t.set_description("Epoch: "+str(epoch)+" Loss: "+str(loss))
    os.makedirs(root_folder+'bert_models',exist_ok=True)
    th.save(student_model.state_dict(),root_folder+'bert_models/student_wikitext.pt')

We next train a control, which is just the same BERT shaped student model trained from scratch on general and task specific data.

In [None]:
control_model = BertForMaskedLM(student_config).to(device)
gc.collect()
control_model.train()
optimizer = optim.Adam(params=student_model.parameters(),lr=lr,weight_decay=0.01)
lr = 1e-4
batch_size = 10
epochs=1
for epoch in range(epochs):
    lm_datasets["train"].shuffle()
    t = tqdm(range(0,len(lm_datasets["train"]),batch_size))
    accuracies = []
    losses = []
    for i in t:
        data = lm_datasets["train"][i:i+batch_size]
        data = {k: th.tensor(v).to(device) for k,v in data.items()}
        student_out = student_model(**data,**forward_kwargs)
        losses.append(student_out['loss'].detach().cpu().numpy())
        
        optimizer.zero_grad()
        student_out['loss'].backward()
        optimizer.step()
        accuracy = th.eq(student_out['logits'].argmax(dim=2,keepdim=False).float(),data['labels']).float().mean()
        accuracies.append(accuracy.detach().cpu().numpy())
        loss = np.around(np.mean(losses[-100:]),3)
        accuracy = np.around(np.mean(accuracies[-100:]),2)
        t.set_description("Epoch: "+str(epoch)+" Loss: "+str(loss))
    os.makedirs(root_folder+'bert_models',exist_ok=True)
    th.save(control_model.state_dict(),root_folder+'bert_models/control_wikitext.pt')

##Task Specific Distillation

###Data Loading

In [None]:
%%capture
from datasets import load_dataset
datasets = load_dataset('glue', 'mrpc')

See that the data has been split into training, validation, and testing sets.

In [None]:
datasets

In [None]:
print(datasets['train'][0]['sentence1'], datasets['train'][0]['sentence2'])

In [None]:
%%capture
mrpc_tok = datasets.map(lambda samples: tokenizer(samples['sentence1'], samples['sentence2'],padding='max_length',max_length=150),
                       remove_columns=['sentence1', 'sentence2','idx'],
                       load_from_cache_file=False,
                      )

In [None]:
%%capture
def filter_texts(examples):
    examples["labels"] = examples["label"].copy()
    examples.pop('label',None)
    return examples
mrpc = mrpc_tok.map(
    filter_texts,
    batched=True,
    batch_size=1000,
    num_proc=4,
    load_from_cache_file=False
)

###Training

In [None]:
from kd_loss import KnowledgeDistillationLoss
from transformers import BertForNextSentencePrediction, BertForSequenceClassification
teacher_model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased").to(device)
teacher_model.load_state_dict(th.load(root_folder+'bert_models/teacher_mrpc.pt'))
student_model = BertForNextSentencePrediction(student_config).to(device)
student_model.load_state_dict(th.load(root_folder+'bert_models/student_wikitext.pt'),strict=False)
criterion = KnowledgeDistillationLoss(teacher_embd_dim,student_embd_dim,teacher_hddn_dim,student_hddn_dim,layer_mapping).to(device)


In [None]:
from tqdm.notebook import tqdm
import gc
gc.collect()
optimizer = optim.Adam(params=student_model.parameters(),lr=5e-5,weight_decay=0.01)
student_model.to(device)
lr = 1e-4
batch_size = 10
epochs=10
for epoch in range(epochs):
    mrpc["train"].shuffle(load_from_cache_file=False)
    t = tqdm(range(0,len(mrpc["train"]),batch_size))
    accuracies = []
    losses = []
    for i in t:
        data = mrpc["train"][i:i+batch_size]
        data = {k: th.tensor(v).to(device) for k,v in data.items()}
        teacher_out = teacher_model(**data,**forward_kwargs)
        student_out = student_model(**data,**forward_kwargs)
        teacher_out['embeddings'] = teacher_model.get_input_embeddings().weight
        student_out['embeddings'] = student_model.get_input_embeddings().weight
        loss = criterion(teacher_out,student_out,penalize_prediction=True)
        losses.append(loss.detach().cpu().numpy())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        accuracy = th.eq(student_out['logits'].argmax(dim=1,keepdim=False).float(),data['labels']).float().mean()
        accuracies.append(accuracy.detach().cpu().numpy())
        loss = np.around(np.mean(losses[-100:]),3)
        accuracy = np.around(np.mean(accuracies[-100:]),2)
        t.set_description("Epoch: "+str(epoch)+" Loss: "+str(loss)+" Accuracy: "+str(accuracy))
    os.makedirs(root_folder+'bert_models',exist_ok=True)
    th.save(student_model.state_dict(),root_folder+'bert_models/student_mrpc.pt')

We next train a control, which is just the same BERT shaped student model trained from scratch on general and task specific data.

In [None]:
control_model = BertForNextSentencePrediction(student_config).to(device)
control.load_state_dict(th.load(root_folder+'bert_models/control_wikitext.pt'),strict=False)
gc.collect()
control_model.train()
optimizer = optim.Adam(params=student_model.parameters(),lr=lr,weight_decay=0.01)
lr = 1e-4
batch_size = 10
epochs=1
for epoch in range(epochs):
    mrpc["train"].shuffle()
    t = tqdm(range(0,len(mrpc["train"]),batch_size))
    accuracies = []
    losses = []
    for i in t:
        data = mrpc["train"][i:i+batch_size]
        data = {k: th.tensor(v).to(device) for k,v in data.items()}
        student_out = student_model(**data,**forward_kwargs)
        losses.append(student_out['loss'].detach().cpu().numpy())
        
        optimizer.zero_grad()
        student_out['loss'].backward()
        optimizer.step()
        accuracy = th.eq(student_out['logits'].argmax(dim=2,keepdim=False).float(),data['labels']).float().mean()
        accuracies.append(accuracy.detach().cpu().numpy())
        loss = np.around(np.mean(losses[-100:]),3)
        accuracy = np.around(np.mean(accuracies[-100:]),2)
        t.set_description("Epoch: "+str(epoch)+" Loss: "+str(loss))
    os.makedirs(root_folder+'bert_models',exist_ok=True)
    th.save(control_model.state_dict(),root_folder+'bert_models/control_wikitext.pt')

# Preprocessing

##BERT Wikitext Specific Training

In [None]:
%%capture
from datasets import load_dataset
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')
tokenized_datasets = datasets.map(lambda samples: tokenizer(samples['text']), batched=True, num_proc=4, remove_columns=["text"])

In [None]:
block_size = 128
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [None]:
%%capture
lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
    num_proc=4,
    load_from_cache_file=False
)

In [None]:
from tqdm.autonotebook import tqdm
import gc
gc.collect()
teacher_model.train()
optimizer = optim.Adam(params=teacher_model.parameters(),lr=lr,weight_decay=0.01)
lr = 1e-4
batch_size = 10
epochs=1
for epoch in range(epochs):
    lm_datasets["train"].shuffle()
    t = tqdm(range(0,len(lm_datasets["train"]),batch_size))
    accuracies = []
    losses = []
    for i in t:
        data = lm_datasets["train"][i:i+batch_size]
        data = {k: th.tensor(v).to(device) for k,v in data.items()}
        teacher_out = teacher_model(**data,**forward_kwargs)
        losses.append(teacher_out['loss'].detach().cpu().numpy())
        
        optimizer.zero_grad()
        teacher_out['loss'].backward()
        optimizer.step()
        accuracy = th.eq(teacher_out['logits'].argmax(dim=2,keepdim=False).float(),data['labels']).float().mean()
        accuracies.append(accuracy.detach().cpu().numpy())
        loss = np.around(np.mean(losses[-100:]),3)
        accuracy = np.around(np.mean(accuracies[-100:]),2)
        t.set_description("Epoch: "+str(epoch)+" Loss: "+str(loss)+" Accuracy: "+str(accuracy))

In [None]:
teacher_model.eval()
lm_datasets["validation"].shuffle()
t = tqdm(range(0,len(lm_datasets["validation"]),batch_size))
for i in t:
    data = lm_datasets["validation"][i:i+batch_size]
    data = {k: th.tensor(v).to(device) for k,v in data.items()}
    teacher_out = teacher_model(**data,**forward_kwargs)
    losses.append(teacher_out['loss'].detach().cpu().numpy())
    
    accuracy = th.eq(teacher_out['logits'].argmax(dim=2,keepdim=False).float(),data['labels']).float().mean()
    accuracies.append(accuracy.detach().cpu().numpy())
    loss = np.around(np.mean(losses),3)
    accuracy = np.around(np.mean(accuracies),2)
    t.set_description("Validation - "+"Loss: "+str(loss)+" Accuracy: "+str(accuracy))

In [None]:
os.makedirs(root_folder+'bert_models',exist_ok=True)
th.save(teacher_model.state_dict(),root_folder+'bert_models/teacher_wikitext.pt')

##BERT MRPC Specific Training

In [None]:
%%capture
from datasets import load_dataset
datasets = load_dataset('glue', 'mrpc')
mrpc_tok = datasets.map(lambda samples: tokenizer(samples['sentence1'], samples['sentence2'],padding='max_length',max_length=150),
                       remove_columns=['sentence1', 'sentence2','idx'],
                       load_from_cache_file=False,
                      )

In [None]:
%%capture
def filter_texts(examples):
    examples["labels"] = examples["label"].copy()
    examples.pop('label',None)
    return examples
mrpc = mrpc_tok.map(
    filter_texts,
    batched=True,
    batch_size=1000,
    num_proc=4,
    load_from_cache_file=False
)

In [None]:
%%capture
from transformers import BertForNextSentencePrediction, BertForSequenceClassification
teacher_model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased").to(device)

In [None]:
from tqdm.autonotebook import tqdm
import gc
gc.collect()
teacher_model.train()
lr = 2e-5
batch_size = 10
epochs=10
optimizer = optim.Adam(params=teacher_model.parameters(),lr=lr,weight_decay=0.01)
for epoch in range(epochs):
    mrpc["train"].shuffle(load_from_cache_file=False)
    t = tqdm(range(0,len(mrpc["train"]),batch_size))
    accuracies = []
    losses = []
    for i in t:
        data = mrpc["train"][i:i+batch_size]
        data = {k: th.tensor(v).to(device) for k,v in data.items()}
        teacher_out = teacher_model(**data,**forward_kwargs)
        losses.append(teacher_out['loss'].detach().cpu().numpy())
        
        optimizer.zero_grad()
        teacher_out['loss'].backward()
        optimizer.step()
        accuracy = th.eq(teacher_out['logits'].argmax(dim=1,keepdim=False).float(),data['labels']).float().mean()
        accuracies.append(accuracy.detach().cpu().numpy())
        loss = np.around(np.mean(losses[-100:]),3)
        accuracy = np.around(np.mean(accuracies[-100:]),2)
        t.set_description("Epoch: "+str(epoch)+" Loss: "+str(loss)+" Accuracy: "+str(accuracy))

In [None]:
teacher_model.eval()
mrpc["validation"]
t = tqdm(range(0,len(mrpc["validation"]),batch_size))
for i in t:
    data = mrpc["validation"][i:i+batch_size]
    data = {k: th.tensor(v).to(device) for k,v in data.items()}
    teacher_out = teacher_model(**data,**forward_kwargs)
    losses.append(teacher_out['loss'].detach().cpu().numpy())
    
    accuracy = th.eq(teacher_out['logits'].argmax(dim=1,keepdim=False).float(),data['labels']).float().mean()
    accuracies.append(accuracy.detach().cpu().numpy())
    loss = np.around(np.mean(losses),3)
    accuracy = np.around(np.mean(accuracies),2)
    t.set_description("Validation - "+"Loss: "+str(loss)+" Accuracy: "+str(accuracy))

In [None]:
os.makedirs(root_folder+'bert_models',exist_ok=True)
th.save(teacher_model.state_dict(),root_folder+'bert_models/teacher_mrpc.pt')