In [2]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"]=""
import os.path
import torch
import torch.optim as optim
# from allennlp.data.dataset_readers.stanford_sentiment_tree_bank import \
#     StanfordSentimentTreeBankDatasetReader
from reader_new import StanfordSentimentTreeBankDatasetReader_NEW
from allennlp.data.data_loaders import SimpleDataLoader
from allennlp.data.vocabulary import Vocabulary
from allennlp.models import Model
from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.modules.token_embedders.embedding import _read_pretrained_embeddings_file
from allennlp.modules.token_embedders import Embedding
from allennlp.nn.util import get_text_field_mask
from allennlp.training.metrics import CategoricalAccuracy
from allennlp.training.trainer import Trainer
from allennlp.common.util import lazy_groups_of
from allennlp.data.token_indexers import SingleIdTokenIndexer
from allennlp.nn.util import move_to_device
import pandas as pd

2022-05-15 03:29:25.647957: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0


In [3]:
class LstmClassifier(Model):
    def __init__(self, word_embeddings, encoder, vocab):
        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self.encoder = encoder
        self.linear = torch.nn.Linear(in_features=encoder.get_output_dim(),
                                      out_features=vocab.get_vocab_size('labels'))
        self.accuracy = CategoricalAccuracy()
        self.loss_function = torch.nn.CrossEntropyLoss()

    def forward(self, tokens, label):
        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)
        encoder_out = self.encoder(embeddings, mask)
        logits = self.linear(encoder_out)
        output = {"logits": logits}
        if label is not None:
            self.accuracy(logits, label)
            output["loss"] = self.loss_function(logits, label)
        return output

    def get_metrics(self, reset=False):
        return {'accuracy': self.accuracy.get_metric(reset)}

In [4]:
# load the binary SST dataset.
single_id_indexer = SingleIdTokenIndexer(lowercase_tokens=True) # word tokenizer

# use_subtrees gives us a bit of extra data by breaking down each example into sub sentences.
reader = StanfordSentimentTreeBankDatasetReader_NEW(granularity="2-class",
                                                token_indexers={"tokens": single_id_indexer},
                                                use_subtrees=True)

train_data = reader.read('./data/train.txt')
reader = StanfordSentimentTreeBankDatasetReader_NEW(granularity="2-class",
                                                token_indexers={"tokens": single_id_indexer})
dev_data = reader.read('./data/dev.txt')

In [5]:
vocab_path = "./lstm_main_sst_model/w2v_" + "vocab"
vocab = Vocabulary.from_files(vocab_path)

In [6]:
embedding_path = "./data/crawl-300d-2M.vec.zip"
weight = _read_pretrained_embeddings_file(embedding_path,
                                          embedding_dim=300,
                                          vocab=vocab,
                                          namespace="tokens")
token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                            embedding_dim=300,
                            weight=weight,
                            trainable=False)
word_embedding_dim = 300
word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})


encoder = PytorchSeq2VecWrapper(torch.nn.LSTM(word_embedding_dim,
                                              hidden_size=512,
                                              num_layers=2,
                                              batch_first=True))
# word_embeddings2 = BasicTextFieldEmbedder({"tokens": token_embedding})
# encoder2 = PytorchSeq2VecWrapper(torch.nn.LSTM(word_embedding_dim,
#                                               hidden_size=512,
#                                               num_layers=2,
#                                               batch_first=True))

# student_model = LstmClassifier(word_embeddings2, encoder2, vocab)
teacher_model = LstmClassifier(word_embeddings, encoder, vocab)
model_path = "./lstm_main_sst_model/w2v_model.th"
    
with open(model_path, 'rb') as f:
    teacher_model.load_state_dict(torch.load(f,map_location='cpu'))

  0%|          | 0/1999995 [00:00<?, ?it/s]

In [1]:
train_dl = SimpleDataLoader(list(train_data), batch_size=128, shuffle=True)
train_dl.index_with(vocab)

NameError: name 'SimpleDataLoader' is not defined

In [7]:
val_dl = SimpleDataLoader(list(dev_data), batch_size=128, shuffle=True)
val_dl.index_with(vocab)

In [8]:
student_model.train().cuda()
teacher_model.train().cuda()
optimizer = optim.Adam(student_model.parameters())

In [9]:
import torch.nn.functional as F
import torch.nn as nn
import copy

In [10]:
def loss_fn_kd(outputs, labels, teacher_outputs):
    params = {
            "alpha": 0.95,
            "temperature": 6,
    }
    alpha = params['alpha']
    T = params['temperature']
    KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
                             F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
              F.cross_entropy(outputs, labels) * (1. - alpha)
    return KD_loss

In [11]:
for ep in range(3):
    student_model.train()
    for k in train_dl:
        k = move_to_device(k, device=0)
        toks = k['tokens']
        labs = k['label']
        output_teacher = teacher_model(toks, labs)
        output_student = student_model(toks, labs)
        loss = (loss_fn_kd(output_student['logits'].view(-1, 2), labs.view(-1), output_teacher['logits']))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
    student_model.eval()
    student_model.get_metrics(reset=True)
    for k in val_dl:
        k = move_to_device(k, device=0)
        toks = k['tokens']
        labs = k['label']
        output_student = student_model(toks, labs)
    print(student_model.get_metrics()['accuracy'])



0.8520642201834863
0.8509174311926605
0.8646788990825688


In [14]:
model_path = 'lstm_distilled_sst_model/w2v_model.th'
with open(model_path, 'wb') as f:
    torch.save(student_model.state_dict(), f)
# vocab.save_to_files(vocab_path)