In [5]:
from datasets import concatenate_datasets, load_from_disk
from transformers import BasicTokenizer, EarlyStoppingCallback, Trainer, BertForSequenceClassification, AutoConfig, BertTokenizer
from torch.utils.data import DataLoader
import kagglehub
import torch
import base
import copy
import os

In [2]:
my_glove = kagglehub.dataset_download("thanakomsn/glove6b300dtxt")
print(my_glove)

/home/jovyan/.cache/kagglehub/datasets/thanakomsn/glove6b300dtxt/versions/1


In [3]:
GLOVE_FILE = f"{my_glove}/glove.6B.300d.txt"
DATASET = "sst2"

In [68]:
train_data = load_from_disk(f"~/data/{DATASET}/train-logits")
eval_data = load_from_disk(f"~/data/{DATASET}/eval-logits")
test_data = load_from_disk(f"~/data/{DATASET}/test-logits")

all_train_data = load_from_disk(f"~/data/{DATASET}/train-logits-augmented")

all_data = concatenate_datasets([load_from_disk(file) for file in [f"~/data/{DATASET}/eval-logits", f"~/data/{DATASET}/test-logits", f"~/data/{DATASET}/train-logits-augmented"]])
tokenizer = BasicTokenizer(do_lower_case=True)
teacher_tokenizer = BertTokenizer.from_pretrained("gchhablani/bert-base-cased-finetuned-sst2")

In [7]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA H100 PCIe


In [8]:
train_data_tokens = list(map(lambda e: tokenizer.tokenize(e["sentence"]), train_data))
eval_data_tokens = list(map(lambda e: tokenizer.tokenize(e["sentence"]), eval_data))
test_data_tokens = list(map(lambda e: tokenizer.tokenize(e["sentence"]), test_data))

all_train_data_tokens = list(map(lambda e: tokenizer.tokenize(e["sentence"]), all_train_data))

all_data_tokens = list(map(lambda e: tokenizer.tokenize(e["sentence"]), all_data))

In [9]:
vocab = base.get_vocab(all_data_tokens)

In [10]:
word_index = dict(zip(vocab, range(len(vocab))))

In [11]:
embeddings_index = base.get_embeddings_indeces(GLOVE_FILE)

Found 400000 word vectors.


In [12]:
print(len(vocab))
num_tokens = len(vocab) + 2
embedding_dim = 300

14621


In [13]:
embedding_matrix = base.get_embedding_matrix(num_tokens, embedding_dim, word_index, embeddings_index)

Converted 14305 words (316) misses


In [14]:
train_data_index = list(map(lambda x: list(map(lambda y: word_index[y], x)),train_data_tokens))
eval_data_index = list(map(lambda x: list(map(lambda y: word_index[y], x)),eval_data_tokens))
test_data_index = list(map(lambda x: list(map(lambda y: word_index[y], x)),test_data_tokens))

all_train_data_index = list(map(lambda x: list(map(lambda y: word_index[y], x)),all_train_data_tokens))

In [15]:
train_padded_data = list(map(lambda x: base.padd(x,60), train_data_index))
eval_padded_data = list(map(lambda x: base.padd(x,60), eval_data_index))
test_padded_data = list(map(lambda x: base.padd(x,60), test_data_index))

all_train_padded_data = list(map(lambda x: base.padd(x,60), all_train_data_index))

In [16]:
train_teacher_data = base.prepare_dataset_teacher(train_data, teacher_tokenizer)
eval_teacher_data = base.prepare_dataset_teacher(eval_data, teacher_tokenizer)
test_teacher_data = base.prepare_dataset_teacher(test_data, teacher_tokenizer)

all_train_teacher_data = base.prepare_dataset_teacher(all_train_data, teacher_tokenizer)

Tokenizing the provided dataset:   0%|          | 0/53879 [00:00<?, ? examples/s]

Tokenizing the provided dataset:   0%|          | 0/872 [00:00<?, ? examples/s]

Tokenizing the provided dataset:   0%|          | 0/13470 [00:00<?, ? examples/s]

Tokenizing the provided dataset:   0%|          | 0/293634 [00:00<?, ? examples/s]

