In [1]:
from transformers import Trainer, BertForSequenceClassification, BertTokenizer, EarlyStoppingCallback, AutoConfig, TrainingArguments
from datasets import load_from_disk
from torch.utils.data import DataLoader
import torch
from torch import nn
from torch.nn import functional as F
import base
import os 

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/jovyan/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /home/jovyan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


In [2]:
base.reset_seed()

In [3]:
DATASET = "trec"

In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 2g.20gb


In [5]:
train = load_from_disk(f"~/data/{DATASET}/train-logits_fine")
eval = load_from_disk(f"~/data/{DATASET}/eval-logits_fine")
test = load_from_disk(f"~/data/{DATASET}/test-logits_fine")

train_aug = load_from_disk(f"~/data/{DATASET}/train-logits-augmented_fine")

In [6]:
tokenizer = BertTokenizer.from_pretrained("ndavid/autotrain-trec-fine-bert-739422530")

In [7]:
train = train.map(lambda e: tokenizer(e["sentence"], truncation=True, padding="max_length", return_tensors="pt", max_length=300), batched=True, desc="Tokenizing the train dataset")
eval = eval.map(lambda e: tokenizer(e["sentence"], truncation=True, padding="max_length", return_tensors="pt", max_length=300), batched=True, desc="Tokenizing the eval dataset")
test = test.map(lambda e: tokenizer(e["sentence"], truncation=True, padding="max_length", return_tensors="pt", max_length=300), batched=True, desc="Tokenizing the test dataset")

train_aug = train_aug.map(lambda e: tokenizer(e["sentence"], truncation=True, padding="max_length", return_tensors="pt", max_length=300), batched=True, desc="Tokenizing the augmented dataset")

In [8]:
base.reset_seed()

In [9]:
student_model = BertForSequenceClassification.from_pretrained("google/bert_uncased_L-2_H-128_A-2", num_labels=50)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
print(student_model)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-1): 2 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-1

In [11]:
model_path = f"{os.path.expanduser('~')}/models/{DATASET}/teacher_fine.pth"

config = AutoConfig.from_pretrained("ndavid/autotrain-trec-fine-bert-739422530")
config.max_length = 20 
config.num_labels = 50
config.output_hidden_states = True
teacher_model = BertForSequenceClassification.from_pretrained("ndavid/autotrain-trec-fine-bert-739422530", config=config, ignore_mismatched_sizes=True)
state_dict = torch.load(model_path, map_location=torch.device('cpu')) 

