# Training

## Training a simple Transformer model

We start by training a simple Transformer model on the `Setfit/emotion` dataset. We chose `microsoft/xtremedistil-l6-h256-uncased` as it is a relatively lightweight base model.

Note: in the following sections we limit the number of training steps in the `Trainer` as it is a simple demo code but you will need to increase (or unset) the `max_steps` parameter to achieve decent performance.

In [1]:
from bert_squeeze.assistants import TrainAssistant
from lightning.pytorch import Trainer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config_assistant = {
    "name": "bert",
    "train_kwargs": {
        "objective": "ce"
    },
    "model_kwargs": {
        "pretrained_model_name_or_path": "microsoft/xtremedistil-l6-h256-uncased",
        "num_labels": 6
    },
    "data_kwargs": {
        "max_length": 64,
        "tokenizer_name": "microsoft/xtremedistil-l6-h256-uncased",
        "dataset_config": {
            "path": "Setfit/emotion"
        }
    }
}

In [3]:
assistant = TrainAssistant(**config_assistant)



In [4]:
model = assistant.model

train_dataloader = assistant.data.train_dataloader()
test_dataloader = assistant.data.test_dataloader()



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 365.50it/s]

INFO:root:Dataset 'Setfit/emotion' successfully loaded.
DatasetDict({
    train: Dataset({
        features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 16000
    })
    test: Dataset({
        features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2000
    })
    validation: Dataset({
        features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2000
    })
})





In [5]:
basic_trainer = Trainer(
    max_steps=10
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [6]:
basic_trainer.fit(
    model=model, 
    train_dataloaders=train_dataloader, 
    val_dataloaders=test_dataloader
)


  | Name       | Type             | Params
------------------------------------------------
0 | objective  | CrossEntropyLoss | 0     
1 | encoder    | CustomBertModel  | 12.8 M
2 | classifier | Sequential       | 67.8 K
------------------------------------------------
12.8 M    Trainable params
0         Non-trainable params
12.8 M    Total params
51.272    Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|                                                                                                                                                                                     | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                                                                                                                                                                                              

  rank_zero_warn(


Epoch 0:   0%|▎                                                                                                                                                                                     | 1/500 [00:01<09:27,  1.14s/it, v_num=14]

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/python_arg_parser.cpp:1485.)
  next_m.mul_(beta1).add_(1 - beta1, grad)


Epoch 0:   2%|███▌                                                                                                                                                                                 | 10/500 [00:10<08:34,  1.05s/it, v_num=14]

`Trainer.fit` stopped: `max_steps=10` reached.


Epoch 0:   2%|███▌                                                                                                                                                                                 | 10/500 [00:10<08:54,  1.09s/it, v_num=14]


## Training FastBert

Fine-tuning a `FastBert` model is as easy as fine-tuning a regular BERT. The only difference is that you need to use the `FastBertLogic` callback. The callback is in charge of freezing the model's backbone after some steps.

In [7]:
config_assistant_fastbert = {
    "name": "fastbert",
    "train_kwargs": {
        "objective": "ce"
    },
    "model_kwargs": {
        "pretrained_model_name_or_path": "microsoft/xtremedistil-l6-h256-uncased",
        "num_labels": 6
    },
    "data_kwargs": {
        "max_length": 64,
        "tokenizer_name": "microsoft/xtremedistil-l6-h256-uncased",
        "dataset_config": {
            "path": "Setfit/emotion"
        }
    }
}

fastbert_assistant = TrainAssistant(**config_assistant_fastbert)

basic_trainer = Trainer(
    max_steps=10,
    callbacks=fastbert_assistant.callbacks
)

basic_trainer.fit(
    model=fastbert_assistant.model, 
    train_dataloaders=fastbert_assistant.data.train_dataloader(), 
    val_dataloaders=fastbert_assistant.data.test_dataloader()
)



GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs




100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 554.02it/s]

