In [1]:
from transformers import Trainer, BertForSequenceClassification, BertTokenizer, EarlyStoppingCallback, AutoConfig, TrainingArguments, AutoModelForImageClassification
from datasets import load_from_disk
from torch.utils.data import DataLoader
import torch
from torch import nn
from torch.nn import functional as F
import base
import os 

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/jovyan/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


In [2]:
base.reset_seed()

In [3]:
DATASET = "trec"

In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 2g.20gb


In [5]:
train = load_from_disk(f"~/data/{DATASET}/train-logits_fine")
eval = load_from_disk(f"~/data/{DATASET}/eval-logits_fine")
test = load_from_disk(f"~/data/{DATASET}/test-logits_fine")

train_aug = load_from_disk(f"~/data/{DATASET}/train-logits-augmented_fine")

In [6]:
tokenizer = BertTokenizer.from_pretrained("ndavid/autotrain-trec-fine-bert-739422530")

In [7]:
train = train.map(lambda e: tokenizer(e["sentence"], truncation=True, padding="max_length", return_tensors="pt", max_length=300), batched=True, desc="Tokenizing the train dataset")
eval = eval.map(lambda e: tokenizer(e["sentence"], truncation=True, padding="max_length", return_tensors="pt", max_length=300), batched=True, desc="Tokenizing the eval dataset")
test = test.map(lambda e: tokenizer(e["sentence"], truncation=True, padding="max_length", return_tensors="pt", max_length=300), batched=True, desc="Tokenizing the test dataset")

train_aug = train_aug.map(lambda e: tokenizer(e["sentence"], truncation=True, padding="max_length", return_tensors="pt", max_length=300), batched=True, desc="Tokenizing the augmented dataset")

In [8]:
base.reset_seed()

In [52]:
student_model = BertForSequenceClassification.from_pretrained("google/bert_uncased_L-2_H-128_A-2", num_labels=50)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
print(student_model)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-1): 2 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-1