teacher_model.load_state_dict(state_dict)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ndavid/autotrain-trec-fine-bert-739422530 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([47, 768]) in the checkpoint and torch.Size([50, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([47]) in the checkpoint and torch.Size([50]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))


<All keys matched successfully>

In [12]:
teacher_model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [13]:
teacher_model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [14]:
class DistilTrainerInner(Trainer):
    """Distilation trainer, computes loss with logits from teacher in mind. Logits are precomputed."""
    def __init__(self, student_model=None, teacher_model = None, *args, **kwargs):
        super().__init__(model=student_model, *args, **kwargs)
        self.student = student_model
        self.teacher = teacher_model
        self.layer_loss_function = nn.MSELoss()
        self.logit_loss_function = nn.KLDivLoss(reduction="batchmean")
        self.temperature = self.args.temperature
        self.lambda_param = self.args.lambda_param
        self.alpha_param = self.args.alpha_param


        self.student_to_teacher = nn.Linear(128, 768).to(device)
        self.model_parameters = list(self.model.parameters()) + list(self.student_to_teacher.parameters())

    def compute_loss(self, student, inputs, return_outputs=False, num_items_in_batch=None):
        logits = inputs.pop("logits")
        
        student_output = student(**inputs, output_hidden_states=True)
        student_target_loss = student_output["loss"]

        with torch.no_grad():
            teacher_output = self.teacher(**inputs, output_hidden_states=True)

        teacher_hidden_states = teacher_output.hidden_states
        student_hidden_states = student_output.hidden_states

        teacher_l6 = teacher_hidden_states[6] / self.temperature
        teacher_l12 = teacher_hidden_states[12] / self.temperature
        student_l1 = student_hidden_states[1]
        student_l2 = student_hidden_states[2] 

        student_l1_projection = self.student_to_teacher(student_l1) / self.temperature
        student_l2_projection = self.student_to_teacher(student_l2) / self.temperature

        layer_distillation_loss = (
            self.layer_loss_function(student_l1_projection, teacher_l6) +
            self.layer_loss_function(student_l2_projection, teacher_l12)
        )

        

        soft_teacher = F.softmax(logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output['logits'] / self.temperature, dim=-1)

        logit_distillation_loss = self.logit_loss_function(soft_student, soft_teacher) * (self.temperature ** 2)
        logit_label_loss = ((1. - self.lambda_param) * student_target_loss + self.lambda_param * logit_distillation_loss)

        
        loss = (1 - self.alpha_param) * logit_label_loss + self.alpha_param * layer_distillation_loss

        
        return (loss, student_output) if return_outputs else loss
    
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        logits = inputs.pop("logits")
        inputs = self._prepare_inputs(inputs)
        with torch.no_grad():
            # For evaluation, we disable extra outputs.
            outputs = model(**inputs, output_hidden_states=False)
            loss = outputs.loss if "loss" in outputs else None
            logits = outputs.logits
        labels = inputs.get("labels")
        return loss, logits, labels

In [15]:
class Custom_training_args(TrainingArguments):
    """Custom wrapper of training args for distillation."""
    def __init__(self, lambda_param, alpha_param, temperature, *args, **kwargs):
        super().__init__(*args, **kwargs)    
        self.lambda_param = lambda_param
        self.alpha_param = alpha_param
        self.temperature = temperature

In [16]:
def get_training_args(output_dir, logging_dir, remove_unused_columns=True, lr=5e-5, epochs=5, weight_decay=0, adam_beta1 = .9, lambda_param=.5, alpha_param = .5, temp=5, batch_size=128, num_workers=4, warmup_steps=0):
    """Returns training args that can be adjusted."""
    return (
        Custom_training_args(
        output_dir=output_dir,
        eval_strategy="epoch",
        adam_beta1 = adam_beta1,
        warmup_steps = warmup_steps,
        save_strategy="epoch",
        logging_strategy="epoch",
        learning_rate=lr, #Defaultní hodnota 
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=epochs,
        weight_decay=weight_decay,
        seed = 42,  #Defaultní hodnota 
        metric_for_best_model="f1",
        load_best_model_at_end = True,
        fp16=True, 
        logging_dir=logging_dir,
        remove_unused_columns=remove_unused_columns,
        lambda_param = lambda_param,
        alpha_param = alpha_param, 
        temperature = temp,
        dataloader_num_workers=num_workers,
        )
    )

In [20]:
training_args = get_training_args(output_dir=f"~/results/{DATASET}/hokus_pokus", logging_dir=f"~/logs/{DATASET}/hokus_pokus", remove_unused_columns=False, warmup_steps=4, lr=5e-4, weight_decay=.003, batch_size=128, epochs=20, temp=2.5, lambda_param=.4, alpha_param=.5)

In [21]:
trainer = DistilTrainerInner(
    student_model = student_model,
    teacher_model = teacher_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [22]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.2485,2.304686,0.505041,0.200722,0.153706,0.138322
2,0.9176,1.770681,0.63428,0.218068,0.241304,0.217888
3,0.7064,1.415867,0.703941,0.291258,0.296876,0.27655
4,0.5717,1.279,0.715857,0.303754,0.325029,0.300446
5,0.4812,1.175202,0.734189,0.339924,0.354964,0.33347
6,0.4132,1.109791,0.745188,0.385101,0.385757,0.369391
7,0.3602,1.08845,0.753437,0.454085,0.418329,0.414791
8,0.3302,1.051247,0.769019,0.444176,0.442952,0.428646
9,0.2989,1.001823,0.774519,0.48377,0.467252,0.466109
10,0.2744,0.995045,0.781852,0.531094,0.491267,0.492554


TrainOutput(global_step=700, training_loss=0.39037716933659145, metrics={'train_runtime': 800.8023, 'train_samples_per_second': 108.916, 'train_steps_per_second': 0.874, 'total_flos': 65900954952000.0, 'train_loss': 0.39037716933659145, 'epoch': 20.0})

In [24]:
student_model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-1): 2 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-1

In [25]:
trainer.evaluate(test)

{'eval_loss': 1.0279921293258667,
 'eval_accuracy': 0.76,
 'eval_precision': 0.5204448048364814,
 'eval_recall': 0.5517825303517305,
 'eval_f1': 0.5117426131149784,
 'eval_runtime': 3.1235,
 'eval_samples_per_second': 160.077,
 'eval_steps_per_second': 1.281,
 'epoch': 20.0}

In [26]:
base.reset_seed()

In [27]:
student_model = BertForSequenceClassification.from_pretrained("google/bert_uncased_L-2_H-128_A-2", num_labels=50)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [28]:
training_args = get_training_args(output_dir=f"~/results/{DATASET}/hokus_pokus", logging_dir=f"~/logs/{DATASET}/hokus_pokus", remove_unused_columns=False, warmup_steps=4, lr=5e-4, weight_decay=.003, batch_size=128, epochs=20, temp=2.5, lambda_param=.4, alpha_param=.5)

In [None]:
class DistilTrainerInnerAVG(Trainer):
    """Distilation trainer, computes loss with logits from teacher in mind. Logits are precomputed."""
    def __init__(self, student_model=None, teacher_model = None, *args, **kwargs):
        super().__init__(model=student_model, *args, **kwargs)
        self.student = student_model
        self.teacher = teacher_model
        self.layer_loss_function = nn.MSELoss()
        self.logit_loss_function = nn.KLDivLoss(reduction="batchmean")
        self.temperature = self.args.temperature
        self.lambda_param = self.args.lambda_param
        self.alpha_param = self.args.alpha_param


        self.student_to_teacher = nn.Linear(128, 768).to(device)
        self.model_parameters = list(self.model.parameters()) + list(self.student_to_teacher.parameters())

    def compute_loss(self, student, inputs, return_outputs=False, num_items_in_batch=None):
        logits = inputs.pop("logits")
        
        student_output = student(**inputs, output_hidden_states=True)
        student_target_loss = student_output["loss"]

        with torch.no_grad():
            teacher_output = self.teacher(**inputs, output_hidden_states=True)

        teacher_hidden_states = teacher_output.hidden_states
        student_hidden_states = student_output.hidden_states

        
        
        teacher_l6 = torch.stack(teacher_hidden_states[2:7], dim=0).mean(dim=0) / self.temperature
        teacher_l12 = torch.stack(teacher_hidden_states[7:12], dim=0).mean(dim=0) / self.temperature
        student_l1 = student_hidden_states[1]
        student_l2 = student_hidden_states[2] 

        student_l1_projection = self.student_to_teacher(student_l1) / self.temperature
        student_l2_projection = self.student_to_teacher(student_l2) / self.temperature

        layer_distillation_loss = (
            self.layer_loss_function(student_l1_projection, teacher_l6) +
            self.layer_loss_function(student_l2_projection, teacher_l12)
        )

        

        soft_teacher = F.softmax(logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output['logits'] / self.temperature, dim=-1)

        logit_distillation_loss = self.logit_loss_function(soft_student, soft_teacher) * (self.temperature ** 2)
        logit_label_loss = ((1. - self.lambda_param) * student_target_loss + self.lambda_param * logit_distillation_loss)

        
        loss = (1 - self.alpha_param) * logit_label_loss + self.alpha_param * layer_distillation_loss

        
        return (loss, student_output) if return_outputs else loss
    
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        logits = inputs.pop("logits")
        inputs = self._prepare_inputs(inputs)
        with torch.no_grad():
            # For evaluation, we disable extra outputs.
            outputs = model(**inputs, output_hidden_states=False)
            loss = outputs.loss if "loss" in outputs else None
            logits = outputs.logits
        labels = inputs.get("labels")
        return loss, logits, labels

In [33]:
trainer = DistilTrainerInnerAVG(
    student_model = student_model,
    teacher_model = teacher_model,
    args=training_args,
    train_dataset=train,
    eval_dataset=eval,
    compute_metrics=base.compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [34]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.4801,2.86377,0.418882,0.087182,0.101284,0.076397
2,1.1032,2.138775,0.562786,0.208789,0.197205,0.174814
3,0.8217,1.635681,0.655362,0.267159,0.258225,0.237654
4,0.6315,1.374435,0.713107,0.302741,0.315566,0.295314
5,0.5105,1.25904,0.728689,0.369574,0.356522,0.33538
6,0.4308,1.177486,0.739688,0.364386,0.368775,0.350246
7,0.3692,1.139008,0.757104,0.475266,0.42413,0.416781
8,0.3322,1.116509,0.75802,0.405699,0.433149,0.410593
9,0.2949,1.078361,0.754354,0.425039,0.422598,0.410365
10,0.2665,1.06017,0.764436,0.461547,0.435401,0.431252


TrainOutput(global_step=700, training_loss=0.4159424332209996, metrics={'train_runtime': 801.2602, 'train_samples_per_second': 108.854, 'train_steps_per_second': 0.874, 'total_flos': 65900954952000.0, 'train_loss': 0.4159424332209996, 'epoch': 20.0})