INFO:root:Dataset 'Setfit/emotion' successfully loaded.




  | Name       | Type             | Params
------------------------------------------------
0 | objective  | CrossEntropyLoss | 0     
1 | embeddings | BertEmbeddings   | 7.9 M 
2 | encoder    | FastBertGraph    | 6.7 M 
------------------------------------------------
14.7 M    Trainable params
0         Non-trainable params
14.7 M    Total params
58.669    Total estimated model params size (MB)


DatasetDict({
    train: Dataset({
        features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 16000
    })
    test: Dataset({
        features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2000
    })
    validation: Dataset({
        features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2000
    })
})
Epoch 0:   2%|███▌                                                                                                                                                                                 | 10/500 [00:11<09:11,  1.13s/it, v_num=15]

`Trainer.fit` stopped: `max_steps=10` reached.


Epoch 0:   2%|███▌                                                                                                                                                                                 | 10/500 [00:11<09:24,  1.15s/it, v_num=15]


## Training TheseusBert

Similarly, fine-tuning a `TheseusBert` model is as simple as fine-tuning a regular BERT. For `TheseusBert` you do not even need to use a callback. The submodules are indeed substituted through a scheduler.

In [8]:
config_assistant_fastbert = {
    "name": "theseusbert",
    "train_kwargs": {
        "objective": "ce"
    },
    "model_kwargs": {
        "pretrained_model_name_or_path": "microsoft/xtremedistil-l6-h256-uncased",
        "num_labels": 6
    },
    "data_kwargs": {
        "max_length": 64,
        "tokenizer_name": "microsoft/xtremedistil-l6-h256-uncased",
        "dataset_config": {
            "path": "Setfit/emotion"
        }
    }
}

fastbert_assistant = TrainAssistant(**config_assistant_fastbert)

basic_trainer = Trainer(
    max_steps=10
)

basic_trainer.fit(
    model=fastbert_assistant.model, 
    train_dataloaders=fastbert_assistant.data.train_dataloader(), 
    val_dataloaders=fastbert_assistant.data.test_dataloader()
)



GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Some weights of TheseusBertModel were not initialized from the model checkpoint at microsoft/xtremedistil-l6-h256-uncased and are newly initialized: ['encoder.successor_layers.2.attention.self.key.weight', 'encoder.successor_layers.2.attention.output.dense.weight', 'encoder.successor_layers.3.attention.output.dense.bias', 'encoder.successor_layers.2.attention.self.key.bias', 'encoder.successor_layers.5.attention.output.dense.bias', 'encoder.successor_layers.1.attention.self.query.bias', 'encoder.successor_layers.1.attention.self.value.bias', 'encoder.successor_layers.2.attention.output.LayerNorm.weight', 'encoder.successor_layers.1.output.dense.bias', 'encoder.successor_layers.2.attention.output.dense.bias', 'encoder.successor_layers.3.intermediate.dense.bias', 'encoder.successor_layers.3.attention.self.query.weight', 'encoder.successor_laye



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 378.92it/s]

INFO:root:Dataset 'Setfit/emotion' successfully loaded.




  | Name       | Type             | Params
------------------------------------------------
0 | objective  | CrossEntropyLoss | 0     
1 | encoder    | TheseusBertModel | 17.5 M
2 | classifier | Sequential       | 67.8 K
------------------------------------------------
17.6 M    Trainable params
0         Non-trainable params
17.6 M    Total params
70.226    Total estimated model params size (MB)


DatasetDict({
    train: Dataset({
        features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 16000
    })
    test: Dataset({
        features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2000
    })
    validation: Dataset({
        features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2000
    })
})
Epoch 0:   2%|███▌                                                                                                                                                                                 | 10/500 [00:10<08:34,  1.05s/it, v_num=16]

`Trainer.fit` stopped: `max_steps=10` reached.


Epoch 0:   2%|███▌                                                                                                                                                                                 | 10/500 [00:10<08:51,  1.08s/it, v_num=16]