In [69]:
train_data = train_data.add_column("input_ids", train_padded_data)
train_data = train_data.add_column("teacher_ids", train_teacher_data[0])
train_data = train_data.add_column("teacher_attention", train_teacher_data[1])

eval_data = eval_data.add_column("input_ids", eval_padded_data)
eval_data = eval_data.add_column("teacher_ids", eval_teacher_data[0])
eval_data = eval_data.add_column("teacher_attention", eval_teacher_data[1])

test_data = test_data.add_column("input_ids", test_padded_data)
test_data = test_data.add_column("teacher_ids", test_teacher_data[0])
test_data = test_data.add_column("teacher_attention", test_teacher_data[1])

all_train_data = all_train_data.add_column("input_ids", all_train_padded_data)
all_train_data = all_train_data.add_column("teacher_ids", all_train_teacher_data[0])
all_train_data = all_train_data.add_column("teacher_attention", all_train_teacher_data[1])

In [70]:
student_model = base.BiLSTMClassifier(embedding_matrix=embedding_matrix, embedding_dim=embedding_dim, fc_dim=400, hidden_dim=300, output_dim=2)

In [71]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/bilstm-distill_fine", remove_unused_columns=False, logging_dir=f"~/logs/{DATASET}/bilstm-distill_fine")

In [72]:
base.reset_seed()

In [73]:
train_data.set_format(type="torch", columns=["input_ids", "logits", "labels"], device="cpu")
eval_data.set_format(type="torch", columns=["input_ids", "logits", "labels"], device="cpu")

In [74]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=eval_data,
    compute_metrics=base.compute_metrics
)

In [75]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.5994,1.746571,0.78211,0.783666,0.782752,0.782017
2,1.792,1.693846,0.784404,0.784762,0.783994,0.784113
3,1.6637,1.639896,0.779817,0.781173,0.780416,0.779742
4,1.6124,1.712009,0.766055,0.770215,0.76484,0.764538
5,1.5822,1.619792,0.78555,0.785862,0.785162,0.785279


TrainOutput(global_step=2105, training_loss=1.8499325278819032, metrics={'train_runtime': 50.6844, 'train_samples_per_second': 5315.15, 'train_steps_per_second': 41.532, 'total_flos': 0.0, 'train_loss': 1.8499325278819032, 'epoch': 5.0})

In [24]:
base.reset_seed()

In [25]:
student_model = base.BiLSTMClassifier(embedding_matrix=embedding_matrix, embedding_dim=embedding_dim, fc_dim=400, hidden_dim=300, output_dim=2)
teacher_model = BertForSequenceClassification.from_pretrained("gchhablani/bert-base-cased-finetuned-sst2", num_labels=2)
teacher_model.to(device)
teacher_model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [76]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/bilstm-distill_fine_infer", remove_unused_columns=False, logging_dir=f"~/logs/{DATASET}/bilstm-distill_fine_infer")

In [77]:
base.reset_seed()

In [78]:
train_data.reset_format()
eval_data.reset_format()   

In [79]:
trainer = base.DistilTrainerInferText(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=eval_data,
    compute_metrics=base.compute_metrics
)

In [80]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.586,1.641049,0.783257,0.789805,0.784594,0.782513
2,1.5311,1.571059,0.78211,0.782283,0.781784,0.781885
3,1.4953,1.53404,0.793578,0.79376,0.793803,0.793577
4,1.4764,1.582667,0.780963,0.784331,0.7799,0.779814
5,1.4582,1.522424,0.792431,0.792462,0.792214,0.792287


TrainOutput(global_step=2105, training_loss=1.5094139407196407, metrics={'train_runtime': 67.3658, 'train_samples_per_second': 3998.986, 'train_steps_per_second': 31.247, 'total_flos': 0.0, 'train_loss': 1.5094139407196407, 'epoch': 5.0})

In [31]:
student_model = base.BiLSTMClassifier(embedding_matrix=embedding_matrix, embedding_dim=embedding_dim, fc_dim=400, hidden_dim=300, output_dim=2)

