In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import re
import csv
from collections import OrderedDict

import numpy as np
import torch
from torchtext import data
from torchtext.vocab import pretrained_aliases, Vocab

import spacy
from spacy.symbols import ORTH
import argparse
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from transformers import (BertConfig, BertForSequenceClassification, BertTokenizer)
from tqdm.autonotebook import tqdm, trange
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from tensorboardX import SummaryWriter

In [None]:
spacy_en = spacy.load("en")
spacy_en.tokenizer.add_special_case("<mask>", [{ORTH: "<mask>"}])
mask_token = "<mask>"

In [None]:
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)

#Interface of this class is similar to torchtext.data.Vocab so that BertVocab can be directly assigned to torchtext.data.Vocab for compatibility.
class BertVocab:
    UNK = '<unk>'
    def __init__(self, stoi):
        self.stoi = OrderedDict()
        #any token inside [] should be replaced by token inside <> for torchtext. 
        #Bert vocab has some special tokens inside [] but torchtext has inside <>.
        pattern = re.compile(r"\[(.*)\]") 
        for s, idx in stoi.items():
            s = s.lower()
            m = pattern.match(s)
            if m:
                content = m.group(1)
                s = "<%s>" % content
            self.stoi[s] = idx
        self.unk_index = self.stoi[BertVocab.UNK]
        self.itos = [(s, idx) for s, idx in self.stoi.items()] 
        self.itos.sort(key=lambda x: x[1])
        self.itos = [s for (s, idx) in self.itos]
    def _default_unk_index(self):
        return self.unk_index
    def __getitem__(self, token):
        return self.stoi.get(token, self.stoi.get(BertVocab.UNK))
    def __len__(self):
        return len(self.itos)

def spacy_tokenizer(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]

def load_tsv(path, skip_header=True):
    with open(path) as f:
        reader = csv.reader(f, delimiter='\t')
        if skip_header:
            next(reader)
        data = [row for row in reader]
    return data

def load_data(data_dir, tokenizer, vocab=None, batch_first=False, augmented=False, use_teacher=False):
    text_field = data.Field(sequential=True, tokenize=tokenizer, lower=True, include_lengths=True, batch_first=batch_first)
    label_field_class = data.Field(sequential=False, use_vocab=False, dtype=torch.long)
    if augmented or use_teacher:
        # Augmented dataset uses class scores as labels
        label_field_scores = data.Field(sequential=False, batch_first=True, use_vocab=False,
            preprocessing=lambda x: [float(n) for n in x.split(" ")], dtype=torch.float32)
        fields_train = [("text", text_field), ("label", label_field_scores)]
    else:
        # Original training set uses the class id
        fields_train = [("text", text_field), ("label", label_field_class)]

    if augmented:
        train_file = "augmented.tsv"
    elif use_teacher:
        train_file = "noaugmented.tsv"
    else:
        train_file = "train.tsv"
    train_dataset = data.TabularDataset(
        path=os.path.join(data_dir, train_file),
        format="tsv",  skip_header=True,
        fields=fields_train
    )

    fields_valid = [("text", text_field), ("label", label_field_class)]
    valid_dataset = data.TabularDataset(
        path=os.path.join(data_dir, "dev.tsv"),
        format="tsv", skip_header=True,
        fields=fields_valid
    )

    # Initialize field's vocabulary
    if vocab is None:
#         vectors = pretrained_aliases["fasttext.en.300d"]()
        vectors = pretrained_aliases["glove.6B.300d"]() #vectors here basically work as a dict with keys as words and values as 300d embedding vectors.
        text_field.build_vocab(train_dataset, vectors=vectors) #this will take care of creating word_to_idx dict and embedding matrix (used for initializing weight of nn.Embedding) 
        del vectors
    else:
        # Use bert tokenizer's vocab if supplied
        text_field.vocab = vocab
    return train_dataset, valid_dataset, text_field

def get_model_wrapper(model_weights, text_field, device=None):
    if isinstance(text_field, str):
        text_field = torch.load(text_field)
    if isinstance(model_weights, str):
        model_weights = torch.load(model_weights)
    if device is None:
        device = torch.device("cpu")

    vocab = text_field.vocab
    model = BiLSTMClassifier(2, len(vocab.itos), vocab.vectors.shape[-1],
        lstm_hidden_size=300, classif_hidden_size=400, dropout_rate=0.15).to(device)
    model.load_state_dict(model_weights)
    trainer = LSTMTrainer(model, device)
    
    def model_wrapper(text):
        outputs = trainer.infer_one(text, text_field, softmax=True)
        return {
            "Negative": outputs[0],
            "Positive": outputs[1]
        }
    return model_wrapper


In [None]:
def build_pos_dict(sentences):
    """
    creates POS dict which will be used for augmentation as described here https://blog.floydhub.com/knowledge-distillation/.
    """
    pos_dict = {}
    for sentence in sentences:
        for word in sentence:   
            pos_tag = word.pos_
            if pos_tag not in pos_dict:
                pos_dict[pos_tag] = []
            if word.text.lower() not in pos_dict[pos_tag]:
                pos_dict[pos_tag].append(word.text.lower())
    return pos_dict

