In [95]:
from datasets import concatenate_datasets, load_from_disk
from transformers import BasicTokenizer, Trainer, EarlyStoppingCallback
import kagglehub
import torch
import base
import os

In [96]:
my_glove = kagglehub.dataset_download("thanakomsn/glove6b300dtxt")
print(my_glove)

/home/jovyan/.cache/kagglehub/datasets/thanakomsn/glove6b300dtxt/versions/1


In [97]:
GLOVE_FILE = f"{my_glove}/glove.6B.300d.txt"
DATASET = "trec"

In [98]:
train_data = load_from_disk(f"~/data/{DATASET}/train-logits_coarse")
eval_data = load_from_disk(f"~/data/{DATASET}/eval-logits_coarse")
test_data = load_from_disk(f"~/data/{DATASET}/test-logits_coarse")

all_train_data = load_from_disk(f"~/data/{DATASET}/train-logits-augmented_coarse")

all_data = concatenate_datasets([load_from_disk(file) for file in [f"~/data/{DATASET}/eval-logits_coarse", f"~/data/{DATASET}/test-logits_coarse", f"~/data/{DATASET}/train-logits-augmented_coarse"]])
tokenizer = BasicTokenizer(do_lower_case=True)

In [99]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA H100 PCIe


In [100]:
train_data_tokens = list(map(lambda e: tokenizer.tokenize(e["sentence"]), train_data))
eval_data_tokens = list(map(lambda e: tokenizer.tokenize(e["sentence"]), eval_data))
test_data_tokens = list(map(lambda e: tokenizer.tokenize(e["sentence"]), test_data))

all_train_data_tokens = list(map(lambda e: tokenizer.tokenize(e["sentence"]), all_train_data))

all_data_tokens = list(map(lambda e: tokenizer.tokenize(e["sentence"]), all_data))

In [101]:
vocab = base.get_vocab(all_data_tokens)

In [102]:
word_index = dict(zip(vocab, range(len(vocab))))

In [103]:
embeddings_index = base.get_embeddings_indeces(GLOVE_FILE)

Found 400000 word vectors.


In [104]:
print(len(vocab))
num_tokens = len(vocab) + 2
embedding_dim = 300

8766


In [105]:
embedding_matrix = base.get_embedding_matrix(num_tokens, embedding_dim, word_index, embeddings_index)

Converted 8551 words (215) misses


In [106]:
train_data_index = list(map(lambda x: list(map(lambda y: word_index[y], x)),train_data_tokens))
eval_data_index = list(map(lambda x: list(map(lambda y: word_index[y], x)),eval_data_tokens))
test_data_index = list(map(lambda x: list(map(lambda y: word_index[y], x)),test_data_tokens))

all_train_data_index = list(map(lambda x: list(map(lambda y: word_index[y], x)),all_train_data_tokens))

In [107]:
train_padded_data = list(map(lambda x: base.padd(x,60), train_data_index))
eval_padded_data = list(map(lambda x: base.padd(x,60), eval_data_index))
test_padded_data = list(map(lambda x: base.padd(x,60), test_data_index))

all_train_padded_data = list(map(lambda x: base.padd(x,60), all_train_data_index))

In [108]:
train_data = train_data.add_column("input_ids", train_padded_data)
eval_data = eval_data.add_column("input_ids", eval_padded_data)
test_data = test_data.add_column("input_ids", test_padded_data)

all_train_data = all_train_data.add_column("input_ids", all_train_padded_data)

In [109]:
model = base.BiLSTMClassifier(embedding_matrix=embedding_matrix, embedding_dim=embedding_dim, fc_dim=400, hidden_dim=300, output_dim=6, freeze_embed=False)

In [110]:
print(model)

BiLSTMClassifier(
  (embedding): Embedding(8768, 300)
  (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)
  (fc1): Linear(in_features=600, out_features=400, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc2): Linear(in_features=400, out_features=6, bias=True)
)


In [111]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/bilstm-base_coarse_embedd", logging_dir=f"~/logs/{DATASET}/bilstm-base_coarse_embedd", lr=.005, weight_decay=.001, warmup_steps=4, epochs=20, batch_size=128)

In [112]:
base.reset_seed()

In [113]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=eval_data,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 4)]
)

