In [1]:
from datasets import concatenate_datasets, load_from_disk
from transformers import BasicTokenizer, EarlyStoppingCallback, Trainer
import kagglehub
import torch
import base
import os

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/jovyan/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


In [2]:
my_glove = kagglehub.dataset_download("thanakomsn/glove6b300dtxt")
print(my_glove)

/home/jovyan/.cache/kagglehub/datasets/thanakomsn/glove6b300dtxt/versions/1


In [3]:
GLOVE_FILE = f"{my_glove}/glove.6B.300d.txt"
DATASET = "trec"

In [4]:
train_data = load_from_disk(f"~/data/{DATASET}/train-logits_fine")
eval_data = load_from_disk(f"~/data/{DATASET}/eval-logits_fine")
test_data = load_from_disk(f"~/data/{DATASET}/test-logits_fine")

all_train_data = load_from_disk(f"~/data/{DATASET}/train-logits-augmented_fine")

all_data = concatenate_datasets([load_from_disk(file) for file in [f"~/data/{DATASET}/eval-logits_fine", f"~/data/{DATASET}/test-logits_fine", f"~/data/{DATASET}/train-logits-augmented_fine"]])
tokenizer = BasicTokenizer(do_lower_case=True)

In [5]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 2g.20gb


In [6]:
train_data_tokens = list(map(lambda e: tokenizer.tokenize(e["sentence"]), train_data))
eval_data_tokens = list(map(lambda e: tokenizer.tokenize(e["sentence"]), eval_data))
test_data_tokens = list(map(lambda e: tokenizer.tokenize(e["sentence"]), test_data))

all_train_data_tokens = list(map(lambda e: tokenizer.tokenize(e["sentence"]), all_train_data))

all_data_tokens = list(map(lambda e: tokenizer.tokenize(e["sentence"]), all_data))

In [7]:
vocab = base.get_vocab(all_data_tokens)

In [8]:
word_index = dict(zip(vocab, range(len(vocab))))

In [9]:
embeddings_index = base.get_embeddings_indeces(GLOVE_FILE)

Found 400000 word vectors.


In [10]:
print(len(vocab))
num_tokens = len(vocab) + 2
embedding_dim = 300

8766


In [11]:
embedding_matrix = base.get_embedding_matrix(num_tokens, embedding_dim, word_index, embeddings_index)

Converted 8551 words (215) misses


In [12]:
train_data_index = list(map(lambda x: list(map(lambda y: word_index[y], x)),train_data_tokens))
eval_data_index = list(map(lambda x: list(map(lambda y: word_index[y], x)),eval_data_tokens))
test_data_index = list(map(lambda x: list(map(lambda y: word_index[y], x)),test_data_tokens))

all_train_data_index = list(map(lambda x: list(map(lambda y: word_index[y], x)),all_train_data_tokens))

In [13]:
train_padded_data = list(map(lambda x: base.padd(x,60), train_data_index))
eval_padded_data = list(map(lambda x: base.padd(x,60), eval_data_index))
test_padded_data = list(map(lambda x: base.padd(x,60), test_data_index))

all_train_padded_data = list(map(lambda x: base.padd(x,60), all_train_data_index))

In [14]:
train_data = train_data.add_column("input_ids", train_padded_data)
eval_data = eval_data.add_column("input_ids", eval_padded_data)
test_data = test_data.add_column("input_ids", test_padded_data)

all_train_data = all_train_data.add_column("input_ids", all_train_padded_data)

In [15]:
model = base.BiLSTMClassifier(embedding_matrix=embedding_matrix, embedding_dim=embedding_dim, fc_dim=400, hidden_dim=300, output_dim=50)

In [16]:
print(model)

BiLSTMClassifier(
  (embedding): Embedding(8768, 300)
  (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)
  (fc1): Linear(in_features=600, out_features=400, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc2): Linear(in_features=400, out_features=50, bias=True)
)


In [20]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/bilstm-base_fine", logging_dir=f"~/logs/{DATASET}/bilstm-base_fine", lr=.001,  epochs=10, batch_size=128)

In [21]:
base.reset_seed()