def make_sample(input_sentence, pos_dict, p_mask=0.1, p_pos=0.1, p_ng=0.25, max_ng=5):
    """
    generates augmenetd samples for a given input sentence based on three techniques described here https://blog.floydhub.com/knowledge-distillation/.
    """
    sentence = []
    for word in input_sentence:
        # Apply single token masking or POS-guided replacement
        u = np.random.uniform()
        if u < p_mask:
            sentence.append(mask_token)
        elif u < (p_mask + p_pos):
            same_pos = pos_dict[word.pos_]
            # Pick from list of words with same POS tag
            sentence.append(np.random.choice(same_pos))
        else:
            sentence.append(word.text.lower())
    # Apply n-gram sampling
    if len(sentence) > 2 and np.random.uniform() < p_ng:
        n = min(np.random.choice(range(1, 5+1)), len(sentence) - 1)
        start = np.random.choice(len(sentence) - n)
        for idx in range(start, start + n):
            sentence[idx] = mask_token
    return sentence

    
def augmentation(sentences, pos_dict, n_iter=20, p_mask=0.1, p_pos=0.1, p_ng=0.25, max_ng=5):
    """
    generates augmenetd samples for for entire training set based on three techniques described here https://blog.floydhub.com/knowledge-distillation/.
    """
    augmented = []
    for sentence in tqdm(sentences, "Generation"):
        samples = [[word.text.lower() for word in sentence]]
        for _ in range(n_iter):
            new_sample = make_sample(sentence, pos_dict, p_mask, p_pos, p_ng, max_ng)
            if new_sample not in samples:
                samples.append(new_sample)
        augmented.extend(samples)
    return augmented

def generate_dataset_student(input_dir, output_dir, model_to_load, no_augment= False, batch_size= 16, no_cuda= False ):
    """
    generates dataset for training the student (BiLSTM model). Target (class scores) for student is generated using the fine-tuned teacher (Bert)
    """
    device = torch.device("cuda" if not no_cuda and torch.cuda.is_available() else "cpu")
    set_seed(42)
    # Load original tsv file
    input_tsv = load_tsv(input_dir)
    if not no_augment:
        sentences = [spacy_en(text) for text, _ in tqdm(input_tsv, desc="Loading dataset")]
        # build lists of words indexes by POS tab
        pos_dict = build_pos_dict(sentences)
        # Generate augmented samples
        sentences = augmentation(sentences, pos_dict)
    else:
        sentences = [text for text, _ in input_tsv]

    # Load teacher model
    model = BertForSequenceClassification.from_pretrained(model_to_load).to(device)
    tokenizer = BertTokenizer.from_pretrained(model_to_load, do_lower_case=True)

    # Assign labels with teacher
    teacher_field = data.Field(sequential=True, tokenize=tokenizer.tokenize, lower=True, include_lengths=True, batch_first=True)
    fields = [("text", teacher_field)]
    if not no_augment:
        examples = [data.Example.fromlist([" ".join(words)], fields) for words in sentences]
    else:
        examples = [data.Example.fromlist([text], fields) for text in sentences]
    augmented_dataset = data.Dataset(examples, fields)
    teacher_field.vocab = BertVocab(tokenizer.vocab)
    new_labels = BertTrainer(model, device, batch_size=batch_size).infer(augmented_dataset)

    # Write to file
    with open(output_dir, "w") as f:
        f.write("sentence\tscores\n")
        for sentence, rating in zip(sentences, new_labels):
            if not no_augment:
                text = " ".join(sentence)
            else: text = sentence
            f.write("%s\t%.6f %.6f\n" % (text, *rating))

In [None]:
def save_bert(model, tokenizer, config, output_dir):
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)
    config.save_pretrained(output_dir)
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