In [114]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.0179,0.698765,0.774519,0.677601,0.657372,0.656584
2,0.3718,0.442209,0.861595,0.835801,0.808452,0.816831
3,0.1215,0.495563,0.879927,0.889976,0.801268,0.826402
4,0.0426,0.492769,0.877177,0.857197,0.838664,0.846835
5,0.0101,0.69182,0.873511,0.855567,0.823549,0.837071
6,0.0067,0.764008,0.870761,0.847876,0.832321,0.838546
7,0.0041,0.688693,0.871677,0.850788,0.834702,0.841099
8,0.0029,0.72073,0.875344,0.854717,0.837652,0.84505


TrainOutput(global_step=280, training_loss=0.19721224983888014, metrics={'train_runtime': 58.7067, 'train_samples_per_second': 1485.691, 'train_steps_per_second': 11.924, 'total_flos': 0.0, 'train_loss': 0.19721224983888014, 'epoch': 8.0})

In [115]:
model.eval()

BiLSTMClassifier(
  (embedding): Embedding(8768, 300)
  (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)
  (fc1): Linear(in_features=600, out_features=400, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc2): Linear(in_features=400, out_features=6, bias=True)
)

In [116]:
trainer.evaluate(test_data)

{'eval_loss': 0.4047532081604004,
 'eval_accuracy': 0.882,
 'eval_precision': 0.8462707017073257,
 'eval_recall': 0.8645004090856151,
 'eval_f1': 0.8512382056259588,
 'eval_runtime': 3.5734,
 'eval_samples_per_second': 139.921,
 'eval_steps_per_second': 1.119,
 'epoch': 8.0}

In [117]:
torch.save(model.state_dict(), f"{os.path.expanduser('~')}/models/{DATASET}/bilstm-base_coarse_embedd.pth")

In [118]:
student_model = base.BiLSTMClassifier(embedding_matrix=embedding_matrix, embedding_dim=embedding_dim, fc_dim=400, hidden_dim=300, output_dim=6, freeze_embed=False)

In [119]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/bilstm-distill_coarse_embedd", remove_unused_columns=False, logging_dir=f"~/logs/{DATASET}/bilstm-distill_coarse_embedd", lr=.005, weight_decay=0.004, warmup_steps=1, epochs=20, batch_size=128, lambda_param=.6, temp=4)

In [120]:
base.reset_seed()

In [121]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=eval_data,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 4)]
)

In [122]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.5433,1.53259,0.7956,0.686991,0.675375,0.675736
2,0.8821,0.939075,0.862511,0.726828,0.734705,0.729329
3,0.3392,0.911219,0.867094,0.883666,0.802039,0.826958
4,0.1767,0.77931,0.883593,0.890326,0.823614,0.84614
5,0.1159,0.742111,0.888176,0.896102,0.845837,0.865158
6,0.0932,0.79831,0.875344,0.887423,0.826039,0.848028
7,0.0765,0.75773,0.887259,0.895388,0.836163,0.857594
8,0.07,0.766474,0.879927,0.89015,0.829804,0.851547
9,0.0627,0.733519,0.885426,0.892971,0.834284,0.855406


TrainOutput(global_step=315, training_loss=0.4844025112333752, metrics={'train_runtime': 39.7041, 'train_samples_per_second': 2196.753, 'train_steps_per_second': 17.63, 'total_flos': 0.0, 'train_loss': 0.4844025112333752, 'epoch': 9.0})

In [123]:
student_model.eval()

BiLSTMClassifier(
  (embedding): Embedding(8768, 300)
  (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)
  (fc1): Linear(in_features=600, out_features=400, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc2): Linear(in_features=400, out_features=6, bias=True)
)

In [124]:
trainer.evaluate(test_data)

{'eval_loss': 0.5530129075050354,
 'eval_accuracy': 0.926,
 'eval_precision': 0.9343320322857044,
 'eval_recall': 0.9177282995349499,
 'eval_f1': 0.9247244401573691,
 'eval_runtime': 3.8253,
 'eval_samples_per_second': 130.709,
 'eval_steps_per_second': 1.046,
 'epoch': 9.0}

In [125]:
torch.save(student_model.state_dict(), f"{os.path.expanduser('~')}/models/{DATASET}/bilstm-distill_coarse_embedd.pth")

In [126]:
model = base.BiLSTMClassifier(embedding_matrix=embedding_matrix, embedding_dim=embedding_dim, fc_dim=400, hidden_dim=300, output_dim=6, freeze_embed=False)

In [127]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/bilstm-base-aug_coarse_embedd", logging_dir=f"~/logs/{DATASET}/bilstm-base-aug_coarse_embedd", lr=.0045, weight_decay=.007, warmup_steps=5, epochs=20, batch_size=128)