In [22]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=eval_data,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [23]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,3.0782,2.507161,0.390467,0.070064,0.086817,0.059064
2,2.1657,1.93477,0.519707,0.167539,0.14541,0.124516
3,1.724,1.592946,0.601283,0.236949,0.211851,0.20426
4,1.3812,1.361398,0.660862,0.340172,0.273271,0.274921
5,1.1269,1.249793,0.686526,0.355648,0.317757,0.317563
6,0.9333,1.189549,0.683776,0.391895,0.334762,0.341869
7,0.7722,1.1564,0.696609,0.423865,0.35556,0.369691
8,0.6952,1.131603,0.705775,0.440811,0.385572,0.397846
9,0.5919,1.115437,0.706691,0.441,0.396359,0.405291
10,0.5338,1.110169,0.706691,0.44629,0.392478,0.403815


TrainOutput(global_step=350, training_loss=1.3002333504813057, metrics={'train_runtime': 92.9123, 'train_samples_per_second': 469.367, 'train_steps_per_second': 3.767, 'total_flos': 0.0, 'train_loss': 1.3002333504813057, 'epoch': 10.0})

In [24]:
model.eval()

BiLSTMClassifier(
  (embedding): Embedding(8768, 300)
  (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)
  (fc1): Linear(in_features=600, out_features=400, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc2): Linear(in_features=400, out_features=50, bias=True)
)

In [25]:
trainer.evaluate(test_data)

{'eval_loss': 1.0224372148513794,
 'eval_accuracy': 0.738,
 'eval_precision': 0.43001110037454765,
 'eval_recall': 0.48362209479935037,
 'eval_f1': 0.42725538526628765,
 'eval_runtime': 4.6588,
 'eval_samples_per_second': 107.325,
 'eval_steps_per_second': 0.859,
 'epoch': 10.0}

In [26]:
torch.save(model.state_dict(), f"{os.path.expanduser('~')}/models/{DATASET}/bilstm-base_fine.pth")

In [27]:
student_model = base.BiLSTMClassifier(embedding_matrix=embedding_matrix, embedding_dim=embedding_dim, fc_dim=400, hidden_dim=300, output_dim=50)

In [28]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/bilstm-distill_fine", remove_unused_columns=False, logging_dir=f"~/logs/{DATASET}/bilstm-distill_fine", lr=.001,  epochs=10, batch_size=128, lambda_param=.4, temp=2)

In [29]:
base.reset_seed()

In [30]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=eval_data,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [31]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.7356,2.214236,0.372136,0.061409,0.076439,0.051624
2,1.9192,1.715433,0.488543,0.09706,0.123207,0.098378
3,1.5855,1.461795,0.567369,0.194694,0.170749,0.155853
4,1.3465,1.300753,0.633364,0.251351,0.231007,0.22217
5,1.173,1.194623,0.665445,0.286682,0.27128,0.266521
6,1.0085,1.108314,0.692026,0.322703,0.297319,0.294363
7,0.8957,1.082811,0.692026,0.330278,0.302123,0.303787
8,0.8237,1.028491,0.705775,0.35562,0.321527,0.316046
9,0.7498,1.021081,0.702108,0.363828,0.324351,0.32469


KeyboardInterrupt: 

In [None]:
student_model.eval()

BiLSTMClassifier(
  (embedding): Embedding(8768, 300)
  (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)
  (fc1): Linear(in_features=600, out_features=400, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc2): Linear(in_features=400, out_features=50, bias=True)
)

In [None]:
trainer.evaluate(test_data)

{'eval_loss': 0.9858419895172119,
 'eval_accuracy': 0.686,
 'eval_precision': 0.336147111289856,
 'eval_recall': 0.3725331908094909,
 'eval_f1': 0.3238966398492582,
 'eval_runtime': 3.2628,
 'eval_samples_per_second': 153.243,
 'eval_steps_per_second': 1.226,
 'epoch': 10.0}

In [None]:
torch.save(student_model.state_dict(), f"{os.path.expanduser('~')}/models/{DATASET}/bilstm-distill_fine.pth")

In [None]:
model = base.BiLSTMClassifier(embedding_matrix=embedding_matrix, embedding_dim=embedding_dim, fc_dim=400, hidden_dim=300, output_dim=50)

In [None]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/bilstm-base-aug_fine", logging_dir=f"~/logs/{DATASET}/bilstm-base-aug_fine", lr=.001,  epochs=10, batch_size=128)

In [None]:
base.reset_seed()

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=all_train_data,
    eval_dataset=eval_data,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.5355,1.090564,0.712191,0.436485,0.391231,0.394178