def finetune_bert_teacher(data_dir, output_dir, lr_schedule, cache_dir, epochs=1, batch_size=16, gradient_accumulation_steps=1, lr= 5e-5, warmup_steps=0, epochs_per_cycle=1, 
             do_train= False, seed=42,  no_cuda= False, checkpoint_interval=1):
    """
    fine-tunes bert model (teacher) for sequence classification with original train data (class id (1/0)). This fine-tuned model will be used to generate
    class scores for transfer set (augmented/non-augmented) which will be used to train the student model (BiLSTM in this case).
    """
    if lr_schedule == "constant":
        lr_schedule = None
    device = torch.device("cuda" if not no_cuda and torch.cuda.is_available() else "cpu")
    set_seed(seed)
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)

    bert_config = BertConfig.from_pretrained("bert-base-uncased", cache_dir=cache_dir)
    bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", config=bert_config, cache_dir=cache_dir).to(device)
    bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True, cache_dir=cache_dir)
    #this will return data for fine-tuning the bert with original train data (class id (1/0))
    train_dataset, valid_dataset, _ = load_data(data_dir, bert_tokenizer.tokenize,
        vocab=BertVocab(bert_tokenizer.vocab), batch_first=True)
    
    #Bert is finetuned with cross-entropy loss with hard labels as target since it is a teacher model in this case.
    #Bert outputs logits (not softmax probs)
    trainer = BertTrainer(bert_model, device,
        loss="cross_entropy",
        train_dataset=train_dataset,
        val_dataset=valid_dataset, val_interval=250,
        checkpt_callback=lambda m, step: save_bert(m, bert_tokenizer, bert_config, os.path.join(output_dir, "checkpt_%d" % step)), #used for saving bert model while training
        checkpt_interval=checkpoint_interval,
        batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps,
        lr=lr)
    if do_train:
        trainer.train(epochs, schedule=lr_schedule,
            warmup_steps=warmup_steps, epochs_per_cycle=epochs_per_cycle)

    print("Evaluating bert (teacher) model:")
    print(trainer.evaluate())
    #saving this model so that it can be used to generate class scores for transfer set.
    save_bert(bert_model, bert_tokenizer, bert_config, output_dir)


In [None]:
class MultiChannelEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size, filters_size=64, filters=[2, 4, 6], dropout_rate=0.0):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.filters_size = filters_size
        self.filters = filters
        self.dropout_rate = dropout_rate
        self.embedding = nn.Embedding(self.vocab_size, self.embed_size)
        self.conv1 = nn.ModuleList([
            nn.Conv1d(self.embed_size, filters_size, kernel_size=f, padding=f//2)
            for f in filters
        ])
        self.act = nn.Sequential(
            nn.ReLU(inplace=True),
            #nn.Dropout(p=dropout_rate)
        )
    def init_embedding(self, weight):
        self.embedding.weight = nn.Parameter(weight.to(self.embedding.weight.device))
    def forward(self, x):
        x = x.transpose(0, 1)
        x = self.embedding(x).transpose(1, 2)
        channels = []
        for c in self.conv1:
            channels.append(c(x))
        x = F.relu(torch.cat(channels, 1))
        x = x.transpose(1, 2).transpose(0, 1)
        return x   

class BiLSTMClassifier(nn.Module): 
    """
    Creates a BiLSTM classifier network for student which can be trained with augemented/non-augmented (class scores) or original train data (class id).
    It returns Logits not the class labels/softmax probs.
    """
    def __init__(self, num_classes, vocab_size, embed_size, lstm_hidden_size, classif_hidden_size,
        lstm_layers=1, dropout_rate=0.0, use_multichannel_embedding=False):
        super().__init__()
        self.vocab_size = vocab_size
        self.lstm_hidden_size = lstm_hidden_size
        self.use_multichannel_embedding = use_multichannel_embedding
        if self.use_multichannel_embedding:
            self.embedding = MultiChannelEmbedding(self.vocab_size, embed_size, dropout_rate=dropout_rate)
            self.embed_size = len(self.embedding.filters) * self.embedding.filters_size
        else:
            self.embedding = nn.Embedding(self.vocab_size, embed_size)
            self.embed_size = embed_size
        self.lstm = nn.LSTM(self.embed_size, self.lstm_hidden_size, lstm_layers, bidirectional=True, dropout=dropout_rate)
        self.classifier = nn.Sequential(
            nn.Linear(lstm_hidden_size*2, classif_hidden_size),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout_rate),
            nn.Linear(classif_hidden_size, num_classes)
        )
    def init_embedding(self, weight):
        if self.use_multichannel_embedding:
            self.embedding.init_embedding(weight)
        else:
            self.embedding.weight = nn.Parameter(weight.to(self.embedding.weight.device))
    def forward(self, seq, length):
        # TODO use sort_within_batch?
        # Sort batch
        seq_size, batch_size = seq.size(0), seq.size(1)
        length_perm = (-length).argsort()
        length_perm_inv = length_perm.argsort()
        seq = torch.gather(seq, 1, length_perm[None, :].expand(seq_size, batch_size))
        length = torch.gather(length, 0, length_perm)
        # Pack sequence
        seq = self.embedding(seq)
        seq = pack_padded_sequence(seq, length)
        # Send through LSTM
        features, hidden_states = self.lstm(seq) #hidden_states is a tuple of (h,c) from last layer of bilstm where both the tensors (h & c ) have shape of (2, 64, 300). 2, since this is a bilstm.
        # Unpack sequence
        #this returns (45,64,600) for max seqlen in batch= 45, batch size=64, and lstm_hidden_size= 300 (600 because bilstm and batch_first= False)
        features = pad_packed_sequence(features)[0]
        # Separate last dimension into forward/backward features
        features = features.view(seq_size, batch_size, 2, -1) #shape: (45, 64, 2, 300)
        # Index to get forward and backward features and concatenate
        # Gather last word for each sequence
        #with pack_padded_sequence, we are not processing (for BiLSTM) padded tokens. That's why last hidden state for forward network will depend
        #on length of sequence. for seq of length=30, last hidden state will be at 29th index in features Tensor.
        #(length-1) because last index (as per length) is one less than sequence length.
        last_indexes = (length - 1)[None, :, None, None].expand((1, batch_size, 2, features.size(-1))) #shape: (1,64,2,300)
        forward_features = torch.gather(features, 0, last_indexes) #shape: (1,64,2,300)
        # Squeeze seq dimension, take forward features
        forward_features = forward_features[0, :, 0] #shape: (64,300). for each of the sequneces in batch we have 300d feature from last valid (non-padded) timestamp.
        # Take first word, backward features
        #last hidden state for backward lstm is always found at 0th index of the sequence irrespective of sequence length.
        backward_features = features[0, :, 1]  #shape: (64,300). for each of the sequneces in batch we have 300d feature from first timestamp for backward network.
        features = torch.cat((forward_features, backward_features), -1) #shape: (64,600). concatenated both forward and bakcward features.
        # Send through classifier
        logits = self.classifier(features) #shape: (64,2). for each seq in batch we have 2d logits score since we are dealing with 2 classes only.
        #Batch was sorted by length of each sequence. Before outputting it has to be brought to original order.
        # Invert batch permutation
        logits = torch.gather(logits, 0, length_perm_inv[:, None].expand((batch_size, logits.size(-1)))) #shape: (64,2) but sorted as per original order
        return logits, hidden_states