In [32]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/bilstm-distill_fine", remove_unused_columns=False, logging_dir=f"~/logs/{DATASET}/bilstm-distill_fine")

In [33]:
base.reset_seed()

In [34]:
all_train_data.set_format(type="torch", columns=["input_ids", "logits", "labels"], device="cpu")
eval_data.set_format(type="torch", columns=["input_ids", "logits", "labels"], device="cpu")

In [35]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset= all_train_data,
    eval_dataset=eval_data,
    compute_metrics=base.compute_metrics
)

In [36]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.6319,1.602721,0.787844,0.788501,0.788257,0.78783
2,1.2836,1.49912,0.793578,0.7937,0.793298,0.793394
3,1.1996,1.446129,0.800459,0.800939,0.800812,0.800455
4,1.1396,1.399032,0.802752,0.803284,0.802307,0.802452
5,1.1022,1.400693,0.797018,0.797183,0.796718,0.796824


TrainOutput(global_step=11475, training_loss=1.2713790168845316, metrics={'train_runtime': 169.8933, 'train_samples_per_second': 8641.717, 'train_steps_per_second': 67.542, 'total_flos': 0.0, 'train_loss': 1.2713790168845316, 'epoch': 5.0})

In [37]:
student_model = base.BiLSTMClassifier(embedding_matrix=embedding_matrix, embedding_dim=embedding_dim, fc_dim=400, hidden_dim=300, output_dim=2)

In [38]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/bilstm-distill_fine_infer", remove_unused_columns=False, logging_dir=f"~/logs/{DATASET}/bilstm-distill_fine_infer")

In [39]:
base.reset_seed()

In [43]:
all_train_data.reset_format()
eval_data.reset_format()   

In [44]:
trainer = base.DistilTrainerInferText(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=all_train_data,
    eval_dataset=eval_data,
    compute_metrics=base.compute_metrics
)

In [45]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.6319,1.602721,0.787844,0.788501,0.788257,0.78783
2,1.2836,1.49912,0.793578,0.7937,0.793298,0.793394
3,1.1996,1.44614,0.800459,0.800939,0.800812,0.800455
4,1.1396,1.398991,0.803899,0.804374,0.803475,0.803618
5,1.1022,1.400865,0.797018,0.797183,0.796718,0.796824


TrainOutput(global_step=11475, training_loss=1.271378102022059, metrics={'train_runtime': 252.0103, 'train_samples_per_second': 5825.834, 'train_steps_per_second': 45.534, 'total_flos': 0.0, 'train_loss': 1.271378102022059, 'epoch': 5.0})

In [46]:
student_model = BertForSequenceClassification.from_pretrained("google/bert_uncased_L-2_H-128_A-2", num_labels=2)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [47]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/bilstm-distill_fine", remove_unused_columns=False, logging_dir=f"~/logs/{DATASET}/bilstm-distill_fine")

In [48]:
train_data = train_data.remove_columns(["input_ids"])
train_data = train_data.rename_column("teacher_attention", "attention_mask")
train_data = train_data.rename_column("teacher_ids", "input_ids")

eval_data = eval_data.remove_columns(["input_ids"])
eval_data = eval_data.rename_column("teacher_attention", "attention_mask")
eval_data = eval_data.rename_column("teacher_ids", "input_ids")

train_data.set_format(type="torch", columns=["input_ids", "attention_mask", "logits", "labels"], device="cpu")
eval_data.set_format(type="torch", columns=["input_ids", "attention_mask", "logits", "labels"], device="cpu")

In [49]:
base.reset_seed()

In [50]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=eval_data,
    compute_metrics=base.compute_metrics
)

In [51]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,3.0974,2.423601,0.705275,0.705181,0.705197,0.705188
2,2.355,2.062898,0.748853,0.748876,0.748958,0.748837
3,1.9023,1.943749,0.764908,0.764845,0.764766,0.764797
4,1.6662,1.900899,0.779817,0.780674,0.78029,0.779788
5,1.5646,1.87653,0.780963,0.781898,0.781458,0.780928


