In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from datasets import load_dataset
import datasets
from transformers import BertTokenizerFast, AutoModelForTokenClassification, DataCollatorForTokenClassification

In [2]:
conll_03 = load_dataset("conll2003")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [3]:
conll_03

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3453
    })
})

In [4]:
conll_03.shape

{'train': (14041, 5), 'validation': (3250, 5), 'test': (3453, 5)}

In [5]:
conll_03['train'][0]

{'id': '0',
 'tokens': ['EU',
  'rejects',
  'German',
  'call',
  'to',
  'boycott',
  'British',
  'lamb',
  '.'],
 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7],
 'chunk_tags': [11, 21, 11, 12, 21, 22, 11, 12, 0],
 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0]}

In [6]:
conll_03['train'].features['ner_tags']

Sequence(feature=ClassLabel(names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'], id=None), length=-1, id=None)

In [7]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

In [8]:
example_text = conll_03['train'][0]

In [9]:
tokenizer_input = tokenizer(example_text['tokens'], is_split_into_words=True)
tokens = tokenizer.convert_ids_to_tokens(tokenizer_input['input_ids'])
word_ids = tokenizer_input.word_ids()

In [10]:
word_ids

[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, None]

In [11]:
tokens

['[CLS]',
 'eu',
 'rejects',
 'german',
 'call',
 'to',
 'boycott',
 'british',
 'lamb',
 '.',
 '[SEP]']

In [12]:
len(example_text['ner_tags']), len(tokenizer_input['input_ids']), len(tokenizer_input['attention_mask'])

(9, 11, 11)

## Problem

As shown above, the ner tags and tokenizer input doesn't have same length due to tokenizer using subwords token and cli tokens as addition.

To solve this we have to refer the word_ids

## Solution

define new function `def tokenize_and_align_labels()` which performs two tasks

1. set -100 as label for special tokens such as CLI token and the subwords we wish to mask during the training. Becase, PyTorch ignores the -100 index during the training time.
2. mask the subword representations after the first subword

Then we aligns the labels with tokens ids using the strategy we picked.

In [13]:
def tokenize_and_aligns_labels(example, label_all_tokens=True):
    tokenizer_input = tokenizer(example['tokens'], is_split_into_words=True, truncation=True)
    labels = []
    
    for i, label in enumerate(example['ner_tags']):
        word_ids = tokenizer_input.word_ids(batch_index=i) # word_ids() => List of mapped token indices to their actual word in the initial sentence.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx
        labels.append(label_ids)
    tokenizer_input['labels'] = labels
    return tokenizer_input

In [14]:
q = tokenize_and_aligns_labels(conll_03['train'][4:5])

In [15]:
print(q)

{'input_ids': [[101, 2762, 1005, 1055, 4387, 2000, 1996, 2647, 2586, 1005, 1055, 15651, 2837, 14121, 1062, 9328, 5804, 2056, 2006, 9317, 10390, 2323, 4965, 8351, 4168, 4017, 2013, 3032, 2060, 2084, 3725, 2127, 1996, 4045, 6040, 2001, 24509, 1012, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'labels': [[-100, 5, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, -100]]}


In [16]:
for token, label in zip(tokenizer.convert_ids_to_tokens(q['input_ids'][0]), q['labels'][0]):
    print(f'{token:_<40}:{label}')

[CLS]___________________________________:-100
germany_________________________________:5
'_______________________________________:0
s_______________________________________:0
representative__________________________:0
to______________________________________:0
the_____________________________________:0
european________________________________:3
union___________________________________:4
'_______________________________________:0
s_______________________________________:0
veterinary______________________________:0
committee_______________________________:0
werner__________________________________:1
z_______________________________________:2
##wing__________________________________:2
##mann__________________________________:2
said____________________________________:0
on______________________________________:0
wednesday_______________________________:0
consumers_______________________________:0
should__________________________________:0
buy_____________________________________:0
sheep___

In [17]:
tokenized_datasets = conll_03.map(tokenize_and_aligns_labels, batched=True)

In [18]:
teacher = AutoModelForTokenClassification.from_pretrained('bert-base-uncased', num_labels=conll_03['train'].features['ner_tags'].feature.num_classes)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [19]:
conll_03['train'].features['ner_tags'].feature.num_classes

9

In [20]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 3453
    })
})

In [21]:
from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    'test-ner',
    evaluation_strategy='epoch',
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    weight_decay=0.01,
    num_train_epochs=2,
    logging_dir='./logs',
    logging_steps=10,
)

In [22]:
data_collator = DataCollatorForTokenClassification(tokenizer) 
# DataCollatorForTokenClassification => Collate function that is used to dynamically pad the inputs received by the model.
# This is useful to pad the inputs to the maximum length of the batch with the maximum sequence length in the batch, which is not possible with the default data collate function provided by PyTorch.

In [23]:
metric = datasets.load_metric('seqeval')

  metric = datasets.load_metric('seqeval')
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [24]:
label_list = conll_03['train'].features['ner_tags'].feature.names

# Example for Mertics

In [25]:
exp_labels = [label_list[i] for i in example_text['ner_tags']]
exp_labels

['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O']

In [26]:
metric.compute(predictions=[exp_labels], references=[exp_labels])

{'MISC': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 2},
 'ORG': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1},
 'overall_precision': 1.0,
 'overall_recall': 1.0,
 'overall_f1': 1.0,
 'overall_accuracy': 1.0}

## Compute Metrics Function