def save_bilstm(model, output_dir):
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)
    torch.save(model.state_dict(), os.path.join(output_dir, "weights.pth"))


def train_bilstm_student(data_dir, output_dir, augmented= False, use_teacher=False, epochs=1, batch_size=64, gradient_accumulation_steps=1,
                        lr= 5e-5, lr_schedule= "constant", warmup_steps=0, epochs_per_cycle_bilstm_bilstm=1,  do_train= False, seed=42, checkpoint_interval= -1, no_cuda= False):

    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)
    device = torch.device("cuda" if not no_cuda and torch.cuda.is_available() else "cpu")
    set_seed(seed)
    #loading data for training the BiLSTM classifier based on provided input arguments
    train_dataset, valid_dataset, text_field = load_data(data_dir, spacy_tokenizer, augmented=augmented, use_teacher=use_teacher)
    vocab = text_field.vocab
    
    #creating an instance of a BiLSTMClassifier network
    model = BiLSTMClassifier(2, len(vocab.itos), vocab.vectors.shape[-1],
        lstm_hidden_size=300, classif_hidden_size=400, dropout_rate=0.15).to(device)
    # Initialize word embeddings to fasttext
    model.init_embedding(vocab.vectors.to(device)) #vocab.vectors is a 2D tensor of shape (vocab_size, embed_size). it initilaizes weight of Embedding layer.
    
    #when either augmented or use_teacher is TRUE:
        #then BiLSTM (student) is being trained with mse loss (not KlDiv loss) with target as class scores from Bert (teacher) since as T becomes larger 
        #the Kullback–Leibler divergence becomes more and more similar to applying MSE Loss to the raw scores. MSE Loss tends to be more common for 
        #training small networks since, among a variety of reasons, it doesn’t have hyper-parameters. That is, we don’t need to pick a value for T. We
        #can also use klDiv loss in this case for comparison but T has to be fine-tuned as an hyper-param.
    #when both augmented and use_teacher are False:
        #then it indicates that BiLSTM has to be trained in isolation (not with the teacher) using hard labels
        #as targets. That's why it has to be trained with cross-entropy loss in this case.
    trainer = LSTMTrainer(model, device,
        loss="mse" if augmented or use_teacher else "cross_entropy",
        train_dataset=train_dataset, val_dataset=valid_dataset, val_interval=250,
        checkpt_interval=checkpoint_interval,
        checkpt_callback=lambda m, step: save_bilstm(m, os.path.join(output_dir, "checkpt_%d" % step)), #used for saving bilstm model while training
        batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps,
        lr=lr)
    
    if do_train:
        trainer.train(epochs, schedule=lr_schedule,
            warmup_steps=warmup_steps, epochs_per_cycle_bilstm=epochs_per_cycle_bilstm)
    
    print("Evaluating model:")
    print(trainer.evaluate())
    #save trained bilstm model which can be used for generating the class labels for test data
    save_bilstm(model, output_dir)