In [11]:
config = AutoConfig.from_pretrained("ndavid/autotrain-trec-fine-bert-739422530")
config.max_length = 20 #revert to default ot skip warning 
config.num_labels = 50
teacher_model = BertForSequenceClassification.from_pretrained("ndavid/autotrain-trec-fine-bert-739422530", config=config, ignore_mismatched_sizes=True)
model_path = f"{os.path.expanduser('~')}/models/{DATASET}/teacher_fine.pth"
state_dict = torch.load(model_path, map_location=torch.device('cpu')) 
teacher_model.load_state_dict(state_dict)
teacher_model.to(device)
teacher_model.eval()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ndavid/autotrain-trec-fine-bert-739422530 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([47, 768]) in the checkpoint and torch.Size([50, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([47]) in the checkpoint and torch.Size([50]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [18]:
class DistilTrainerInner(Trainer):
    """Distilation trainer, computes loss with logits from teacher in mind. Logits are precomputed."""
    def __init__(self, student_model=None, teacher_model = None, *args, **kwargs):
        super().__init__(model=student_model, *args, **kwargs)
        self.student = student_model
        self.teacher = teacher_model
        self.layer_loss_function = nn.MSELoss()
        self.logit_loss_function = nn.KLDivLoss(reduction="batchmean")
        self.temperature = self.args.temperature
        self.lambda_param = self.args.lambda_param
        self.alpha_param = self.args.alpha_param

        self.student_to_teacher = nn.Linear(128, 768).to(device)


    def compute_loss(self, student, inputs, return_outputs=False, num_items_in_batch=None):
        logits = inputs.pop("logits")
        
        student_output = student(**inputs, output_hidden_states=True)
        student_target_loss = student_output["loss"]

        with torch.no_grad():
            teacher_output = self.teacher(**inputs, output_hidden_states=True)

        teacher_hidden_states = teacher_output.hidden_states
        student_hidden_states = student_output.hidden_states

        
        teacher_l6 = teacher_hidden_states[6] / self.temperature
        teacher_l12 = teacher_hidden_states[12] / self.temperature
        student_l1 = student_hidden_states[1]
        student_l2 = student_hidden_states[2] 

        student_l1_projection = self.student_to_teacher(student_l1) / self.temperature
        student_l2_projection = self.student_to_teacher(student_l2) / self.temperature

        layer_distillation_loss = (
            self.layer_loss_function(student_l1_projection, teacher_l6) +
            self.layer_loss_function(student_l2_projection, teacher_l12)
        )

        

        soft_teacher = F.softmax(logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output['logits'] / self.temperature, dim=-1)

        logit_distillation_loss = self.logit_loss_function(soft_student, soft_teacher) * (self.temperature ** 2)
        logit_label_loss = ((1. - self.lambda_param) * student_target_loss + self.lambda_param * logit_distillation_loss)

        
        loss = (1 - self.alpha_param) * logit_label_loss + self.alpha_param * layer_distillation_loss

        
        return (loss, student_output) if return_outputs else loss
    
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        logits = inputs.pop("logits")
        inputs = self._prepare_inputs(inputs)
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=False)
            loss = outputs.loss if "loss" in outputs else None
            logits = outputs.logits
        labels = inputs.get("labels")
        return loss, logits, labels

In [19]:
class Custom_training_args(TrainingArguments):
    """Custom wrapper of training args for distillation."""
    def __init__(self, lambda_param, alpha_param, temperature, *args, **kwargs):
        super().__init__(*args, **kwargs)    
        self.lambda_param = lambda_param
        self.alpha_param = alpha_param
        self.temperature = temperature

In [20]:
def get_training_args(output_dir, logging_dir, remove_unused_columns=True, lr=5e-5, epochs=5, weight_decay=0, adam_beta1 = .9, lambda_param=.5, alpha_param = .5, temp=5, batch_size=128, num_workers=4, warmup_steps=0):
    """Returns training args that can be adjusted."""
    return (
        Custom_training_args(
        output_dir=output_dir,
        eval_strategy="epoch",
        adam_beta1 = adam_beta1,
        warmup_steps = warmup_steps,
        save_strategy="epoch",
        logging_strategy="epoch",
        learning_rate=lr, #Defaultní hodnota 
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=epochs,
        weight_decay=weight_decay,
        seed = 42,  #Defaultní hodnota 
        metric_for_best_model="f1",
        load_best_model_at_end = True,
        fp16=True, 
        logging_dir=logging_dir,
        remove_unused_columns=remove_unused_columns,
        lambda_param = lambda_param,
        alpha_param = alpha_param, 
        temperature = temp,
        dataloader_num_workers=num_workers,
        )
    )

In [53]:
training_args = get_training_args(output_dir=f"~/results/{DATASET}/hokus_pokus", logging_dir=f"~/logs/{DATASET}/hokus_pokus", remove_unused_columns=False, lr=8e-4, batch_size=128, epochs=20, temp=2, lambda_param=0, alpha_param=.3)

In [54]:
trainer = DistilTrainerInner(
    student_model = student_model,
    teacher_model = teacher_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [55]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.3517,2.622619,0.450046,0.124208,0.126271,0.104896
2,1.6341,1.84112,0.6022,0.215776,0.250789,0.224072
3,1.1457,1.420704,0.695692,0.376744,0.338916,0.323342
4,0.8456,1.239802,0.730522,0.413277,0.38097,0.368196
5,0.6416,1.134269,0.752521,0.448728,0.440872,0.423798
6,0.5122,1.074933,0.759853,0.483631,0.462137,0.455504
7,0.4037,1.051079,0.756187,0.512179,0.477353,0.478819
8,0.3366,1.044749,0.774519,0.586416,0.539534,0.54769
9,0.2804,1.036319,0.771769,0.616361,0.559525,0.566792
10,0.2341,1.04296,0.770852,0.628541,0.565279,0.572681


TrainOutput(global_step=700, training_loss=0.4959136758531843, metrics={'train_runtime': 355.5705, 'train_samples_per_second': 245.296, 'train_steps_per_second': 1.969, 'total_flos': 65900954952000.0, 'train_loss': 0.4959136758531843, 'epoch': 20.0})

In [56]:
student_model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-1): 2 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-1

In [57]:
trainer.evaluate(test)

{'eval_loss': 1.0209091901779175,
 'eval_accuracy': 0.786,
 'eval_precision': 0.6741836658624986,
 'eval_recall': 0.6495346381733221,
 'eval_f1': 0.6295482650549977,
 'eval_runtime': 3.4098,
 'eval_samples_per_second': 146.638,
 'eval_steps_per_second': 1.173,
 'epoch': 20.0}

In [None]:
training_args = base.get_training_args(output_dir=f"~/results/{DATASET}/hokus_pokus", logging_dir=f"~/logs/{DATASET}/hokus_pokus", remove_unused_columns=False, lr=8e-4, batch_size=128, epochs=20,  temp=2, lambda_param=.4)

In [59]:
student_model = BertForSequenceClassification.from_pretrained("google/bert_uncased_L-2_H-128_A-2", num_labels=50)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [60]:
trainer = base.DistilTrainer(
    student_model = student_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [61]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.3077,1.931338,0.404216,0.066452,0.093272,0.06383
2,1.6243,1.37064,0.613199,0.248521,0.25051,0.23079
3,1.1516,1.084126,0.695692,0.291288,0.303464,0.280613
4,0.8665,0.952221,0.728689,0.349467,0.357325,0.338163
5,0.6797,0.87876,0.747938,0.399649,0.400221,0.382936
6,0.5478,0.84847,0.75527,0.429065,0.408227,0.401909
7,0.4478,0.817848,0.76077,0.491511,0.433454,0.436979
8,0.3779,0.810743,0.773602,0.50977,0.479757,0.480758
9,0.3158,0.79083,0.779102,0.547607,0.494367,0.502936
10,0.2721,0.802734,0.764436,0.543249,0.494737,0.501678


TrainOutput(global_step=700, training_loss=0.5186817496163505, metrics={'train_runtime': 107.5116, 'train_samples_per_second': 811.262, 'train_steps_per_second': 6.511, 'total_flos': 65900954952000.0, 'train_loss': 0.5186817496163505, 'epoch': 20.0})

In [62]:
student_model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-1): 2 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-1

In [63]:
trainer.evaluate(test)

{'eval_loss': 0.766819179058075,
 'eval_accuracy': 0.762,
 'eval_precision': 0.6137723868676089,
 'eval_recall': 0.6109359742797066,
 'eval_f1': 0.5813720599288332,
 'eval_runtime': 3.5397,
 'eval_samples_per_second': 141.254,
 'eval_steps_per_second': 1.13,
 'epoch': 20.0}

In [26]:
base.reset_seed()

In [64]:
student_model = BertForSequenceClassification.from_pretrained("google/bert_uncased_L-2_H-128_A-2", num_labels=50)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [69]:
training_args = get_training_args(output_dir=f"~/results/{DATASET}/hokus_pokus", logging_dir=f"~/logs/{DATASET}/hokus_pokus", remove_unused_columns=False, lr=8e-4, batch_size=128, epochs=20, temp=2, lambda_param=0, alpha_param=.3)

In [70]:
class DistilTrainerInnerAVG(Trainer):
    """Distilation trainer, computes loss with logits from teacher in mind. Logits are precomputed."""
    def __init__(self, student_model=None, teacher_model = None, *args, **kwargs):
        super().__init__(model=student_model, *args, **kwargs)
        self.student = student_model
        self.teacher = teacher_model
        self.layer_loss_function = nn.MSELoss()
        self.logit_loss_function = nn.KLDivLoss(reduction="batchmean")
        self.temperature = self.args.temperature
        self.lambda_param = self.args.lambda_param
        self.alpha_param = self.args.alpha_param


        self.student_to_teacher = nn.Linear(128, 768).to(device)
        self.model_parameters = list(self.model.parameters()) + list(self.student_to_teacher.parameters())

    def compute_loss(self, student, inputs, return_outputs=False, num_items_in_batch=None):
        logits = inputs.pop("logits")
        
        student_output = student(**inputs, output_hidden_states=True)
        student_target_loss = student_output["loss"]

        with torch.no_grad():
            teacher_output = self.teacher(**inputs, output_hidden_states=True)

        teacher_hidden_states = teacher_output.hidden_states
        student_hidden_states = student_output.hidden_states

        
        
        teacher_l6 = torch.stack(teacher_hidden_states[1:7], dim=0).mean(dim=0) / self.temperature
        teacher_l12 = torch.stack(teacher_hidden_states[7:13], dim=0).mean(dim=0) / self.temperature
        student_l1 = student_hidden_states[1]
        student_l2 = student_hidden_states[2] 

        student_l1_projection = self.student_to_teacher(student_l1) / self.temperature
        student_l2_projection = self.student_to_teacher(student_l2) / self.temperature

        layer_distillation_loss = (
            self.layer_loss_function(student_l1_projection, teacher_l6) +
            self.layer_loss_function(student_l2_projection, teacher_l12)
        )

        

        soft_teacher = F.softmax(logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output['logits'] / self.temperature, dim=-1)

        logit_distillation_loss = self.logit_loss_function(soft_student, soft_teacher) * (self.temperature ** 2)
        logit_label_loss = ((1. - self.lambda_param) * student_target_loss + self.lambda_param * logit_distillation_loss)

        
        loss = (1 - self.alpha_param) * logit_label_loss + self.alpha_param * layer_distillation_loss

        
        return (loss, student_output) if return_outputs else loss
    
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        logits = inputs.pop("logits")
        inputs = self._prepare_inputs(inputs)
        with torch.no_grad():
            # For evaluation, we disable extra outputs.
            outputs = model(**inputs, output_hidden_states=False)
            loss = outputs.loss if "loss" in outputs else None
            logits = outputs.logits
        labels = inputs.get("labels")
        return loss, logits, labels

In [71]:
trainer = DistilTrainerInnerAVG(
    student_model = student_model,
    teacher_model = teacher_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [72]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.3086,2.604124,0.442713,0.067964,0.109693,0.077831
2,1.6164,1.82571,0.586618,0.250198,0.224954,0.20227
3,1.1172,1.388694,0.705775,0.331629,0.316616,0.291173
4,0.8161,1.204824,0.737855,0.383825,0.386627,0.370825
5,0.6246,1.085495,0.744271,0.412612,0.413672,0.395178
6,0.4863,1.032137,0.761687,0.479304,0.45425,0.448534
7,0.3759,1.047209,0.756187,0.549809,0.488388,0.487681
8,0.3073,1.03051,0.771769,0.546652,0.541012,0.529507
9,0.2472,1.027601,0.770852,0.569284,0.553769,0.546604
10,0.2074,0.982512,0.776352,0.580372,0.556895,0.551229


TrainOutput(global_step=700, training_loss=0.468877078805651, metrics={'train_runtime': 366.2642, 'train_samples_per_second': 238.134, 'train_steps_per_second': 1.911, 'total_flos': 65900954952000.0, 'train_loss': 0.468877078805651, 'epoch': 20.0})

In [73]:
student_model.eval()
trainer.evaluate(test)

{'eval_loss': 0.9729981422424316,
 'eval_accuracy': 0.778,
 'eval_precision': 0.6108235367341464,
 'eval_recall': 0.6287281800256842,
 'eval_f1': 0.585669900362865,
 'eval_runtime': 3.7778,
 'eval_samples_per_second': 132.353,
 'eval_steps_per_second': 1.059,
 'epoch': 20.0}