In [27]:
def compute_metrics(eval_preds):
    
    pred_logits, labels = eval_preds 
    
    pred_logits = np.argmax(pred_logits, axis=2) 
    # the logits and the probabilities are in the same order,
    # so we don’t need to apply the softmax
    
    # We remove all the values where the label is -100
    predictions = [ 
        [label_list[eval_preds] for (eval_preds, l) in zip(prediction, label) if l != -100] 
        for prediction, label in zip(pred_logits, labels) 
    ] 
    
    true_labels = [ 
      [label_list[l] for (eval_preds, l) in zip(prediction, label) if l != -100] 
       for prediction, label in zip(pred_logits, labels) 
    ] 
    
    results = metric.compute(predictions=predictions, references=true_labels) 
    
    return {
        'precision': results['overall_precision'],
        'recall': results['overall_recall'],
        'f1': results['overall_f1'],
        'accuracy': results['overall_accuracy']
    }

In [28]:
trainer = Trainer(
    teacher,
    args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False)


In [None]:
trainer.train()

teacher.save_pretrained('ner-model')
tokenizer.save_pretrained('ner-model')

In [30]:
id2label = {i: label for i, label in enumerate(label_list)}

label2id = {label: i for i, label in enumerate(label_list)}

In [31]:
import json

config = json.load(open('ner-model/config.json'))

In [32]:
config['id2label'] = id2label
config['label2id'] = label2id

json.dump(config, open('ner-model/config.json', 'w'))

In [33]:
from transformers import pipeline, AutoModelForTokenClassification, BertTokenizerFast

model_finetuned = AutoModelForTokenClassification.from_pretrained('ner-model')

tokenizer_finetuned = BertTokenizerFast.from_pretrained('ner-model')

nlp = pipeline('ner', model=model_finetuned, tokenizer=tokenizer_finetuned)

nlp('Hugging Face is a French company founded in 2016.')

[{'entity': 'B-ORG',
  'score': 0.8851787,
  'index': 1,
  'word': 'hugging',
  'start': 0,
  'end': 7},
 {'entity': 'I-ORG',
  'score': 0.88827014,
  'index': 2,
  'word': 'face',
  'start': 8,
  'end': 12},
 {'entity': 'B-MISC',
  'score': 0.99571615,
  'index': 5,
  'word': 'french',
  'start': 18,
  'end': 24}]

In [34]:
# get model number of parameters
model_finetuned.num_parameters()

108898569

## Distill BERT

In [35]:
label_list = conll_03['train'].features['ner_tags'].feature.names

id2label = {i: label for i, label in enumerate(label_list)}

label2id = {label: i for i, label in enumerate(label_list)}

In [36]:
from transformers import AutoModelForTokenClassification, BertTokenizerFast

student_model_card = "distilbert/distilbert-base-uncased"
student_model = AutoModelForTokenClassification.from_pretrained(student_model_card, num_labels=conll_03['train'].features['ner_tags'].feature.num_classes)
student_tokenizer = BertTokenizerFast.from_pretrained(student_model_card)

# Assuming `conll_03` is a dataset object and `id2label`, `label2id` are dictionaries defined earlier:
student_tokenizer = BertTokenizerFast.from_pretrained(
    student_model_card, 
    num_labels=conll_03['train'].features['ner_tags'].feature.num_classes, 
    id2label=id2label, 
    label2id=label2id
)

Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DistilBertTokenizer'. 
The class this function is called from is 'BertTokenizerFast'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DistilBertTokenizer'. 
The class this function is called from is 'BertTokenizerFast'.


In [37]:
def tokenize_and_aligns_labels_student(example, label_all_tokens=True):
    tokenizer_input = student_tokenizer(example['tokens'], is_split_into_words=True, truncation=True)
    labels = []
    
    for i, label in enumerate(example['ner_tags']):
        word_ids = tokenizer_input.word_ids(batch_index=i) # word_ids() => List of mapped token indices to their actual word in the initial sentence.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx
        labels.append(label_ids)
    tokenizer_input['labels'] = labels
    return tokenizer_input

tokenized_datasets_student = conll_03.map(tokenize_and_aligns_labels_student, batched=True)

In [38]:
tokenized_datasets_student

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 3453
    })
})

In [39]:
data_collator_student = DataCollatorForTokenClassification(student_tokenizer)

In [40]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
# load teacher model
teacher = AutoModelForTokenClassification.from_pretrained('ner-model')

# load student model and teacher model on gpu

teacher = teacher.to('cuda')
student_model = student_model.to('cuda')

In [41]:
from transformers import TrainingArguments, Trainer

args_student = TrainingArguments(
    'test-ner-student',
    evaluation_strategy='epoch',
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    weight_decay=0.01,
    num_train_epochs=2,
    logging_dir='./logs',
    logging_steps=10,
)