In [None]:
class Trainer():
    def __init__(self, model, device,
        loss="cross_entropy",
        train_dataset=None,
        temperature=1.0,
        val_dataset=None, val_interval=1,
        checkpt_callback=None, checkpt_interval=1,
        max_grad_norm=1.0, batch_size=64, gradient_accumulation_steps=1,
        lr=5e-5, weight_decay=0.0):
        # Storing
        self.model = model
        self.device = device
        self.loss_option = loss
        self.train_dataset = train_dataset
        self.temperature = temperature
        self.val_dataset = val_dataset
        self.val_interval = val_interval
        self.checkpt_callback = checkpt_callback
        self.checkpt_interval = checkpt_interval
        self.max_grad_norm = max_grad_norm
        self.batch_size = batch_size
        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.lr = lr
        self.weight_decay = weight_decay
        # Initialization
        assert self.loss_option in ["cross_entropy", "mse", "kl_div"]
        if self.loss_option == "cross_entropy":
            self.loss_function = nn.CrossEntropyLoss(reduction="sum")
        elif self.loss_option == "mse":
            self.loss_function = nn.MSELoss(reduction="sum")
        elif self.loss_option == "kl_div":
            self.loss_function = nn.KLDivLoss(reduction="sum")
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        if self.train_dataset is not None:
            self.train_it = data.BucketIterator(self.train_dataset, self.batch_size, train=True, sort_key=lambda x: len(x.text), device=self.device)
        else:
            self.train_it = None
        if self.val_dataset is not None:
            self.val_it = data.BucketIterator(self.val_dataset, self.batch_size, train=False, sort_key=lambda x: len(x.text), device=self.device)
        else:
            self.val_it = None
    def get_loss(self, model_output, label, curr_batch_size):
        #cross-entropy loss is used with the hard labels (class ids) as target but mse and Kullback–Leibler divergence loss are 
        #used with the teacher labels (class scores) as target.
        #CrossEntropyLoss and MSELoss takes logits not the softmax output (softmax probs).
        #You should pass raw logits to nn.CrossEntropyLoss, since the function itself applies F.log_softmax and nn.NLLLoss() on the input.
        #If you pass log probabilities (from nn.LogSoftmax) or probabilities (from nn.Softmax()) your loss function won’t work as intended.
        if self.loss_option in ["cross_entropy", "mse"]:
            loss = self.loss_function(
                model_output, #this is a logit output from bert or bilstm model
                label #hard label for cross-entropy and class scores from Bert (teacher) for mse
            ) / curr_batch_size # Mean over batch
        elif self.loss_option == "kl_div":
            # KL Divergence loss needs special care
            # It expects log probabilities for the model's output, and probabilities for the label
            loss = self.loss_function(
                F.log_softmax(model_output / self.temperature, dim=-1),
                F.softmax(label / self.temperature, dim=-1)
            ) / (self.temperature * self.temperature) / curr_batch_size
        return loss
    def train_step(self, batch):
        self.model.train() #put the model in train mode
        batch, label, curr_batch_size = self.process_batch(batch)
        #model output is a tuple, that's why indexing for 0 to fetch the logit outputs.
        #since batch is a keyworded argument (return from process_batch function), that's why ** for unrolling.
        s_logits = self.model(**batch)[0] #both Bert and BiLSTM model here outputs logits (not softmax probs)
        loss = self.get_loss(s_logits, label, curr_batch_size)
        loss.backward()
        self.training_step += 1
        #keep accumulating the gradients for gradient_accumulation_steps steps then update the weights and lr (if needed) and set the gradient to zero.
        if self.training_step % self.gradient_accumulation_steps == 0:
            # Apply gradient clipping
            nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
            self.optimizer.step()
            if self.scheduler is not None:
                # Advance learning rate schedule
                #scheduler.step() updates learning rate as per the scheduler, while optimizer.step() performs on batch-level which updates parameters.
                self.scheduler.step()
            self.model.zero_grad()
            # Save stats to tensorboard
            self.tb_writer.add_scalar("lr",
                self.scheduler.get_lr()[0] if self.scheduler is not None else self.lr,
                self.global_step)
            self.tb_writer.add_scalar("loss", loss, self.global_step)
            self.global_step += 1 #have doubts here regarding global_step being incremented here.
            # Every val_interval steps, evaluate and log stats to tensorboard
            if self.val_interval >= 0 and (self.global_step + 1) % self.val_interval == 0: # doubt: if global_step already incremented then why we are using global_step+1.
                results = self.evaluate()
                print(results)
                for k, v in results.items():
                    self.tb_writer.add_scalar("val_" + k, v, self.global_step)
            # Every checkpt_interval steps, call checkpt_callback to save a checkpoint
            if self.checkpt_interval >= 0 and (self.global_step + 1) % self.checkpt_interval == 0:
                self.checkpt_callback(self.model, self.global_step)
    def train(self, epochs=1, schedule=None, **kwargs):
        # Initialization
        self.global_step = 0 #tracks only the parameter updation steps
        self.training_step = 0 #tracks all the gradient calculation steps
        self.tb_writer = SummaryWriter()
        steps_per_epoch = len(self.train_dataset) // self.batch_size // self.gradient_accumulation_steps
        total_steps = epochs * steps_per_epoch
        # Initialize the learning rate scheduler if one has been chosen
        assert schedule is None or schedule in ["warmup", "cyclic"]
        if schedule is None:
            self.scheduler = None
            for grp in self.optimizer.param_groups: grp['lr'] = self.lr
        elif schedule == "warmup":
            warmup_steps = kwargs["warmup_steps"]
            self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=self.lr/100, max_lr=self.lr,
                step_size_up=max(1, warmup_steps), step_size_down=(total_steps - warmup_steps),cycle_momentum=False)
        elif schedule == "cyclic":
            epochs_per_cycle = kwargs["epochs_per_cycle"]
            self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=self.lr/25, max_lr=self.lr,
            step_size_up=steps_per_epoch // 2, cycle_momentum=False)
        # The loop over epochs, batches
        for epoch in trange(epochs, desc="Training"):
            for batch in tqdm(self.train_it, desc="Epoch %d" % epoch):
                self.train_step(batch)
        self.tb_writer.close()
        del self.tb_writer
    
    def evaluate(self):
        '''
        Evaluate the model on dev/valid data.
        '''
        self.model.eval() #put the model in eval mode
        val_loss = val_accuracy = 0.0
        loss_func = nn.CrossEntropyLoss(reduction="sum")
        for batch in tqdm(self.val_it, desc="Evaluation"):
            with torch.no_grad():
                batch, label, _ = self.process_batch(batch)
                output = self.model(**batch)[0] #both Bert and BiLSTM model here outputs logits (not softmax probs)
                loss = loss_func(output, label)
                val_loss += loss.item()
                val_accuracy += (output.argmax(dim=-1) == label).sum().item()
        val_loss /= len(self.val_dataset)
        val_accuracy /= len(self.val_dataset)
        return {
            "loss": val_loss,
            "perplexity": np.exp(val_loss),
            "accuracy": val_accuracy
        }
    # get the inference on any 
    def infer(self, dataset, softmax=False):
        '''
        Get the inference for any new data.
        '''
        self.model.eval() #put the model in eval mode
        outputs_idx = 0
        outputs = np.empty(shape=(len(dataset), 2))
        infer_it = data.Iterator(dataset, self.batch_size, train=False, sort=False, device=self.device)
        for batch in tqdm(infer_it, desc="Inference"):
            with torch.no_grad():
                batch, _, batch_size = self.process_batch(batch)
                output = self.model(**batch)[0] #both Bert and BiLSTM model here outputs logits (not softmax probs)
                #if softmax is True then we return softmax probs instead of logit outputs directly from Bert or BilSTM.
                if softmax:
                    output = F.softmax(output, dim=-1)
                outputs[outputs_idx:outputs_idx + batch_size] = output.detach().cpu().numpy()
                outputs_idx += batch_size
                del output
        return outputs
    def infer_one(self, example, text_field=None, softmax=False):
        self.model.eval()
        if text_field is None:
            text_field = self.train_dataset.fields["text"]
        example = text_field.preprocess(example)
        tokens, length  = text_field.process([example])
        with torch.no_grad():
            batch = self.process_one(tokens, length)
            output = self.model(**batch)[0]
            if softmax:
                output = F.softmax(output, dim=-1)
            output = output.detach().cpu().numpy()
        return output[0]
    def process_batch(self, *args):
        # Implemented by subclasses
        raise NotImplementedError()
    def process_one(self, *args):
        # Implemented by subclasses
        raise NotImplementedError()