In [128]:
base.reset_seed()

In [129]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=all_train_data,
    eval_dataset=eval_data,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 4)]
)

In [130]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.2755,0.566823,0.862511,0.828948,0.836157,0.830687
2,0.0294,0.742938,0.873511,0.864484,0.807929,0.826157
3,0.0154,0.959887,0.855179,0.853852,0.813879,0.826547
4,0.0121,0.820969,0.855179,0.842003,0.813022,0.823155
5,0.0089,0.975391,0.848763,0.813738,0.823703,0.817903


TrainOutput(global_step=1525, training_loss=0.06826410043435019, metrics={'train_runtime': 42.0394, 'train_samples_per_second': 18515.953, 'train_steps_per_second': 145.102, 'total_flos': 0.0, 'train_loss': 0.06826410043435019, 'epoch': 5.0})

In [131]:
model.eval()

BiLSTMClassifier(
  (embedding): Embedding(8768, 300)
  (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)
  (fc1): Linear(in_features=600, out_features=400, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc2): Linear(in_features=400, out_features=6, bias=True)
)

In [132]:
trainer.evaluate(test_data)

{'eval_loss': 0.4136075973510742,
 'eval_accuracy': 0.888,
 'eval_precision': 0.8404510256287869,
 'eval_recall': 0.8892604063476749,
 'eval_f1': 0.8569037633430963,
 'eval_runtime': 3.4539,
 'eval_samples_per_second': 144.764,
 'eval_steps_per_second': 1.158,
 'epoch': 5.0}

In [133]:
torch.save(model.state_dict(), f"{os.path.expanduser('~')}/models/{DATASET}/bilstm-base-aug_coarse_embedd.pth")

In [134]:
student_model = base.BiLSTMClassifier(embedding_matrix=embedding_matrix, embedding_dim=embedding_dim, fc_dim=400, hidden_dim=300, output_dim=6, freeze_embed=False)

In [135]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/bilstm-distill-aug_coarse_embedd", remove_unused_columns=False, logging_dir=f"~/logs/{DATASET}/bilstm-distill-aug_coarse_embedd", lr=.0035,  epochs=20, batch_size=128, weight_decay=.008, warmup_steps=12, lambda_param=.7, temp=5)

In [136]:
base.reset_seed()

In [137]:
trainer = base.DistilTrainer(
    student_model=student_model,
    args=training_args,
    train_dataset=all_train_data,
    eval_dataset=eval_data,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 4)]
)

In [138]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.8143,0.68058,0.898258,0.902333,0.84724,0.865952
2,0.1573,0.679826,0.894592,0.887876,0.843854,0.859208
3,0.1164,0.644492,0.890926,0.896394,0.841148,0.859966
4,0.0964,0.635338,0.901008,0.903955,0.849049,0.868187
5,0.0837,0.636511,0.900092,0.904118,0.857922,0.874972
6,0.0758,0.602172,0.911091,0.913637,0.865987,0.883916
7,0.0716,0.607285,0.908341,0.91041,0.854392,0.874183
8,0.0655,0.620164,0.909258,0.9131,0.864525,0.882648
9,0.0606,0.588093,0.907424,0.910744,0.86308,0.880912
10,0.0565,0.595842,0.912924,0.915519,0.867124,0.885401


TrainOutput(global_step=4270, training_loss=0.12854679942968578, metrics={'train_runtime': 118.2713, 'train_samples_per_second': 6581.477, 'train_steps_per_second': 51.576, 'total_flos': 0.0, 'train_loss': 0.12854679942968578, 'epoch': 14.0})

In [139]:
student_model.eval()

BiLSTMClassifier(
  (embedding): Embedding(8768, 300)
  (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)
  (fc1): Linear(in_features=600, out_features=400, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc2): Linear(in_features=400, out_features=6, bias=True)
)

In [140]:
trainer.evaluate(test_data)

{'eval_loss': 0.4494105875492096,
 'eval_accuracy': 0.938,
 'eval_precision': 0.9268877903431405,
 'eval_recall': 0.9119383501549946,
 'eval_f1': 0.9176413927416944,
 'eval_runtime': 4.2219,
 'eval_samples_per_second': 118.431,
 'eval_steps_per_second': 0.947,
 'epoch': 14.0}

In [141]:
torch.save(model.state_dict(), f"{os.path.expanduser('~')}/models/{DATASET}/bilstm-distill-aug_coarse_embedd.pth")