class KDTrainer(Trainer):
    
    def __init__(self, *args, alpha=0.5, temperature=2.0, teacher=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.temperature = temperature
        self.teacher = teacher
    
    def compute_loss(self, model, inputs, return_outputs=False):
        
        inputs = inputs.to(model.device)
        
        teacher_outputs = self.teacher(**inputs)
        student_outputs = model(**inputs)
        
        # We don't need the logits, only the loss
        teacher_logits = teacher_outputs.logits
        student_logits = student_outputs.logits
        
        # print(teacher_logits.shape)
        # print(student_logits.shape)

        
        # kd loss
        loss_fct = nn.KLDivLoss(reduction='batchmean')
        
        loss_kd = loss_fct(
            F.log_softmax(student_logits / self.temperature, dim=-1),
            F.softmax(teacher_logits / self.temperature, dim=-1)
        )
        
        # ce loss
        loss_ce = student_outputs.loss
        
        # total loss
        loss = (1 - self.alpha) * loss_ce + self.alpha * self.temperature ** 2 * loss_kd
        
        return (loss, student_outputs) if return_outputs else loss

In [42]:
from datasets import load_metric

metric = datasets.load_metric('seqeval')

def compute_metrics(eval_preds):
    
    pred_logits, labels = eval_preds 
    
    pred_logits = np.argmax(pred_logits, axis=2) 
    # the logits and the probabilities are in the same order,
    # so we don’t need to apply the softmax
    
    # We remove all the values where the label is -100
    predictions = [ 
        [label_list[eval_preds] for (eval_preds, l) in zip(prediction, label) if l != -100] 
        for prediction, label in zip(pred_logits, labels) 
    ] 
    
    true_labels = [ 
      [label_list[l] for (eval_preds, l) in zip(prediction, label) if l != -100] 
       for prediction, label in zip(pred_logits, labels) 
    ] 
    
    results = metric.compute(predictions=predictions, references=true_labels) 
    
    return {
        'precision': results['overall_precision'],
        'recall': results['overall_recall'],
        'f1': results['overall_f1'],
        'accuracy': results['overall_accuracy']
    }

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [43]:
# trainer

trainer_student = KDTrainer(
    student_model,
    args_student,
    train_dataset=tokenized_datasets_student['train'],
    eval_dataset=tokenized_datasets_student['validation'],
    data_collator=data_collator_student,
    tokenizer=student_tokenizer,
    compute_metrics=compute_metrics,
    teacher=teacher,
    alpha=0.5,
    temperature=2.0
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False)


In [44]:
trainer_student.train()

  0%|          | 0/1756 [00:00<?, ?it/s]

{'loss': 120.0117, 'learning_rate': 1.988610478359909e-05, 'epoch': 0.01}
{'loss': 64.7227, 'learning_rate': 1.977220956719818e-05, 'epoch': 0.02}
{'loss': 37.7539, 'learning_rate': 1.9658314350797268e-05, 'epoch': 0.03}
{'loss': 32.4821, 'learning_rate': 1.9544419134396357e-05, 'epoch': 0.05}
{'loss': 32.2053, 'learning_rate': 1.9430523917995446e-05, 'epoch': 0.06}
{'loss': 26.823, 'learning_rate': 1.9316628701594535e-05, 'epoch': 0.07}
{'loss': 22.7437, 'learning_rate': 1.9202733485193623e-05, 'epoch': 0.08}
{'loss': 22.0703, 'learning_rate': 1.9088838268792712e-05, 'epoch': 0.09}
{'loss': 20.1889, 'learning_rate': 1.89749430523918e-05, 'epoch': 0.1}
{'loss': 17.9789, 'learning_rate': 1.886104783599089e-05, 'epoch': 0.11}
{'loss': 16.478, 'learning_rate': 1.874715261958998e-05, 'epoch': 0.13}
{'loss': 15.728, 'learning_rate': 1.8633257403189068e-05, 'epoch': 0.14}
{'loss': 15.5714, 'learning_rate': 1.8519362186788156e-05, 'epoch': 0.15}
{'loss': 14.5941, 'learning_rate': 1.8405466970

Checkpoint destination directory test-ner-student\checkpoint-500 already exists and is non-empty.Saving will proceed but saved results may be invalid.


{'loss': 6.2463, 'learning_rate': 1.4305239179954442e-05, 'epoch': 0.57}
{'loss': 6.249, 'learning_rate': 1.4191343963553532e-05, 'epoch': 0.58}
{'loss': 5.7947, 'learning_rate': 1.407744874715262e-05, 'epoch': 0.59}
{'loss': 5.1118, 'learning_rate': 1.396355353075171e-05, 'epoch': 0.6}
{'loss': 5.6704, 'learning_rate': 1.3849658314350799e-05, 'epoch': 0.62}
{'loss': 5.1358, 'learning_rate': 1.3735763097949887e-05, 'epoch': 0.63}
{'loss': 6.0096, 'learning_rate': 1.3621867881548976e-05, 'epoch': 0.64}
{'loss': 6.3806, 'learning_rate': 1.3507972665148065e-05, 'epoch': 0.65}
{'loss': 6.1514, 'learning_rate': 1.3394077448747154e-05, 'epoch': 0.66}
{'loss': 5.8057, 'learning_rate': 1.3280182232346241e-05, 'epoch': 0.67}
{'loss': 6.1939, 'learning_rate': 1.3166287015945332e-05, 'epoch': 0.68}
{'loss': 5.5551, 'learning_rate': 1.3052391799544419e-05, 'epoch': 0.69}
{'loss': 5.0283, 'learning_rate': 1.293849658314351e-05, 'epoch': 0.71}
{'loss': 4.2156, 'learning_rate': 1.2824601366742598e-05

  0%|          | 0/204 [00:00<?, ?it/s]

{'eval_loss': 3.05456280708313, 'eval_precision': 0.8764436696448028, 'eval_recall': 0.8998769437297237, 'eval_f1': 0.888005740464757, 'eval_accuracy': 0.9758685878596279, 'eval_runtime': 6.0873, 'eval_samples_per_second': 533.901, 'eval_steps_per_second': 33.513, 'epoch': 1.0}
{'loss': 4.5301, 'learning_rate': 9.977220956719819e-06, 'epoch': 1.0}
{'loss': 4.1304, 'learning_rate': 9.863325740318908e-06, 'epoch': 1.01}
{'loss': 5.2209, 'learning_rate': 9.749430523917997e-06, 'epoch': 1.03}
{'loss': 4.5231, 'learning_rate': 9.635535307517085e-06, 'epoch': 1.04}
{'loss': 3.8869, 'learning_rate': 9.521640091116174e-06, 'epoch': 1.05}
{'loss': 4.0269, 'learning_rate': 9.407744874715261e-06, 'epoch': 1.06}
{'loss': 3.8067, 'learning_rate': 9.293849658314352e-06, 'epoch': 1.07}
{'loss': 3.3414, 'learning_rate': 9.17995444191344e-06, 'epoch': 1.08}
{'loss': 4.4124, 'learning_rate': 9.06605922551253e-06, 'epoch': 1.09}
{'loss': 3.7504, 'learning_rate': 8.952164009111618e-06, 'epoch': 1.1}
{'los

Checkpoint destination directory test-ner-student\checkpoint-1000 already exists and is non-empty.Saving will proceed but saved results may be invalid.


{'loss': 3.7391, 'learning_rate': 8.610478359908885e-06, 'epoch': 1.14}
{'loss': 4.3186, 'learning_rate': 8.496583143507974e-06, 'epoch': 1.15}
{'loss': 3.2366, 'learning_rate': 8.382687927107063e-06, 'epoch': 1.16}
{'loss': 4.7049, 'learning_rate': 8.26879271070615e-06, 'epoch': 1.17}
{'loss': 3.3936, 'learning_rate': 8.15489749430524e-06, 'epoch': 1.18}
{'loss': 4.6457, 'learning_rate': 8.041002277904329e-06, 'epoch': 1.2}
{'loss': 4.7592, 'learning_rate': 7.927107061503418e-06, 'epoch': 1.21}
{'loss': 4.2475, 'learning_rate': 7.813211845102507e-06, 'epoch': 1.22}
{'loss': 4.1489, 'learning_rate': 7.699316628701596e-06, 'epoch': 1.23}
{'loss': 4.4393, 'learning_rate': 7.585421412300684e-06, 'epoch': 1.24}
{'loss': 4.2335, 'learning_rate': 7.471526195899773e-06, 'epoch': 1.25}
{'loss': 3.4916, 'learning_rate': 7.357630979498862e-06, 'epoch': 1.26}
{'loss': 4.1418, 'learning_rate': 7.243735763097951e-06, 'epoch': 1.28}
{'loss': 3.2802, 'learning_rate': 7.129840546697039e-06, 'epoch': 1

Checkpoint destination directory test-ner-student\checkpoint-1500 already exists and is non-empty.Saving will proceed but saved results may be invalid.


{'loss': 3.408, 'learning_rate': 2.9157175398633257e-06, 'epoch': 1.71}
{'loss': 3.2871, 'learning_rate': 2.801822323462415e-06, 'epoch': 1.72}
{'loss': 3.5728, 'learning_rate': 2.687927107061504e-06, 'epoch': 1.73}
{'loss': 3.6608, 'learning_rate': 2.5740318906605926e-06, 'epoch': 1.74}
{'loss': 3.0918, 'learning_rate': 2.4601366742596815e-06, 'epoch': 1.75}
{'loss': 3.3896, 'learning_rate': 2.34624145785877e-06, 'epoch': 1.77}
{'loss': 2.8706, 'learning_rate': 2.232346241457859e-06, 'epoch': 1.78}
{'loss': 3.4026, 'learning_rate': 2.118451025056948e-06, 'epoch': 1.79}
{'loss': 3.6588, 'learning_rate': 2.0045558086560364e-06, 'epoch': 1.8}
{'loss': 4.1256, 'learning_rate': 1.8906605922551254e-06, 'epoch': 1.81}
{'loss': 3.5231, 'learning_rate': 1.7767653758542143e-06, 'epoch': 1.82}
{'loss': 3.3512, 'learning_rate': 1.662870159453303e-06, 'epoch': 1.83}
{'loss': 3.4139, 'learning_rate': 1.5489749430523921e-06, 'epoch': 1.85}
{'loss': 3.5857, 'learning_rate': 1.4350797266514807e-06, 'e

  0%|          | 0/204 [00:00<?, ?it/s]

{'eval_loss': 2.4997339248657227, 'eval_precision': 0.9012549537648613, 'eval_recall': 0.9158742588656449, 'eval_f1': 0.9085057981468124, 'eval_accuracy': 0.9793794779735333, 'eval_runtime': 10.1415, 'eval_samples_per_second': 320.465, 'eval_steps_per_second': 20.115, 'epoch': 2.0}
{'train_runtime': 210.3598, 'train_samples_per_second': 133.495, 'train_steps_per_second': 8.348, 'train_loss': 7.4272418000431975, 'epoch': 2.0}


TrainOutput(global_step=1756, training_loss=7.4272418000431975, metrics={'train_runtime': 210.3598, 'train_samples_per_second': 133.495, 'train_steps_per_second': 8.348, 'train_loss': 7.4272418000431975, 'epoch': 2.0})

In [45]:

import json

# load student model
student_model.save_pretrained('ner-student')
student_tokenizer.save_pretrained('ner-student')

config = json.load(open('ner-student/config.json'))

config['id2label'] = id2label
config['label2id'] = label2id


json.dump(config, open('ner-student/config.json', 'w'))

In [145]:

import torch
from transformers import AutoModelForTokenClassification, BertTokenizerFast, pipeline
import torch.nn.functional as F

# Load student model and tokenizer
student_model = AutoModelForTokenClassification.from_pretrained('ner-student')
student_tokenizer = BertTokenizerFast.from_pretrained('ner-student')

# Define label mapping (id2label)
id2label = student_model.config.id2label

# Modify pipeline to remove token_type_ids and get human-readable output
def custom_pipeline(text):
    # Get the inputs from the tokenizer
    inputs = student_tokenizer(text, return_tensors="pt", truncation=True)
    
    # Remove 'token_type_ids' if present (as DistilBERT doesn't use them)
    if 'token_type_ids' in inputs:
        inputs.pop('token_type_ids')

    # Get model predictions
    with torch.no_grad():
        output = student_model(**inputs)

    logits = output.logits
    probabilities = F.softmax(logits, dim=-1)  # Get probabilities using softmax
    predictions = torch.argmax(probabilities, dim=-1).squeeze()  # Get the index of the highest score
    
    # Convert predictions to label and score
    results = []
    tokens = student_tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze())
    for i, token in enumerate(tokens):
        if token.startswith('##'):  # Skip subword tokens to avoid broken token entities
            continue
        
        label_id = predictions[i].item()
        label = id2label[label_id]
        score = probabilities[0][i][label_id].item()  # Get the score for the predicted label
        
        # Add the entity only if it is not 'O' (no entity)
        if label != 'O':
            results.append({
                'entity': label,
                'score': score,
                'index': i,
                'word': token,
                'start': inputs['input_ids'][0][i].item(),
                'end': inputs['input_ids'][0][i].item() + len(token)  # Estimate start/end
            })
    
    return results

# Test the custom pipeline
text = "Hugging Face is a French company founded in 2016."
result = custom_pipeline(text)

result


[{'entity': 'B-ORG',
  'score': 0.5237874388694763,
  'index': 1,
  'word': 'hugging',
  'start': 17662,
  'end': 17669},
 {'entity': 'B-MISC',
  'score': 0.9951995611190796,
  'index': 5,
  'word': 'french',
  'start': 2413,
  'end': 2419},
 {'entity': 'B-MISC',
  'score': 0.24349796772003174,
  'index': 11,
  'word': '[SEP]',
  'start': 102,
  'end': 107}]

## Compare BERT and DistilBERT

In [47]:
import numpy as np
parameters = filter(lambda p: p.requires_grad, student_model.parameters())
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
print('Trainable Student Parameters: %.3fM' % parameters)

Trainable Student Parameters: 66.370M


In [48]:
# teacher model
parameters = filter(lambda p: p.requires_grad, teacher.parameters())
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
print('Trainable Teacher Parameters: %.3fM' % parameters)

Trainable Teacher Parameters: 108.899M


## Custom Student Model

In [49]:
class ScheduledOptim():
    '''A simple wrapper class for learning rate scheduling'''

    def __init__(self, optimizer, d_model, n_warmup_steps):
        self._optimizer = optimizer
        self.n_warmup_steps = n_warmup_steps
        self.n_current_steps = 0
        self.init_lr = np.power(d_model, -0.5)

    def step_and_update_lr(self):
        "Step with the inner optimizer"
        self._update_learning_rate()
        self._optimizer.step()

    def step(self):
        "Step with the inner optimizer"
        self.step_and_update_lr()

    def zero_grad(self):
        "Zero out the gradients by the inner optimizer"
        self._optimizer.zero_grad()

    def _get_lr_scale(self):
        return np.min([
            np.power(self.n_current_steps, -0.5),
            np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])

    def _update_learning_rate(self):
        ''' Learning rate scheduling per step '''

        self.n_current_steps += 1
        lr = self.init_lr * self._get_lr_scale()

        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr


In [50]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_outputs import TokenClassifierOutput

class LayerNormalization(nn.Module):
    def __init__(self, features: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(features))
        self.bias = nn.Parameter(torch.zeros(features))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.alpha * (x - mean) / (std + self.eps) + self.bias

class FeedForwardBlock(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear_2(self.dropout(F.relu(self.linear_1(x))))

class InputEmbeddings(nn.Module):
    def __init__(self, d_model: int, vocab_size: int) -> None:
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model

    def forward(self, x):
        # Ensure that the scaling tensor is on the same device as x
        device = x.device
        multiply_by_sqrt_d_model = torch.sqrt(torch.tensor(self.d_model, dtype=torch.float, device=device))
        return self.embedding(x) * multiply_by_sqrt_d_model

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(seq_len, d_model)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :].requires_grad_(False)
        return self.dropout(x)

class ResidualConnection(nn.Module):
    def __init__(self, features: int, dropout: float) -> None:
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization(features)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, attn_dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.attn_dropout = attn_dropout

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.fc_out = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, x, mask=None):
        n_samples, n_patches, dim = x.shape

        if dim != self.embed_dim:
            raise ValueError(f"Input embedding dimension ({dim}) doesn't match model embedding dimension ({self.embed_dim})")

        qkv = self.qkv(x)
        qkv = qkv.reshape(n_samples, n_patches, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)

        q, k, v = qkv[0], qkv[1], qkv[2]

        k_t = k.transpose(-2, -1)
        attn = (q @ k_t) / math.sqrt(self.head_dim)

        if mask is not None:
            attn = attn.masked_fill(mask.unsqueeze(1).unsqueeze(2) == 0, float("-inf"))

        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        weighted_avg = attn @ v

        weighted_avg = weighted_avg.transpose(1, 2)
        weighted_avg = weighted_avg.flatten(2)

        x = self.fc_out(weighted_avg)

        return x

class EncoderBlock(nn.Module):
    def __init__(self, features: int, self_attention_block: MultiHeadAttention, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])

    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x

class Encoder(nn.Module):
    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

class BERT(nn.Module):
    def __init__(self, features: int, num_layers: int, num_heads: int, d_ff: int, dropout: float, max_len: int, vocab_size: int, num_labels: int) -> None:
        super().__init__()
        self.input_embeddings = InputEmbeddings(features, vocab_size)
        self.positional_encoding = PositionalEncoding(features, max_len, dropout)
        self.encoder_blocks = nn.ModuleList([
            EncoderBlock(features, MultiHeadAttention(features, num_heads), FeedForwardBlock(features, d_ff, dropout), dropout)
            for _ in range(num_layers)
        ])
        self.encoder = Encoder(features, self.encoder_blocks)
        self.classifier = nn.Linear(features, num_labels)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None, device="cuda"):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        
        if labels is not None:
            labels = labels.to(device)
        
        x = self.input_embeddings(input_ids)
        x = self.positional_encoding(x)
        x = self.encoder(x, attention_mask)
        logits = self.classifier(x)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            # Only consider the active parts of the loss
            active_loss = attention_mask.view(-1) == 1
            active_logits = logits.view(-1, logits.size(-1))
            active_labels = torch.where(
                active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
            )
            loss = loss_fct(active_logits, active_labels)

        return TokenClassifierOutput(loss=loss, logits=logits)

# Example usage
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = BERT(features=768, num_layers=8, num_heads=8, d_ff=2048, dropout=0.1, max_len=512, vocab_size=30522, num_labels=9).to(device)

inputs = {
    'input_ids': torch.tensor([[101, 7327, 19164, 2446, 2655, 2000, 17757, 2329, 12559, 1012, 102], [101, 7327, 19164, 2446, 2655, 2000, 17757, 2329, 12559, 1012, 102]]).to(device),
    'attention_mask': torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]).to(device),
    'token_type_ids': torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]).to(device),
    'labels': torch.tensor([[-100, 3, 0, 7, 0, 0, 0, 7, 0, 0, -100], [-100, 3, 0, 7, 0, 0, 0, 7, 0, 0, -100]]).to(device)
}

