In [8]:
%pip install transformers[torch] huggingface_hub datasets evaluate torchvision kagglehub ipywidgets

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [1]:
from transformers import Trainer
import torch.nn.functional as F
from tqdm.notebook import tqdm

import torch.nn as nn

import numpy as np


import torch
import base
import kagglehub
from datasets import load_from_disk
from transformers import BasicTokenizer
from datasets import concatenate_datasets

In [2]:
my_glove = kagglehub.dataset_download("thanakomsn/glove6b300dtxt")
print(my_glove)

/home/jovyan/.cache/kagglehub/datasets/thanakomsn/glove6b300dtxt/versions/1


In [5]:
GLOVE_FILE = f"{my_glove}/glove.6B.300d.txt"

In [6]:
train_data = load_from_disk("./data/sst2/train-logits")
eval_data = load_from_disk("./data/sst2/eval-logits")
test_data = load_from_disk("./data/sst2/test-logits")

all_train_data = load_from_disk("./data/sst2/train-logits-augmented")


all_data = concatenate_datasets([load_from_disk(file) for file in ["./data/sst2/train-logits", "./data/sst2/eval-logits", "./data/sst2/test-logits", "./data/sst2/train-logits-augmented"]])
tokenizer = BasicTokenizer(do_lower_case=True)

In [7]:
def tokenize(dataset):
    if isinstance(dataset["sentence"], str):
        return list(tokenizer.tokenize(dataset["sentence"]))
    else:
        raise ValueError("Input text is not string")

In [8]:
def get_vocab(dataset):
    all_tokens = []
    for data in dataset:
        for token in data:
            all_tokens.append(token)

    vocab = set(all_tokens)
    return vocab


In [9]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 2g.20gb


In [10]:
train_data_tokens = list(map(lambda e: tokenize(e), train_data))
eval_data_tokens = list(map(lambda e: tokenize(e), eval_data))
test_data_tokens = list(map(lambda e: tokenize(e), test_data))

all_train_data_tokens = list(map(lambda e: tokenize(e), all_train_data))


all_data_tokens = list(map(lambda e: tokenize(e), all_data))

In [11]:
vocab = get_vocab(all_data_tokens)

In [12]:
word_index = dict(zip(vocab, range(len(vocab))))

In [13]:
embeddings_index = {}
with open(GLOVE_FILE, encoding='utf-8') as f:
    for line in f:
        word, coefs = line.split(maxsplit=1)
        coefs = np.fromstring(coefs, "f", sep=" ")
        embeddings_index[word] = coefs
print(f"Found {len(embeddings_index)} word vectors.")


Found 400000 word vectors.


In [14]:
print(len(vocab))

14621


In [15]:
num_tokens = len(vocab) + 2
embedding_dim = 300
hits = 0
misses = 0
embedding_matrix = np.zeros((num_tokens, embedding_dim))

for word, i in word_index.items():
    embedding_vector = embeddings_index.get(word)
    if embedding_vector is not None:
        embedding_matrix[i] = embedding_vector
        hits += 1
    else:
        misses += 1
embedding_matrix = torch.tensor(embedding_matrix, dtype=torch.float32)
print(f"Converted {hits} words ({misses})")

Converted 14305 words (316)


In [16]:
def padd(data, max_length):
    padding_length = max_length - len(data)
    if padding_length > 0:
        padding = [0 for _ in range(padding_length)]
        data.extend(padding)
    return data[:max_length]

In [17]:
train_data_index = list(map(lambda x: list(map(lambda y: word_index[y], x)),train_data_tokens))
eval_data_index = list(map(lambda x: list(map(lambda y: word_index[y], x)),eval_data_tokens))
test_data_index = list(map(lambda x: list(map(lambda y: word_index[y], x)),test_data_tokens))

all_train_data_index = list(map(lambda x: list(map(lambda y: word_index[y], x)),all_train_data_tokens))

