In [None]:
%matplotlib inline

# Mini-Bert:Using Distillation Knowledge forTraining a Minimal Bert Model




## Define the model

We train a ``nn.TransformerEncoder`` model on a
language modeling task. The ``nn.TransformerEncoder`` consists of multiple layers of ``nn.TransformerEncoderLayer``. A square attention mask is required because the
self-attention layers in ``nn.TransformerEncoder`` are only allowed to attend
the earlier positions in the sequence. For the language modeling task, any
tokens on the future positions should be masked. 

To produce a probability
distribution over output words, the output of the ``nn.TransformerEncoder``
model is passed through a linear layer followed by a log-softmax function.




In [3]:
import math
from typing import Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

class TransformerModel(nn.Module):

    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, dropout: float = 0.5):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, ntoken)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
        """
        Args:
            src: Tensor, shape [seq_len, batch_size]
            src_mask: Tensor, shape [seq_len, seq_len]

        Returns:
            output Tensor of shape [seq_len, batch_size, ntoken]
        """
        src = self.encoder(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        return output


def generate_square_subsequent_mask(sz: int) -> Tensor:
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

``PositionalEncoding`` module injects  information about the
relative or absolute position of the tokens in the sequence. The
positional encodings have the same dimension as the embeddings so that
the two can be summed.We use ``sine`` and ``cosine`` functions to force having 
different frequencies.




In [4]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

# Load data




Our training and testing data is a pre curared plain text file called **wikipedia16-large.txt**. This is a 1.8GB size plain text file thatcontains  a  total  of 19 499 139 sentences. Notice that we are using a 70 percent of the total dataset for training the student model.

In [3]:
from transformers import BertTokenizer
import torch
import random
import numpy as np

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class SentenceIterator:

    def __init__(self, dataset, batch_size, device, n = 0):
        self.dataset = dataset
        self.batch_size = batch_size
        self.total_lines = len(dataset)
        self.device = device
        self.n = n
    def __iter__(self):
        return self

    def __next__(self):
        if self.n <= self.total_lines:
            batch = tokenizer(
                      self.dataset[self.n : self.n + self.batch_size],
                      padding=True, 
                      truncation=True,
                      add_special_tokens=True, # Add '[CLS]' and '[SEP]'
                      return_token_type_ids=True,
                      return_attention_mask=True,
                      return_tensors='pt',  # Return PyTorch tensors
                    ).to(self.device)
            self.n += self.batch_size
            target = batch["input_ids"]
            batch["input_ids"] = torch.tensor([np.concatenate((item[0:np.where(item == 102)[0][0]-1], [103], item[np.where(item == 102)[0][0]:len(item)]),axis=0) for item in batch["input_ids"].to("cpu").numpy()], dtype=torch.long).to(self.device)
            return batch, target
        else:
            raise StopIteration


def get_data(file_path, train_size):
  with open(file_path, "r",encoding="utf-8") as f:
    sentences = f.readlines()
    random.shuffle(sentences)
    len_sentences = len(sentences)
    len_train_sentences = int(len_sentences * train_size)
    len_val_sentences = int((len_sentences - len_train_sentences) // 2)
    train_sentences = sentences[0:len_train_sentences]
    val_sentences = sentences[len_train_sentences:len_train_sentences + len_val_sentences]
    test_sentences = sentences[len_train_sentences + len_val_sentences:len_sentences]
    del sentences
    return train_sentences, val_sentences, test_sentences

train_data, val_data, test_data = get_data(file_path = "wikipedia16-large.txt", train_size = 0.7)

#train_iterable = SentenceIterator(dataset = train_data, batch_size=35, device = device)
#train_iterator = iter(train_iterable)

#batch_input, batch_target = next(train_iterator)
#print(batch_input)
#print(batch_input["input_ids"].shape)
#print(batch_target.shape)
#for batch_input, batch_target in train_iterator:
#  print(batch_input.shape, batch_target.shape)
#  break

## Functions to generate input and target sequence




``get_batch()`` generates a pair of input-target sequences for
the transformer model. It subdivides the source data into chunks of
length ``bptt``. For the language modeling task, the model needs the
following words as ``Target``. For example, with a ``bptt`` value of 2,
we’d get the following two Variables for ``i`` = 0:

![](https://github.com/pytorch/tutorials/blob/gh-pages/_downloads/_static/img/transformer_input_target.png?raw=1)


It should be noted that the chunks are along dimension 0, consistent
with the ``S`` dimension in the Transformer model. The batch dimension
``N`` is along dimension 1.




In [None]:
#This is not necessary
bptt = 35
def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
    """
    Args:
        source: Tensor, shape [full_seq_len, batch_size]
        i: int

    Returns:
        tuple (data, target), where data has shape [seq_len, batch_size] and
        target has shape [seq_len * batch_size]
    """
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

## Initialize Hyper Parameters and Variables




The model hyperparameters are defined below. The vocab size is
equal to the length of the vocab object.




In [None]:
ntokens = 30522  # size of vocabulary in BERT
batch_size = 20
emsize = 512  # embedding dimension
d_hid = emsize * 4  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2  # number of heads in nn.MultiheadAttention
dropout = 0.2  # dropout probability
temperature = 0.5 # temperature for kd
alpha = 0.3 # how important is the output of the student
from os import path
path_file = "/content/drive/MyDrive/Project2/transformer1.pt"
model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)
if path.exists(path_file):
  model.load_state_dict(torch.load(path_file, map_location=torch.device(device)))


## Load Teacher Model 

For our teacher model we are using the **bert-base-uncased** variant. 

In [5]:
from transformers import BertForMaskedLM

teacher_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
teacher_model.eval()
teacher_model.to(device)



BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

## Student Model training




We use CrossEntropyLoss with  Stochastic Gradient Descent as the optimizer. The learning rate is initially set to 5.0. We use `nn.utils.clip_grad_norm\_ <https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html>`__
to prevent gradients from exploding.

During the training, a single batch consists of 200 sentences. In this particular case we have a total of 19 499 139 sentences wich is about 1.8GB of text data. 







In [None]:
import copy
import time

criterion = nn.CrossEntropyLoss()
criterionKL = nn.KLDivLoss(reduction = "batchmean")
lr = 5.0  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

def dist_loss(t, s, T):
    prob_t = F.softmax(t/T, dim=1)
    log_prob_s = F.log_softmax(s/T, dim=1)
    dist_loss = -(prob_t*log_prob_s).sum(dim=1).mean()
    return dist_loss

def train(model: nn.Module) -> None:
    model.train()  # turn on train mode
    total_loss = 0.
    log_interval = 200
    start_time = time.time()
    src_mask = generate_square_subsequent_mask(batch_size).to(device)
    num_batches = len(train_data) // batch_size
    batch = 0
    train_iterable = SentenceIterator(dataset = train_data, batch_size=batch_size, device = device)
    train_iterator = iter(train_iterable)
    for batch_input, batch_target in train_iterator:

        current_batch_size = batch_input["input_ids"].size(0)
        if batch_size != current_batch_size:  # only on last batch
            src_mask = src_mask[:current_batch_size, :current_batch_size]

        with torch.no_grad():
          output_teacher = teacher_model(**batch_input, labels=batch_target)
        
        output = model(batch_input["input_ids"], src_mask)
        student_loss = criterion(output.view(-1, ntokens), batch_target.reshape(-1))

        #ditillation_loss = criterionKL(
        #    F.log_softmax(output / temperature, dim=1),
        #    F.softmax(output_teacher.logits / temperature, dim=1)
        #)

        #loss = alpha * student_loss + (1 - alpha) * ditillation_loss

        distillation_loss = dist_loss(output_teacher.logits, output, temperature)

        loss = student_loss + distillation_loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = math.exp(cur_loss)
            print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
                  f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '
                  f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')
            total_loss = 0
            start_time = time.time()
        if batch % 1000 == 0 and batch > 0:
            torch.save(model.state_dict(), path_file)
        batch += 1

def evaluate(model: nn.Module, eval_data) -> float:
    model.eval()  # turn on evaluation mode
    total_loss = 0.
    src_mask = generate_square_subsequent_mask(batch_size).to(device)
    with torch.no_grad():
        val_iterable = SentenceIterator(dataset = eval_data, batch_size=batch_size, device = device)
        val_iterator = iter(val_iterable)
        for batch_input, batch_target in val_iterator:

            current_batch_size = batch_input["input_ids"].size(0)
            if batch_size != current_batch_size:  # only on last batch
                src_mask = src_mask[:current_batch_size, :current_batch_size]

            output = model(batch_input["input_ids"], src_mask)
            output_flat = output.view(-1, ntokens)
            total_loss += batch_size * criterion(output_flat, batch_target).item()
    return total_loss / (len(eval_data) - 1)

Loop over epochs. Save the model if the validation loss is the best
we've seen so far. Adjust the learning rate after each epoch.



In [None]:
best_val_loss = float('inf')
epochs = 1
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(model)
    val_loss = evaluate(model, val_data)
    val_ppl = math.exp(val_loss)
    elapsed = time.time() - epoch_start_time
    print('-' * 89)
    print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
          f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        torch.save(best_model.state_dict(), path_file)
    scheduler.step()

| epoch   1 |   200/682469 batches | lr 5.00 | ms/batch 44.85 | loss  7.91 | ppl  2712.08
| epoch   1 |   400/682469 batches | lr 5.00 | ms/batch 43.56 | loss  5.35 | ppl   210.74
| epoch   1 |   600/682469 batches | lr 5.00 | ms/batch 43.49 | loss  5.01 | ppl   149.50
| epoch   1 |   800/682469 batches | lr 5.00 | ms/batch 42.72 | loss  4.71 | ppl   111.57
| epoch   1 |  1000/682469 batches | lr 5.00 | ms/batch 43.01 | loss  4.57 | ppl    96.61
| epoch   1 |  1200/682469 batches | lr 5.00 | ms/batch 45.31 | loss  4.50 | ppl    90.42
| epoch   1 |  1400/682469 batches | lr 5.00 | ms/batch 43.67 | loss  4.35 | ppl    77.30
| epoch   1 |  1600/682469 batches | lr 5.00 | ms/batch 43.01 | loss  4.25 | ppl    69.91
| epoch   1 |  1800/682469 batches | lr 5.00 | ms/batch 43.11 | loss  4.17 | ppl    64.64
| epoch   1 |  2000/682469 batches | lr 5.00 | ms/batch 44.43 | loss  4.12 | ppl    61.33
| epoch   1 |  2200/682469 batches | lr 5.00 | ms/batch 46.21 | loss  4.08 | ppl    59.24
| epoch   

Evaluate the best model on the test dataset
-------------------------------------------




In [None]:
test_loss = evaluate(best_model, test_data)
test_ppl = math.exp(test_loss)
print('=' * 89)
print(f'| End of training | test loss {test_loss:5.2f} | '
      f'test ppl {test_ppl:8.2f}')
print('=' * 89)

| End of training | test loss  0.22 | test ppl     1.25


## Student Model Evaluation

In [11]:
#torch.save(best_model.state_dict(), "/content/drive/MyDrive/Project2/transformer3.pt")
from transformers import BertTokenizer, BertForMaskedLM
import torch

device = 'cpu'#torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
teacher_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
teacher_model.to(device)

ntokens = 30522  # size of vocabulary in BERT
batch_size = 20
emsize = 512  # embedding dimension
d_hid = emsize * 4  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2  # number of heads in nn.MultiheadAttention
dropout = 0.2  # dropout probability
path_file = "/content/drive/MyDrive/Project2/transformer1.pt"

best_model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout)
best_model.load_state_dict(torch.load(path_file, map_location=torch.device(device)))
sentence = ["Paris is the capital of [MASK]."]
target = tokenizer(["Paris is the capital of France."], return_tensors='pt')["input_ids"].to(device)

src_mask = generate_square_subsequent_mask(1).to(device)
sentence_data = tokenizer(
                      sentence,
                      padding=True, 
                      truncation=True,
                      add_special_tokens=True, # Add '[CLS]' and '[SEP]'
                      return_token_type_ids=True,
                      return_attention_mask=True,
                      return_tensors='pt',  # Return PyTorch tensors
                    ).to(device)
print(sentence_data["input_ids"])
print(tokenizer.convert_ids_to_tokens(sentence_data["input_ids"][0]))
best_model.eval()
with torch.no_grad():
  print("input size", sentence_data["input_ids"].shape)
  print("src mask", src_mask)
  output = best_model(sentence_data["input_ids"].to(device), src_mask.to(device))#.reshape(-1, ntokens)
  """
  output_flat = output.view(-1, ntokens)
  print("output size", output.shape)
  print(output)
  print("output flat size", output_flat.shape)
  print(output_flat)
  print(nn.Softmax(dim=1)(output_flat))
  result_index = torch.argmax(nn.Softmax(dim=1)(output_flat), dim=1)
  print(result_index)
  print(sentence)
  print(tokenizer.convert_ids_to_tokens(sentence_data[0]))
  print(tokenizer.convert_ids_to_tokens(result_index))
  """
  output_teacher = teacher_model(**sentence_data, labels=target)

  masked_index = 6

  print("Teacher prediction")
  predicted_index = torch.argmax(output_teacher.logits[0, masked_index]).item()
  print(predicted_index)
  predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])
  print(predicted_token)


  print("Student prediction")
  predicted_index = torch.argmax(output.reshape(-1, ntokens), dim=1)
  print(predicted_index)
  predicted_token = tokenizer.convert_ids_to_tokens(predicted_index)
  print(predicted_token)



Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tensor([[ 101, 3000, 2003, 1996, 3007, 1997,  103, 1012,  102]])
['[CLS]', 'paris', 'is', 'the', 'capital', 'of', '[MASK]', '.', '[SEP]']
input size torch.Size([1, 9])
src mask tensor([[0.]])
Teacher prediction
2605
['france']
Student prediction
tensor([ 101, 3000, 2003, 1996, 3007, 1997, 1012, 1012,  102])
['[CLS]', 'paris', 'is', 'the', 'capital', 'of', '.', '.', '[SEP]']