class BertTrainer(Trainer):
    def process_batch(self, batch):
        #for both bert and lstm the dataset (batch) have two keys as dict Field: text and label. 
        #Here text Field is a tuple of tokens (seq of vocab indices for each sentence) and corresponding length of these sequences.
        #for a batch size of 64 and max sequence length in that batch=54; tokens.shape: (64,54), length.shape= (64,), label.shape= (64,), attention_mask.shape= (64,54)
        tokens, length = batch.text
        label = batch.label if "label" in batch.__dict__ else None
        length = length.unsqueeze_(1).expand(tokens.size())
        rg = torch.arange(tokens.size(1), device=self.device).unsqueeze_(0).expand(tokens.size())
        #attention_mask in bert is used to avoid performing attention on padding token indices. 
        #Mask values selected in [0, 1]: 1 for tokens that are NOT MASKED, 0 for MASKED tokens.
        #attention_mask only consists of 0 and 1. Wherever token is there then attention_mask=1 and wherever token=0 (padded) then attention_mask=0
        #shape of both tokens and attention_mask must be equal which is (batch_size X max(seq length in that batch))
        attention_mask = (rg < length).type(torch.float32)
        batch = {
            "input_ids": tokens,
            "attention_mask": attention_mask
        }
        return batch, label, tokens.size(0)  #tokens.size(0) because for bilstm batch_first= True
    def process_one(self, tokens, length):
        return {
            "input_ids": tokens.to(self.device),
            "attention_mask": torch.ones(tokens.size(), dtype=torch.float32, device=self.device)
        }