In [None]:
train_padded_data = list(map(lambda x: padd(x,300), train_data_index))
eval_padded_data = list(map(lambda x: padd(x,300), eval_data_index))
test_padded_data = list(map(lambda x: padd(x,300), test_data_index))

all_train_padded_data = list(map(lambda x: padd(x,300), all_train_data_index))

In [23]:
train_data = train_data.add_column("input_ids", train_padded_data)
eval_data = eval_data.add_column("input_ids", eval_padded_data)
test_data = test_data.add_column("input_ids", test_padded_data)

all_train_data = all_train_data.add_column("input_ids", all_train_padded_data)

In [67]:
class BiLSTMClassifier(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, fc_dim, output_dim, embedding_matrix):
        super(BiLSTMClassifier, self).__init__()
        
        self.embedding = nn.Embedding.from_pretrained(embedding_matrix, freeze=True)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=1, bidirectional=True, batch_first=True)
        self.fc1 = nn.Linear(hidden_dim * 2, fc_dim)  
        self.dropout = nn.Dropout(.2)
        self.fc2 = nn.Linear(fc_dim, output_dim)

    def forward(self, input_ids, labels=None):
        embedded = self.embedding(input_ids)  
        _, (h_n, _) = self.lstm(embedded)
        h_forward = h_n[-2, :, :]  # Last forward hidden state
        h_backward = h_n[-1, :, :]  # Last backward hidden state
        out_cat = torch.cat((h_forward, h_backward), dim=1)
        fc1_out = F.relu(self.fc1(out_cat))
        dropped = self.dropout(fc1_out)
        logits = self.fc2(dropped)
        
        if labels is not None:
            labels = nn.functional.one_hot(labels, num_classes=self.fc2.out_features) 
            loss_fn = nn.CrossEntropyLoss() 
            loss = loss_fn(logits, labels.float())
            return {"loss" : loss, "logits" : logits}
        return {"loss" : None, "logits": logits}
    
model = BiLSTMClassifier(embedding_matrix=embedding_matrix, embedding_dim=embedding_dim, fc_dim=400, hidden_dim=300, output_dim=2)

In [68]:
print(model)

BiLSTMClassifier(
  (embedding): Embedding(14623, 300)
  (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)
  (fc1): Linear(in_features=600, out_features=400, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc2): Linear(in_features=400, out_features=2, bias=True)
)


In [69]:
training_args = base.get_training_args(output_dir="./results/bilstm-base", logging_dir='./logs/bilstm-base', lr=.001,  epochs=10, batch_size=128)

In [70]:
base.reset_seed()

In [71]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=test_data,
    compute_metrics=base.compute_metrics,
)

In [72]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.382,0.33108,0.858723,0.859061,0.863918,0.858294
2,0.2928,0.266239,0.889161,0.88711,0.891044,0.88832
3,0.2277,0.221975,0.911656,0.909772,0.911955,0.91071
4,0.1752,0.226675,0.915293,0.917672,0.910864,0.913508
5,0.1339,0.21124,0.925835,0.923899,0.927066,0.925137
6,0.0996,0.21826,0.929696,0.928639,0.92889,0.928763
7,0.0715,0.237316,0.930364,0.928973,0.930133,0.929517
8,0.0497,0.269629,0.932071,0.931195,0.931089,0.931142
9,0.0333,0.293311,0.931997,0.930517,0.931997,0.931198
10,0.0228,0.312872,0.932368,0.930922,0.932313,0.931565


TrainOutput(global_step=4210, training_loss=0.1488402756262845, metrics={'train_runtime': 97.858, 'train_samples_per_second': 5505.836, 'train_steps_per_second': 43.022, 'total_flos': 0.0, 'train_loss': 0.1488402756262845, 'epoch': 10.0})

In [73]:
model.eval()