output = model(**inputs)
print(output)


TokenClassifierOutput(loss=tensor(2.7006, device='cuda:0', grad_fn=<NllLossBackward0>), logits=tensor([[[ 4.8585e-01,  1.1119e-01, -1.1505e+00,  8.9645e-01, -1.7438e-01,
          -3.7746e-01, -3.5475e-01, -1.9035e-01, -4.6150e-01],
         [ 3.0458e-01,  5.0577e-01, -6.0738e-01,  9.8700e-02, -1.0883e-01,
           3.2272e-01, -2.8224e-01, -5.0890e-01, -7.9754e-02],
         [ 8.1097e-01, -6.8989e-01,  6.6812e-01,  1.7677e-01, -4.6742e-01,
          -7.3528e-01, -5.8350e-02,  3.5619e-01,  1.6835e-01],
         [-4.4839e-01,  7.8674e-01, -2.5306e-01,  2.5320e-01, -8.9091e-01,
           5.2139e-01, -7.9422e-01, -1.8371e-01, -1.5040e-01],
         [-1.2329e+00,  8.9540e-03, -3.4395e-01,  1.0705e-01, -2.3584e-01,
           3.2006e-01,  4.3780e-01, -8.5086e-01,  9.9013e-01],
         [ 1.0983e-01,  2.5225e-01, -1.5263e-01,  9.9148e-01, -1.0004e-01,
          -3.8962e-01, -1.0051e-01,  7.9216e-01,  1.6516e-02],
         [ 7.0605e-02,  3.3741e-01, -3.9578e-01, -4.4323e-02, -2.1545e-01,
  

In [51]:
import numpy as np
parameters = filter(lambda p: p.requires_grad, model_finetuned.parameters())
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
print('Trainable Teacher Parameters: %.3fM' % parameters)

Trainable Teacher Parameters: 108.899M


In [52]:
parameters = filter(lambda p: p.requires_grad, model.parameters())
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
print('Trainable Student Parameters: %.3fM' % parameters)

Trainable Student Parameters: 67.561M


In [53]:
# Define Trainer for knowledge distillation of Student model with Teacher model
class DistillationTrainer():
    
    def __init__(self, teacher, student, train_dataloader, val_dataloader, optimizer, criterion, temperature, device, scheduler=None):
        self.teacher = teacher
        self.student = student
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device
        self.scheduler = scheduler
        self.temperature = temperature
        self.alpha = 0.5

    def train(self, n_epochs):
        self.teacher.eval()
        self.student.train()
        
        for epoch in range(n_epochs):
            train_loss = 0
            for i, batch in enumerate(self.train_dataloader):
                
                
                filtered_inputs = {
                    'input_ids': torch.tensor(batch['input_ids']).unsqueeze(0),  # unsqueeze to add batch dimension
                    'attention_mask': torch.tensor(batch['attention_mask']).unsqueeze(0),
                    'token_type_ids': torch.tensor(batch.get('token_type_ids', [])).unsqueeze(0) if 'token_type_ids' in batch else None,
                    'labels': torch.tensor(batch.get('labels', [])).unsqueeze(0) if 'labels' in batch else None
                }
                
                filtered_inputs = {k: v for k, v in filtered_inputs.items() if v is not None}
                
                # print(filtered_inputs)
                self.optimizer.zero_grad()
                teacher_output = self.teacher(**filtered_inputs)
                student_output = self.student(**filtered_inputs)
                
                loss_ce = student_output.loss
                
                logits_student = student_output.logits
                logits_teacher = teacher_output.logits
                
                loss_kd = self.temperature ** 2 * loss_fct(
                    F.log_softmax(logits_student / self.temperature, dim=-1),
                    F.softmax(logits_teacher / self.temperature, dim=-1)
                )
                
                loss = self.alpha * loss_ce + (1 - self.alpha) * loss_kd

                
                loss.backward()
                self.optimizer.step_and_update_lr()
                train_loss += loss.item()
                
                if i % 10 == 0:
                    print(f'Epoch: {epoch}, Iteration: {i}, Loss: {train_loss / (i + 1)}')
                    
            val_loss = self.evaluate()
            print(f'Epoch: {epoch}, Validation Loss: {val_loss}')
            
            if self.scheduler is not None:
                self.scheduler.step(val_loss)
                
            # save checkpoint after certain epochs
            if epoch % 5 == 0:
                torch.save(self.student.state_dict(), f'student_epoch_{epoch}.pth')
                
                
                
    def evaluate(self):
        self.student.eval()
        val_loss = 0
        for i, batch in enumerate(self.val_dataloader):
            filtered_inputs = {
                'input_ids': torch.tensor(batch['input_ids']).to(self.device),
                'attention_mask': torch.tensor(batch['attention_mask']).to(self.device),
                'token_type_ids': torch.tensor(batch.get('token_type_ids', [])).to(self.device) if 'token_type_ids' in batch else None,
                'labels': torch.tensor(batch.get('labels', [])).to(self.device) if 'labels' in batch else None
            }

            filtered_inputs = {k: v for k, v in filtered_inputs.items() if v is not None}

            with torch.no_grad():
                student_output = self.student(**filtered_inputs)
                val_loss += self.criterion(student_output.logits.view(-1, student_output.logits.size(-1)), filtered_inputs['labels'].view(-1)).item()
                
        return val_loss / len(self.val_dataloader)
    
    

# Define the optimizer and scheduler
optimizer = ScheduledOptim(torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9), 768, 2000)

# Define the loss function
loss_fct = nn.KLDivLoss(reduction="batchmean")

# Define the trainer
trainer = DistillationTrainer(model_finetuned, model, tokenized_datasets['train'], tokenized_datasets['validation'], optimizer, loss_fct, 2, 'cuda')

# Train the student model
# trainer.train(2)

In [54]:
class DistillationTrainer():
    
    def __init__(self, teacher, student, train_dataloader, val_dataloader, optimizer, criterion, temperature, device, scheduler=None):
        self.teacher = teacher.to(device)
        self.student = student.to(device)
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device
        self.scheduler = scheduler
        self.temperature = temperature
        self.alpha = 0.5

    def train(self, n_epochs):
        self.teacher.eval()
        # self.student.train()
        
        for epoch in range(n_epochs):
            train_loss = 0
            for i, batch in enumerate(self.train_dataloader):
                
                filtered_inputs = {
                    'input_ids': batch['input_ids'].to(self.device),  
                    'attention_mask': batch['attention_mask'].to(self.device),
                    'token_type_ids': batch.get('token_type_ids', torch.tensor([])).to(self.device) if 'token_type_ids' in batch else None,
                    'labels': batch.get('labels', torch.tensor([])).to(self.device) if 'labels' in batch else None
                }
                
                
                # Ensure that all necessary inputs are provided
                filtered_inputs = {k: v for k, v in filtered_inputs.items() if v is not None}
                
                # set fileterd inputs to the device
                filtered_inputs = {k: v.to(self.device) for k, v in filtered_inputs.items()}
                
                self.optimizer.zero_grad()
                teacher_output = self.teacher(**filtered_inputs)
                student_output = self.student(**filtered_inputs)
                
                loss_ce = student_output.loss
                
                logits_student = student_output.logits
                logits_teacher = teacher_output.logits
                
                loss_kd = self.temperature ** 2 * self.criterion(
                    F.log_softmax(logits_student / self.temperature, dim=-1),
                    F.softmax(logits_teacher / self.temperature, dim=-1)
                )
                
                loss = self.alpha * loss_ce + (1 - self.alpha) * loss_kd

                loss.backward()
                self.optimizer.step_and_update_lr()  # Update this line
                if self.scheduler:
                    self.scheduler.step()
                train_loss += loss.item()
                
                if i % 10 == 0:
                    print(f'Epoch: {epoch}, Iteration: {i}, Loss: {train_loss / (i + 1)}')
                    
            val_loss = self.evaluate()
            print(f'Epoch: {epoch}, Validation Loss: {val_loss}')
            
            # Save checkpoint after certain epochs
            if epoch % 1 == 0:
                torch.save(self.student.state_dict(), f'student_epoch_t_{epoch}.pth')
                
    def evaluate(self):
        self.student.eval()
        val_loss = 0
        with torch.no_grad():
            for i, batch in enumerate(self.val_dataloader):
                filtered_inputs = {
                    'input_ids': batch['input_ids'].to(self.device),
                    'attention_mask': batch['attention_mask'].to(self.device),
                    'token_type_ids': batch.get('token_type_ids', torch.tensor([])).to(self.device) if 'token_type_ids' in batch else None,
                    'labels': batch.get('labels', torch.tensor([])).to(self.device) if 'labels' in batch else None
                }

                filtered_inputs = {k: v for k, v in filtered_inputs.items() if v is not None}

                student_output = self.student(**filtered_inputs)
                teacher_output = self.teacher(**filtered_inputs)
                
                val_loss += student_output.loss.item()
                
                
                
        return val_loss / len(self.val_dataloader)


In [55]:
# Train using trainer object
from torch.utils.data import DataLoader
import torch
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    input_ids = pad_sequence([torch.tensor(item['input_ids']) for item in batch], batch_first=True, padding_value=0)
    attention_mask = pad_sequence([torch.tensor(item['attention_mask']) for item in batch], batch_first=True, padding_value=0)
    
    # Handle optional fields
    token_type_ids = pad_sequence([torch.tensor(item['token_type_ids']) for item in batch], batch_first=True, padding_value=0) if 'token_type_ids' in batch[0] else None
    labels = pad_sequence([torch.tensor(item['labels']) for item in batch], batch_first=True, padding_value=-100) if 'labels' in batch[0] else None
    
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'token_type_ids': token_type_ids,
        'labels': labels
    }



# Define the optimizer and scheduler
optimizer = ScheduledOptim(torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9), 768, 2000)

# Define the loss function
loss_fct = nn.KLDivLoss(reduction="batchmean")

batch_size = 8

from torch.utils.data import DataLoader

# Define the DataLoader with the custom collate_fn
train_dataloader = DataLoader(
    tokenized_datasets['train'],
    batch_size=8,
    shuffle=True,
    collate_fn=collate_fn
)

val_dataloader = DataLoader(
    tokenized_datasets['validation'],
    batch_size=8,
    shuffle=False,
    collate_fn=collate_fn
)

# train_dataloader = DataLoader(tokenized_datasets['train'], batch_size=batch_size, shuffle=True)
# val_dataloader = DataLoader(tokenized_datasets['validation'], batch_size=batch_size, shuffle=False)

# Define the trainer
trainer = DistillationTrainer(model_finetuned, model, train_dataloader, val_dataloader, optimizer, loss_fct, 1, 'cuda')

In [56]:
next(iter(train_dataloader)).keys()

dict_keys(['input_ids', 'attention_mask', 'token_type_ids', 'labels'])

In [57]:
print(next(iter(train_dataloader)))


{'input_ids': tensor([[  101, 16480,  7229,  1018, 21728,  3070,  1015,  1006, 22589,  1015,
          1011,  1014,  1007,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  1996,  2194,  2036, 16360, 14326,  2070,  1002,  2538,  2454,
          1997,  7016,  1010,  2092,  2682,  1996,  2761,  3740,  1002,  1022,
          2454,  2000,  1002,  1023,  2454,  1012,   102,     0,     0,     0],
        [  101,  2745,  8915, 10322,  4904,  2102,  1006,  2660,  1007,  5443,
          1012,  4138,  3240, 10731,  4059,  1006,  1057,  1012,  1055,  1012,
          1007,   102,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  1996,  5829,  2081, 10791,  1996,  2034,  2350,  2924,  2000,
          3013,  6165,  1999,  3433,  2000,  1037,  2655,  2006,  5958,  2013,
          2430,  2924,  3099,  2016,  2226, 11237,  1011, 11947,  1012,   102],
        [  101, 12037,  1011,  241

In [58]:
trainer.train(1)

Epoch: 0, Iteration: 0, Loss: 37.29097366333008
Epoch: 0, Iteration: 10, Loss: 43.02040758999911
Epoch: 0, Iteration: 20, Loss: 42.79359890165783
Epoch: 0, Iteration: 30, Loss: 44.08187349380985
Epoch: 0, Iteration: 40, Loss: 42.46434355945122
Epoch: 0, Iteration: 50, Loss: 39.87305659873813
Epoch: 0, Iteration: 60, Loss: 36.84897607271789
Epoch: 0, Iteration: 70, Loss: 34.42202940121503
Epoch: 0, Iteration: 80, Loss: 32.68631500667996
Epoch: 0, Iteration: 90, Loss: 30.96294330764603
Epoch: 0, Iteration: 100, Loss: 29.408854007720947
Epoch: 0, Iteration: 110, Loss: 28.13211576358692
Epoch: 0, Iteration: 120, Loss: 27.123944199774876
Epoch: 0, Iteration: 130, Loss: 26.169170346878868
Epoch: 0, Iteration: 140, Loss: 25.32811566278444
Epoch: 0, Iteration: 150, Loss: 24.587374557722484
Epoch: 0, Iteration: 160, Loss: 23.884181529098417
Epoch: 0, Iteration: 170, Loss: 23.35410077390615
Epoch: 0, Iteration: 180, Loss: 22.791840919473554
Epoch: 0, Iteration: 190, Loss: 22.36741368808047
Epoch