2,0.4759,0.954173,0.766269,0.707687,0.584136,0.614689
3,0.1861,1.08806,0.781852,0.677774,0.63013,0.639764
4,0.0799,1.250184,0.769936,0.715949,0.620547,0.646466
5,0.0372,1.248047,0.785518,0.726865,0.650527,0.672268
6,0.0163,1.359968,0.799267,0.727283,0.648819,0.670847
7,0.0078,1.461539,0.793767,0.732195,0.662001,0.685314
8,0.0039,1.482695,0.799267,0.724557,0.674018,0.68352
9,0.0024,1.485037,0.792851,0.69982,0.657599,0.666796
10,0.0016,1.489118,0.796517,0.717694,0.66968,0.68046


TrainOutput(global_step=2830, training_loss=0.23466455228758365, metrics={'train_runtime': 95.4416, 'train_samples_per_second': 3787.133, 'train_steps_per_second': 29.652, 'total_flos': 0.0, 'train_loss': 0.23466455228758365, 'epoch': 10.0})

In [None]:
model.eval()

BiLSTMClassifier(
  (embedding): Embedding(8768, 300)
  (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)
  (fc1): Linear(in_features=600, out_features=400, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc2): Linear(in_features=400, out_features=50, bias=True)
)

In [None]:
trainer.evaluate(test_data)

{'eval_loss': 1.2564932107925415,
 'eval_accuracy': 0.826,
 'eval_precision': 0.6806092675124096,
 'eval_recall': 0.6792263233940713,
 'eval_f1': 0.6625506422671901,
 'eval_runtime': 3.2507,
 'eval_samples_per_second': 153.812,
 'eval_steps_per_second': 1.23,
 'epoch': 10.0}

In [None]:
torch.save(model.state_dict(), f"{os.path.expanduser('~')}/models/{DATASET}/bilstm-base-aug_fine.pth")

In [None]:
student_model = base.BiLSTMClassifier(embedding_matrix=embedding_matrix, embedding_dim=embedding_dim, fc_dim=400, hidden_dim=300, output_dim=50)

In [None]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/bilstm-distill-aug_fine", remove_unused_columns=False, logging_dir=f"~/logs/{DATASET}/bilstm-distill-aug_fine", lr=.001,  epochs=10, batch_size=128, lambda_param=.4, temp=2)

In [None]:
base.reset_seed()

In [None]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=all_train_data,
    eval_dataset=eval_data,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.398,0.972544,0.710357,0.384375,0.334128,0.336958
2,0.5456,0.773753,0.782768,0.530266,0.495128,0.504416
3,0.3056,0.735239,0.8011,0.607206,0.568059,0.578294
4,0.2057,0.722042,0.810266,0.696164,0.621794,0.643894
5,0.1599,0.701917,0.817599,0.743609,0.664334,0.688383
6,0.1391,0.682544,0.826764,0.783324,0.691911,0.719749
7,0.126,0.687981,0.824931,0.821678,0.705702,0.742098
8,0.1186,0.684311,0.823098,0.814479,0.701373,0.734262
9,0.1141,0.682187,0.824015,0.794419,0.693118,0.723827
10,0.1114,0.67629,0.824931,0.815664,0.698016,0.735


TrainOutput(global_step=2830, training_loss=0.3224059303741994, metrics={'train_runtime': 73.3845, 'train_samples_per_second': 4925.43, 'train_steps_per_second': 38.564, 'total_flos': 0.0, 'train_loss': 0.3224059303741994, 'epoch': 10.0})

In [None]:
student_model.eval()

BiLSTMClassifier(
  (embedding): Embedding(8768, 300)
  (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)
  (fc1): Linear(in_features=600, out_features=400, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc2): Linear(in_features=400, out_features=50, bias=True)
)

In [None]:
trainer.evaluate(test_data)

{'eval_loss': 0.5525984764099121,
 'eval_accuracy': 0.828,
 'eval_precision': 0.7166157976308094,
 'eval_recall': 0.6772778862976302,
 'eval_f1': 0.6770289294273095,
 'eval_runtime': 3.3678,
 'eval_samples_per_second': 148.466,
 'eval_steps_per_second': 1.188,
 'epoch': 10.0}

In [None]:
torch.save(model.state_dict(), f"{os.path.expanduser('~')}/models/{DATASET}/bilstm-distill-aug_fine.pth")