BiLSTMClassifier(
  (embedding): Embedding(14623, 300)
  (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)
  (fc1): Linear(in_features=600, out_features=400, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc2): Linear(in_features=400, out_features=2, bias=True)
)

In [74]:
trainer.evaluate(eval_data)

{'eval_loss': 0.7650835514068604,
 'eval_accuracy': 0.8555045871559633,
 'eval_precision': 0.8558324898785425,
 'eval_recall': 0.8552033341752967,
 'eval_f1': 0.8553554502369668,
 'eval_runtime': 4.1204,
 'eval_samples_per_second': 211.631,
 'eval_steps_per_second': 1.699,
 'epoch': 10.0}

In [75]:
student_model = BiLSTMClassifier(embedding_matrix=embedding_matrix, embedding_dim=embedding_dim, fc_dim=400, hidden_dim=300, output_dim=2)

In [76]:
training_args = base.get_training_args(output_dir="./results/bilstm-distill", remove_unused_columns=False, logging_dir='./logs/bilstm-distill', lr=.001,  epochs=10, batch_size=128, lambda_param=.75, temp=5)

In [77]:
base.reset_seed()

In [78]:
trainer = base.ImageDistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=test_data,
    compute_metrics=base.compute_metrics,
)

In [79]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.2202,1.703732,0.866518,0.865024,0.869582,0.865781
2,1.4758,1.298951,0.895694,0.894389,0.894097,0.894241
3,1.031,1.020113,0.914625,0.912776,0.91493,0.913706
4,0.7521,0.834832,0.928731,0.927146,0.928843,0.927914
5,0.5531,0.75324,0.934521,0.933111,0.934469,0.933741
6,0.4167,0.711626,0.937268,0.93631,0.936565,0.936436
7,0.3285,0.657693,0.941203,0.939789,0.94138,0.940517
8,0.2541,0.631549,0.942762,0.94148,0.942743,0.942071
9,0.2058,0.629697,0.94291,0.94137,0.943415,0.942277
10,0.1714,0.612443,0.943356,0.942061,0.94338,0.942676


TrainOutput(global_step=4210, training_loss=0.7408818333279209, metrics={'train_runtime': 102.5924, 'train_samples_per_second': 5251.753, 'train_steps_per_second': 41.036, 'total_flos': 0.0, 'train_loss': 0.7408818333279209, 'epoch': 10.0})

In [80]:
student_model.eval()

BiLSTMClassifier(
  (embedding): Embedding(14623, 300)
  (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)
  (fc1): Linear(in_features=600, out_features=400, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc2): Linear(in_features=400, out_features=2, bias=True)
)

In [81]:
trainer.evaluate(eval_data)

{'eval_loss': 1.7558796405792236,
 'eval_accuracy': 0.8555045871559633,
 'eval_precision': 0.8555213948787062,
 'eval_recall': 0.855371726867054,
 'eval_f1': 0.8554285353375861,
 'eval_runtime': 3.0682,
 'eval_samples_per_second': 284.206,
 'eval_steps_per_second': 2.281,
 'epoch': 10.0}

In [82]:
model = BiLSTMClassifier(embedding_matrix=embedding_matrix, embedding_dim=embedding_dim, fc_dim=400, hidden_dim=300, output_dim=2)

In [83]:
training_args = base.get_training_args(output_dir="./results/bilstm-base", logging_dir='./logs/bilstm-base', lr=.001,  epochs=10, batch_size=128)

In [84]:
base.reset_seed()

In [85]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=all_train_data,
    eval_dataset=test_data,
    compute_metrics=base.compute_metrics,
)