class LSTMTrainer(Trainer):
    def process_batch(self, batch):
        tokens, length = batch.text
        label = batch.label if "label" in batch.__dict__ else None
        batch = {
            "seq": tokens,
            "length": length,
        }
        return batch, label, tokens.size(1) #tokens.size(1) because for bilstm batch_first= False
    def process_one(self, tokens, length):
        return {
            "seq": tokens.to(self.device),
            "length": length.to(self.device)
        }


In [None]:
# if __name__ == "__main__":

data_dir= "../input/the-stanford-sentiment-treebank-v2-sst2/SST-2" #Directory containing the dataset (tsv files).
output_dir= "../output/" #Directory where to save the model and other outputs.
input_dir= "../input/the-stanford-sentiment-treebank-v2-sst2/SST-2/train.tsv" #location for Input dataset.
output_dir_student= "../output/augmented.tsv" #location to save the dataset for student model. if augmented= True then augmented.tsv else noaugmented.tsv
lr_schedule_bert= "warmup" # Schedule to use for the learning rate. Choices are: constant, linear warmup & decay, cyclic. must be one of these {constant,warmup,cyclic}
lr_schedule_bilstm= "warmup"
cache_dir= "/kaggle/working/" #Custom cache for transformer models
epochs_bert= 1 #no of epochs to finetune/train bert for
epochs_bilstm= 1 #no of epochs to train bilstm for
batch_bert= 16 #batch size for bert
batch_bilstm= 50 #batch size for bilstm
gradient_accumulation_steps_bert= 1 #no of steps to accumulate gradient for before parameter updation for bert
gradient_accumulation_steps_bilstm= 1 #no of steps to accumulate gradient for before parameter updation for bilstm
lr_bert= 1e-5
lr_bilstm= 1e-3
warmup_steps_bert= 100 # Warmup steps for the 'warmup' learning rate schedule, Ignored otherwise. no of steps to increase lr for before starting to decrease it for bert
warmup_steps_bilstm= 100 #no of steps to increase lr for before starting to decrease it for bilstm
do_train_bert= True #do you want to train/finetune bert (teacher) with hard labels
do_train_bilstm= True #do you want to train bilstm (student) with soft labels (class scores from bert)
epochs_per_cycle_bert= 1 #Epochs per cycle for the 'cyclic' learning rate schedule. Ignored otherwise
epochs_per_cycle_bilstm= 1 #Epochs per cycle for the 'cyclic' learning rate schedule. Ignored otherwise
seed= 42 #seed for reproducing same output
no_cuda= False #if running on GPU then False else True.
checkpoint_interval= 1 #Interval for keep saving the models
augmented= False #whether to use augmented or non-augmented data for knowledge distillation. True for augmented else False.
use_teacher= False #whether to use teacher model or train in isolation. True for teacher and False for isolation
model= "../output/" #Model to use to generate the labels for the augmented dataset. Directory which contains the finetuned weights from bert.
no_augment= False #whether to generate augmented dataset for student or not. if False then generate augmented data else don't generate.


finetune_bert_teacher(data_dir, output_dir, lr_schedule_bert, cache_dir, epochs_bert, batch_bert, gradient_accumulation_steps_bert, lr_bert, warmup_steps_bert, 
                      epochs_per_cycle_bert, do_train_bert, seed,  no_cuda, checkpoint_interval)

# If dataset that would be used by student (augmented.tsv or noaugmented.tsv) is already generated or stored, then we need not run following function
generate_dataset_student(input_dir, output_dir_student, model, no_augment, batch_bert, no_cuda)

train_bilstm_student(data_dir, output_dir, augmented, use_teacher, epochs_bilstm, batch_bilstm, gradient_accumulation_steps_bilstm,
                        lr_bilstm, lr_schedule_bilstm, warmup_steps_bilstm, epochs_per_cycle_bilstm, do_train_bilstm, seed, checkpoint_interval, no_cuda)

In [None]:
# import spacy

# nlp = spacy.load("en_core_web_sm")
# doc = nlp("Apple is looking at buying U.K. startup for $1 billion")
# print(doc)
# for token in doc:
#     print(token.text, token.lemma_, token.pos_, token.tag_, token.dep_,
#             token.shape_, token.is_alpha, token.is_stop)

In [None]:
# if True:
#     sentences = [spacy_en(text) for text, _ in tqdm(input_tsv, desc="Loading dataset")]
#     # build lists of words indexes by POS tab
#     pos_dict = build_pos_dict(sentences)
#     # Generate augmented samples
#     sentences = augmentation(sentences, pos_dict)
# else:
#     sentences = [text for text, _ in input_tsv]


In [None]:
# device= "cpu"
# bert_config = BertConfig.from_pretrained("bert-large-uncased")
# bert_model = BertForSequenceClassification.from_pretrained("bert-large-uncased", config=bert_config).to(device)
# bert_tokenizer = BertTokenizer.from_pretrained("bert-large-uncased", do_lower_case=True)
# train_dataset, valid_dataset, _ = load_data("../input/the-stanford-sentiment-treebank-v2-sst2/SST-2", bert_tokenizer.tokenize,
#     vocab=BertVocab(bert_tokenizer.vocab), batch_first=False)