TrainOutput(global_step=2105, training_loss=2.117128349539786, metrics={'train_runtime': 58.3853, 'train_samples_per_second': 4614.09, 'train_steps_per_second': 36.054, 'total_flos': 40108928454000.0, 'train_loss': 2.117128349539786, 'epoch': 5.0})

In [52]:
base.reset_seed()

In [53]:
student_model = BertForSequenceClassification.from_pretrained("google/bert_uncased_L-2_H-128_A-2", num_labels=2)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [54]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/bilstm-distill_fine_infer", remove_unused_columns=False, logging_dir=f"~/logs/{DATASET}/bilstm-distill_fine_infer")

In [55]:
trainer = base.DistilTrainerInfer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=eval_data,
    compute_metrics=base.compute_metrics
)

In [56]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,3.1197,2.401331,0.715596,0.715643,0.715711,0.715583
2,2.3943,2.021847,0.755734,0.756378,0.755168,0.755244
3,1.9405,1.865274,0.784404,0.786395,0.783573,0.783634
4,1.6935,1.872592,0.775229,0.782205,0.773722,0.773124
5,1.5938,1.794727,0.784404,0.785969,0.783657,0.783748


TrainOutput(global_step=2105, training_loss=2.1483438451046615, metrics={'train_runtime': 68.9719, 'train_samples_per_second': 3905.868, 'train_steps_per_second': 30.52, 'total_flos': 40108928454000.0, 'train_loss': 2.1483438451046615, 'epoch': 5.0})

In [57]:
student_model = BertForSequenceClassification.from_pretrained("google/bert_uncased_L-2_H-128_A-2", num_labels=2)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [58]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/bilstm-distill_fine", remove_unused_columns=False, logging_dir=f"~/logs/{DATASET}/bilstm-distill_fine")

In [59]:
base.reset_seed()

In [60]:
all_train_data = all_train_data.remove_columns(["input_ids"])
all_train_data = all_train_data.rename_column("teacher_attention", "attention_mask")
all_train_data = all_train_data.rename_column("teacher_ids", "input_ids")

all_train_data.set_format(type="torch", columns=["input_ids", "attention_mask", "logits", "labels"], device="cpu")

In [61]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=all_train_data,
    eval_dataset=eval_data,
    compute_metrics=base.compute_metrics
)

In [62]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.6716,1.489252,0.813073,0.81322,0.813284,0.813071
2,0.8237,1.534105,0.818807,0.818746,0.818746,0.818746
3,0.6694,1.534851,0.815367,0.815566,0.815073,0.81519
4,0.5979,1.582777,0.816514,0.816563,0.816662,0.816505
5,0.5658,1.590183,0.819954,0.819889,0.819915,0.819901


TrainOutput(global_step=11475, training_loss=0.8656818704044118, metrics={'train_runtime': 211.5224, 'train_samples_per_second': 6940.968, 'train_steps_per_second': 54.25, 'total_flos': 218588784084000.0, 'train_loss': 0.8656818704044118, 'epoch': 5.0})

In [63]:
base.reset_seed()

In [64]:
student_model = BertForSequenceClassification.from_pretrained("google/bert_uncased_L-2_H-128_A-2", num_labels=2)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [65]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/bilstm-distill_fine_infer", remove_unused_columns=False, logging_dir=f"~/logs/{DATASET}/bilstm-distill_fine_infer")

In [66]:
trainer = base.DistilTrainerInfer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=all_train_data,
    eval_dataset=eval_data,
    compute_metrics=base.compute_metrics
)

In [67]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.6997,1.560551,0.805046,0.805094,0.805191,0.805037
2,0.8325,1.577026,0.816514,0.816456,0.816536,0.816479
3,0.6753,1.596344,0.809633,0.809713,0.809401,0.809489
4,0.6028,1.635521,0.805046,0.805024,0.804896,0.804943
5,0.5663,1.647905,0.807339,0.807271,0.807317,0.80729


TrainOutput(global_step=11475, training_loss=0.8753401267190905, metrics={'train_runtime': 262.1551, 'train_samples_per_second': 5600.387, 'train_steps_per_second': 43.772, 'total_flos': 218588784084000.0, 'train_loss': 0.8753401267190905, 'epoch': 5.0})