In [86]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.2682,0.179852,0.931923,0.930278,0.932296,0.93117
2,0.1341,0.17189,0.948255,0.947043,0.948294,0.947629
3,0.0812,0.189071,0.949146,0.948179,0.948831,0.948495
4,0.0498,0.23011,0.951745,0.951118,0.951055,0.951087
5,0.0304,0.290225,0.949146,0.947913,0.949249,0.948536
6,0.0192,0.36742,0.95167,0.951093,0.950919,0.951006
7,0.0122,0.412773,0.94922,0.948532,0.948532,0.948532
8,0.0075,0.491498,0.950408,0.949924,0.94951,0.949713
9,0.0048,0.538246,0.951225,0.950529,0.950607,0.950568
10,0.003,0.59634,0.951299,0.950543,0.950761,0.950651


TrainOutput(global_step=39290, training_loss=0.0610427748674532, metrics={'train_runtime': 562.4468, 'train_samples_per_second': 8940.4, 'train_steps_per_second': 69.855, 'total_flos': 0.0, 'train_loss': 0.0610427748674532, 'epoch': 10.0})

In [87]:
model.eval()

BiLSTMClassifier(
  (embedding): Embedding(14623, 300)
  (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)
  (fc1): Linear(in_features=600, out_features=400, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc2): Linear(in_features=400, out_features=2, bias=True)
)

In [88]:
trainer.evaluate(eval_data)

{'eval_loss': 0.5181599259376526,
 'eval_accuracy': 0.8704128440366973,
 'eval_precision': 0.8703937504931745,
 'eval_recall': 0.8705165445819651,
 'eval_f1': 0.8703990382781601,
 'eval_runtime': 4.3202,
 'eval_samples_per_second': 201.842,
 'eval_steps_per_second': 1.62,
 'epoch': 10.0}

In [89]:
student_model = BiLSTMClassifier(embedding_matrix=embedding_matrix, embedding_dim=embedding_dim, fc_dim=400, hidden_dim=300, output_dim=2)

In [90]:
training_args = base.get_training_args(output_dir="./results/bilstm-distill", remove_unused_columns=False, logging_dir='./logs/bilstm-distill', lr=.001,  epochs=10, batch_size=128, lambda_param=.75, temp=5)

In [91]:
base.reset_seed()

In [None]:
trainer = base.ImageDistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=all_train_data,
    eval_dataset=test_data,
    compute_metrics=base.compute_metrics,
)

In [93]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.0579,0.608205,0.942465,0.94119,0.942424,0.941768
2,0.4311,0.419901,0.954566,0.953394,0.954733,0.95402
3,0.2732,0.359612,0.960134,0.959187,0.960124,0.959635
4,0.1949,0.328496,0.961618,0.960677,0.961646,0.96114
5,0.1487,0.319208,0.96147,0.960515,0.961513,0.960991
6,0.12,0.296007,0.963103,0.962291,0.962995,0.962632
7,0.1008,0.29269,0.962806,0.961878,0.96285,0.962342
8,0.0865,0.281432,0.963697,0.962985,0.963475,0.963224
9,0.0756,0.273645,0.9634,0.962711,0.963139,0.962921
10,0.0679,0.272132,0.963846,0.963043,0.963747,0.963384


TrainOutput(global_step=39290, training_loss=0.25565577871352246, metrics={'train_runtime': 597.9402, 'train_samples_per_second': 8409.704, 'train_steps_per_second': 65.709, 'total_flos': 0.0, 'train_loss': 0.25565577871352246, 'epoch': 10.0})

In [94]:
student_model.eval()

BiLSTMClassifier(
  (embedding): Embedding(14623, 300)
  (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)
  (fc1): Linear(in_features=600, out_features=400, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc2): Linear(in_features=400, out_features=2, bias=True)
)

In [None]:
trainer.evaluate(eval_data)

{'eval_loss': 0.8172466158866882,
 'eval_accuracy': 0.8795871559633027,
 'eval_precision': 0.8798068564383821,
 'eval_recall': 0.879357160899217,
 'eval_f1': 0.8794881008319744,
 'eval_runtime': 2.9664,
 'eval_samples_per_second': 293.962,
 'eval_steps_per_second': 2.36,
 'epoch': 10.0}