In [None]:
# train_dataset, valid_dataset, text_field = load_data("../input/the-stanford-sentiment-treebank-v2-sst2/SST-2", spacy_tokenizer)
# vocab = text_field.vocab

In [None]:
# train_it = data.BucketIterator(train_dataset,64, train=True, sort_key=lambda x: len(x.text), device=device)
# i=1
# for epoch in trange(1, desc="Training"):
#     for batch in tqdm(train_it, desc="Epoch %d" % epoch):
#         batch1= batch
#         if(i==1): break
            

In [None]:
# tokens, length = batch1.text
# label = batch.label if "label" in batch.__dict__ else None
# print(length)
# # length = length.unsqueeze(1).expand(tokens.size())
# # length, length.size()

In [None]:
# seq= tokens
# seq_size, batch_size = seq.size(0), seq.size(1)
# length_perm = (-length).argsort()
# length_perm_inv = length_perm.argsort()
# seq = torch.gather(seq, 1, length_perm[None, :].expand(seq_size, batch_size))
# length = torch.gather(length, 0, length_perm)

In [None]:
# embedding = nn.Embedding(vocab.vectors.shape[0], vocab.vectors.shape[1])
# embedding.weight = nn.Parameter(vocab.vectors.to(device))
# seq = embedding(seq)
# seq = pack_padded_sequence(seq, length)
# seq

In [None]:
# lstm = nn.LSTM(300, 300, 1, bidirectional=True, dropout=0.0)
# features, hidden_states = lstm(seq)
# features = pad_packed_sequence(features)[0]
# print(features.shape)
# features = features.view(seq_size, batch_size, 2, -1)

In [None]:
# last_indexes = (length - 1)[None, :, None, None].expand((1, batch_size, 2, features.size(-1)))
# forward_features = torch.gather(features, 0, last_indexes)
# forward_features = forward_features[0, :, 0]
# # Take first word, backward features
# backward_features = features[0, :, 1]
# features = torch.cat((forward_features, backward_features), -1)
# forward_features.shape, backward_features.shape
# features.shape

In [None]:
# classifier = nn.Sequential(
#             nn.Linear(600, 300),
#             nn.ReLU(inplace=True),
#             nn.Dropout(p=0.0),
#             nn.Linear(300, 2)
#         )
# logits = classifier(features)
# logits = torch.gather(logits, 0, length_perm_inv[:, None].expand((64, logits.size(-1))))
# logits.shape

In [None]:
# (length-1)[None, :, None, None].expand((1, batch_size, 2, features.size(-1))).shape, forward_features.shape

In [None]:
# seq_size, batch_size, length_perm[None, :].expand(seq_size, batch_size)

In [None]:
# rg = torch.arange(tokens.size(1), device=device).unsqueeze_(0).expand(tokens.size())
# # print(rg, rg.shape)
# attention_mask = (rg < length).type(torch.float32)
# batch = {
#             "input_ids": tokens,
#             "attention_mask": attention_mask
#         }
# # attention_mask, attention_mask.shape
# batch

In [None]:
# # device= "cuda"
# # model = BertForSequenceClassification.from_pretrained("bert-large-uncased").to(device)
# # tokenizer = BertTokenizer.from_pretrained("bert-large-uncased", do_lower_case=True)

# # # Assign labels with teacher
# # teacher_field = data.Field(sequential=True, tokenize=tokenizer.tokenize, lower=True, include_lengths=True, batch_first=True)
# # fields = [("text", teacher_field)]
# # if True:
# #     examples = [data.Example.fromlist([" ".join(words)], fields) for words in sentences]
# # else:
# #     examples = [data.Example.fromlist([text], fields) for text in sentences]
# # augmented_dataset = data.Dataset(examples, fields)
# # teacher_field.vocab = BertVocab(tokenizer.vocab)
# new_labels = BertTrainer(model, device, batch_size=args.batch_size).infer(augmented_dataset)
# # Write to file
# with open(args.output, "w") as f:
#     f.write("sentence\tscores\n")
#     for sentence, rating in zip(sentences, new_labels):
#         if not args.no_augment:
#             text = " ".join(sentence)
#         else: text = sentence
#         f.write("%s\t%.6f %.6f\n" % (text, *rating))

In [None]:
# augmented_dataset.examples[1].__dict__.keys()

In [None]:
# bert_tokenizer.vocab["apple"]

In [None]:
# vocab=BertVocab(bert_tokenizer.vocab)
# vocab["naquib"]

In [None]:
# for name, layer in model.named_parameters():
#     print(name, layer.shape, layer)
# [print("nnnn", layer) for idx, layer in enumerate(model.children())]
# for name, module in model.named_modules():
#     print(name)

In [None]:
# length= torch.LongTensor([5,3,8,4,7,12,43,21,1])
# length_perm = (-length).argsort() 
# length_perm

In [None]:
# length_perm[None, :].expand(9, 9)