In [1]:
import lightning as L
import torch.nn.functional as F
import timm
import torch
from torcheval.metrics import MulticlassF1Score
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from lightning.pytorch.callbacks import ModelCheckpoint
torch.set_float32_matmul_precision("medium") # Take advantage of the tensor cores on the RTX GPU.

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#A manual seed used for reproducibility throughout this notebook.
MANUAL_SEED = 7
# The number of epochs used for all training done in this notebook.
EPOCHS = 10
# The batch size used for all training done in this notebook.
BATCH_SIZE = 4096 # As much as the GPU can handle for the biggest model.
# The learning rate used for all training done in this notebook.
LEARNINIG_RATE = 1e-3

In [3]:
# Timm is a library for loading pre-trained (or not) models. 
timm.list_models() #Lists the IDs of the available models.

['bat_resnext26ts',
 'beit_base_patch16_224',
 'beit_base_patch16_384',
 'beit_large_patch16_224',
 'beit_large_patch16_384',
 'beit_large_patch16_512',
 'beitv2_base_patch16_224',
 'beitv2_large_patch16_224',
 'botnet26t_256',
 'botnet50ts_256',
 'caformer_b36',
 'caformer_m36',
 'caformer_s18',
 'caformer_s36',
 'cait_m36_384',
 'cait_m48_448',
 'cait_s24_224',
 'cait_s24_384',
 'cait_s36_384',
 'cait_xs24_384',
 'cait_xxs24_224',
 'cait_xxs24_384',
 'cait_xxs36_224',
 'cait_xxs36_384',
 'coat_lite_medium',
 'coat_lite_medium_384',
 'coat_lite_mini',
 'coat_lite_small',
 'coat_lite_tiny',
 'coat_mini',
 'coat_small',
 'coat_tiny',
 'coatnet_0_224',
 'coatnet_0_rw_224',
 'coatnet_1_224',
 'coatnet_1_rw_224',
 'coatnet_2_224',
 'coatnet_2_rw_224',
 'coatnet_3_224',
 'coatnet_3_rw_224',
 'coatnet_4_224',
 'coatnet_5_224',
 'coatnet_bn_0_rw_224',
 'coatnet_nano_cc_224',
 'coatnet_nano_rw_224',
 'coatnet_pico_rw_224',
 'coatnet_rmlp_0_rw_224',
 'coatnet_rmlp_1_rw2_224',
 'coatnet_rmlp_1_r

In [4]:
# Gotta extract the model code into nn.module
# then build a lightning module to handle the training pipeline
# Then throw away this lightning module and create a new one for distillation
# Load the model using .load_state_dict() like shown here : https://github.com/Lightning-AI/pytorch-lightning/issues/20053#issuecomment-2215485554

In [15]:
class TeacherResNet50(torch.nn.Module):
    def __init__(self, num_classes : int, pretrained: bool = True) -> None:
        """An implementation of a fine-tunable Resnet50 pretrained (or not) model.
        The final classification layer is replaced with a new one that predicts the number of classes defined by 'num_classes'.

        Args:
            num_classes: The number of classes to set when replacing the final classification layer for fine-tuning purposes.
            pretrained: A boolean value indicating whether to load the pre-trained weights of the model. Defaults to true.
        """
        super().__init__()
        self.resnet50 = timm.create_model('resnet50', pretrained=pretrained)
        #Replace the final classification layer of the model for fine-tuning purposes.
        self.resnet50.fc = torch.nn.Linear(
            in_features=self.resnet50.fc.in_features,
            out_features=num_classes
        )
    
    def forward(self, x):
        return self.resnet50(x)

In [16]:
class StudentResNet18(torch.nn.Module):
    def __init__(self, num_classes : int, pretrained: bool = True) -> None:
        """An implementation of a fine-tunable Resnet18 pretrained (or not) model.
        The final classification layer is replaced with a new one that predicts the number of classes defined by 'num_classes'.

        Args:
            num_classes: The number of classes to set when replacing the final classification layer for fine-tuning purposes.
            pretrained: A boolean value indicating whether to load the pre-trained weights of the model. Defaults to true.
        """
        super().__init__()
        self.resnet18 = timm.create_model('resnet18', pretrained=pretrained)
        #Replace the final classification layer of the model for fine-tuning purposes.
        self.resnet18.fc = torch.nn.Linear(
            in_features=self.resnet50.fc.in_features,
            out_features=num_classes
        )

    def forward(self, x):
        return self.resnet18(x)

In [17]:
class TrainingTeacherResNet50(L.LightningModule):
    def __init__(self, num_classes : int, pretrained: bool = True, lr : float = 1e-4) -> None:
        """A pytorch lightning implementation of the Resnet50 fine-tuning process.

        Args:
            num_classes: The number of classes to set when replacing the final classification layer for fine-tuning purposes.
            pretrained: A boolean value indicating whether to load the pre-trained weights of the model. Defaults to true.
            lr: The learning rate used during training. Defaults to 1e-4.
        """
        super().__init__()
        self.resnet50 = TeacherResNet50(num_classes=num_classes,
                                        pretrained=pretrained)
        self.lr = lr
        self.f1_metric = MulticlassF1Score(num_classes=num_classes)

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        return self.resnet50(input_tensor)

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        input_batch, target_batch = batch
        logits = self(input_batch).squeeze()
        loss = torch.nn.functional.cross_entropy(logits, target_batch)
        #Calculate metrics
        self.f1_metric.update(logits, target_batch)
        train_f1_score = self.f1_metric.compute()
        self.log("train_loss", loss, on_epoch=True, prog_bar=True)
        self.log("train_f1_score", train_f1_score, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        # validation_step defines the validation loop.
        input_batch, target_batch = batch
        logits = self(input_batch).squeeze()
        loss = torch.nn.functional.cross_entropy(logits, target_batch)
        #Calculate metrics
        self.f1_metric.update(logits, target_batch)
        validation_f1_score = self.f1_metric.compute()
        self.log("validation_loss", loss, on_epoch=True, prog_bar=True)
        self.log("validation_f1_score", validation_f1_score, on_epoch=True, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        # test_step defines the test loop.
        input_batch, target_batch = batch
        logits = self(input_batch).squeeze()
        loss = torch.nn.functional.cross_entropy(logits, target_batch)
        #Calculate metrics
        self.f1_metric.update(logits, target_batch)
        test_f1_score = self.f1_metric.compute()
        self.log("test_loss", loss, on_epoch=True, prog_bar=True)
        self.log("test_f1_score", test_f1_score, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [18]:
class TrainingStudentResNet18(L.LightningModule):
    def __init__(self, num_classes : int, pretrained: bool = True, lr : float = 1e-4) -> None:
        """A pytorch lightning implementation of the Resnet18 pretrained model.

        Args:
            num_classes: The number of classes to set when replacing the final classification layer for fine-tuning purposes.
            pretrained: A boolean value indicating whether to load the pre-trained weights of the model. Defaults to true.
            lr: The learning rate used during training. Defaults to 1e-4.
        """
        super().__init__()
        self.resnet18 = StudentResNet18(num_classes=num_classes,
                                        pretrained=pretrained)
        self.lr = lr
        self.f1_metric = MulticlassF1Score(num_classes=num_classes)

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        return self.resnet18(input_tensor)

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        input_batch, target_batch = batch
        logits = self(input_batch).squeeze()
        loss = torch.nn.functional.cross_entropy(logits, target_batch)
        #Calculate metrics
        self.f1_metric.update(logits, target_batch)
        train_f1_score = self.f1_metric.compute()
        self.log("train_loss", loss, on_epoch=True, prog_bar=True)
        self.log("train_f1_score", train_f1_score, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        # validation_step defines the validation loop.
        input_batch, target_batch = batch
        logits = self(input_batch).squeeze()
        loss = torch.nn.functional.cross_entropy(logits, target_batch)
        #Calculate metrics
        self.f1_metric.update(logits, target_batch)
        validation_f1_score = self.f1_metric.compute()
        self.log("validation_loss", loss, on_epoch=True, prog_bar=True)
        self.log("validation_f1_score", validation_f1_score, on_epoch=True, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        # test_step defines the test loop.
        input_batch, target_batch = batch
        logits = self(input_batch).squeeze()
        loss = torch.nn.functional.cross_entropy(logits, target_batch)
        #Calculate metrics
        self.f1_metric.update(logits, target_batch)
        test_f1_score = self.f1_metric.compute()
        self.log("test_loss", loss, on_epoch=True, prog_bar=True)
        self.log("test_f1_score", test_f1_score, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

## Load the Cifar10 dataset

In [19]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load the training and test datasets.
train_dataset = datasets.CIFAR10(root="./CIFAR10", train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root="./CIFAR10", train=False, download=True, transform=transform)

# Create a validation set from the training set (80% train, 20% validation).
train_size = int(0.8 * len(train_dataset))
validation_size = len(train_dataset) - train_size
#Use a manual seed for reproducibility.
generator = torch.Generator().manual_seed(MANUAL_SEED)
train_dataset, validation_dataset = random_split(train_dataset, [train_size, validation_size], generator=generator)

# Create data loaders.
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)


Files already downloaded and verified
Files already downloaded and verified


# 1. Train the teacher network (pre-trained on ImageNet) on the CIFAR10 dataset while evaluating it.
Ensure proper data split for training, validation, and testing. Calculate the number of parameters

In [20]:
# Instantiate the pre-trained teacher model with 10 classes for the classifier.
torch.manual_seed(MANUAL_SEED) # Seed the weights for the new classification layer.
teacherResnet50 = TrainingTeacherResNet50(num_classes=10,
                                  pretrained=True,
                                  lr=LEARNINIG_RATE)

In [21]:
# Only save to disk the best performing version of the model throughout training (best f1 score).
checkpoint_callback = ModelCheckpoint(
        dirpath="models/pretrained_teacher_training",
        monitor="validation_f1_score",
        filename="best",
        mode="max",
        save_last=True,
        verbose=True
    )

trainer = L.Trainer(max_epochs=EPOCHS,
                    log_every_n_steps=1,
                    val_check_interval=1,
                    callbacks=[checkpoint_callback])

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(val_check_interval=1)` was configured so validation will run after every batch.


In [None]:
trainer.fit(model=teacherResnet50, train_dataloaders=train_dataloader, val_dataloaders=validation_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type            | Params | Mode 
-----------------------------------------------------
0 | resnet50 | TeacherResNet50 | 23.5 M | train
-----------------------------------------------------
23.5 M    Trainable params
0         Non-trainable params
23.5 M    Total params
94.114    Total estimated model params size (MB)
218       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\paris_cite\m2\cours\aide décision\knowledge_distillation\venv\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


                                                                           

c:\paris_cite\m2\cours\aide décision\knowledge_distillation\venv\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Epoch 0:  30%|███       | 3/10 [00:14<00:33,  0.21it/s, v_num=1, train_loss_step=2.270, train_f1_score_step=0.164, validation_loss=2.130, validation_f1_score=0.155]

In [9]:
# Test the teacher model.
trainer.test(model=teacherResnet50, dataloaders=test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\paris_cite\m2\cours\aide décision\knowledge_distillation\venv\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 3/3 [00:01<00:00,  1.95it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_f1_score          0.702156126499176
        test_loss           1.0119032859802246
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 1.0119032859802246, 'test_f1_score': 0.702156126499176}]

In [10]:
# The number of parameters for the teacher Resnet50 (extracted from training output).
#   | Name     | Type   | Params | Mode 
# --------------------------------------------
# 0 | resnet50 | ResNet | 23.5 M | train
# --------------------------------------------
# 23.5 M    Trainable params
# 0         Non-trainable params
# 23.5 M    Total params
# 94.114    Total estimated model params size (MB)
# 217       Modules in train mode
# 0         Modules in eval mode

# 2. Train the student network (pre-trained on ImageNet) on the CIFAR10 dataset while evaluating it.
Ensure proper data split for training, validation, and testing. Calculate the number of parameters.

Cross entropy loss is used. The batch size is set to 4096 and the learning rate to 1e-4. 

In [11]:
# Instantiate the pre-trained student model with 10 classes for the classifier.
torch.manual_seed(MANUAL_SEED) # Seed the weights for the new classification layer.
studentResnet18 = StudentResNet18(num_classes=10,
                                  pretrained=True,
                                  lr=LEARNINIG_RATE)

In [12]:
# Only save to disk the best performing version of the model throughout training (best f1 score).
checkpoint_callback = ModelCheckpoint(
        dirpath="models/pretrained_student_training",
        monitor="validation_f1_score",
        filename="best",
        mode="max",
        save_last=False,
        verbose=True
    )

trainer = L.Trainer(max_epochs=EPOCHS,
                    log_every_n_steps=1,
                    val_check_interval=1,
                    callbacks=[checkpoint_callback])

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(val_check_interval=1)` was configured so validation will run after every batch.


In [13]:
trainer.fit(model=studentResnet18, train_dataloaders=train_dataloader, val_dataloaders=validation_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type   | Params | Mode 
--------------------------------------------
0 | resnet18 | ResNet | 11.2 M | train
--------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)
94        Modules in train mode
0         Modules in eval mode


Epoch 0: 100%|██████████| 10/10 [00:19<00:00,  0.51it/s, v_num=1, train_loss_step=1.950, train_f1_score_step=0.248, validation_loss=1.750, validation_f1_score=0.256, train_loss_epoch=2.150, train_f1_score_epoch=0.164]

Epoch 0, global step 10: 'validation_f1_score' reached 0.25605 (best 0.25605), saving model to 'C:\\paris_cite\\m2\\cours\\aide décision\\knowledge_distillation\\models\\pretrained_student_training\\best.ckpt' as top 1


Epoch 1: 100%|██████████| 10/10 [00:19<00:00,  0.52it/s, v_num=1, train_loss_step=1.310, train_f1_score_step=0.368, validation_loss=1.410, validation_f1_score=0.372, train_loss_epoch=1.550, train_f1_score_epoch=0.320]

Epoch 1, global step 20: 'validation_f1_score' reached 0.37163 (best 0.37163), saving model to 'C:\\paris_cite\\m2\\cours\\aide décision\\knowledge_distillation\\models\\pretrained_student_training\\best.ckpt' as top 1


Epoch 2: 100%|██████████| 10/10 [00:20<00:00,  0.50it/s, v_num=1, train_loss_step=0.950, train_f1_score_step=0.438, validation_loss=1.190, validation_f1_score=0.440, train_loss_epoch=1.080, train_f1_score_epoch=0.408]

Epoch 2, global step 30: 'validation_f1_score' reached 0.44034 (best 0.44034), saving model to 'C:\\paris_cite\\m2\\cours\\aide décision\\knowledge_distillation\\models\\pretrained_student_training\\best.ckpt' as top 1


Epoch 3: 100%|██████████| 10/10 [00:26<00:00,  0.38it/s, v_num=1, train_loss_step=0.778, train_f1_score_step=0.495, validation_loss=0.913, validation_f1_score=0.497, train_loss_epoch=0.829, train_f1_score_epoch=0.470]

Epoch 3, global step 40: 'validation_f1_score' reached 0.49700 (best 0.49700), saving model to 'C:\\paris_cite\\m2\\cours\\aide décision\\knowledge_distillation\\models\\pretrained_student_training\\best.ckpt' as top 1


Epoch 4: 100%|██████████| 10/10 [00:19<00:00,  0.51it/s, v_num=1, train_loss_step=0.652, train_f1_score_step=0.541, validation_loss=0.824, validation_f1_score=0.542, train_loss_epoch=0.657, train_f1_score_epoch=0.521]

Epoch 4, global step 50: 'validation_f1_score' reached 0.54243 (best 0.54243), saving model to 'C:\\paris_cite\\m2\\cours\\aide décision\\knowledge_distillation\\models\\pretrained_student_training\\best.ckpt' as top 1


Epoch 5: 100%|██████████| 10/10 [00:19<00:00,  0.51it/s, v_num=1, train_loss_step=0.506, train_f1_score_step=0.576, validation_loss=0.781, validation_f1_score=0.577, train_loss_epoch=0.531, train_f1_score_epoch=0.561]

Epoch 5, global step 60: 'validation_f1_score' reached 0.57717 (best 0.57717), saving model to 'C:\\paris_cite\\m2\\cours\\aide décision\\knowledge_distillation\\models\\pretrained_student_training\\best.ckpt' as top 1


Epoch 6: 100%|██████████| 10/10 [00:19<00:00,  0.51it/s, v_num=1, train_loss_step=0.408, train_f1_score_step=0.604, validation_loss=0.774, validation_f1_score=0.605, train_loss_epoch=0.425, train_f1_score_epoch=0.592]

Epoch 6, global step 70: 'validation_f1_score' reached 0.60510 (best 0.60510), saving model to 'C:\\paris_cite\\m2\\cours\\aide décision\\knowledge_distillation\\models\\pretrained_student_training\\best.ckpt' as top 1


Epoch 7: 100%|██████████| 10/10 [00:19<00:00,  0.51it/s, v_num=1, train_loss_step=0.329, train_f1_score_step=0.627, validation_loss=0.782, validation_f1_score=0.628, train_loss_epoch=0.333, train_f1_score_epoch=0.617]

Epoch 7, global step 80: 'validation_f1_score' reached 0.62800 (best 0.62800), saving model to 'C:\\paris_cite\\m2\\cours\\aide décision\\knowledge_distillation\\models\\pretrained_student_training\\best.ckpt' as top 1


Epoch 8: 100%|██████████| 10/10 [00:20<00:00,  0.49it/s, v_num=1, train_loss_step=0.254, train_f1_score_step=0.647, validation_loss=0.807, validation_f1_score=0.647, train_loss_epoch=0.257, train_f1_score_epoch=0.638]

Epoch 8, global step 90: 'validation_f1_score' reached 0.64712 (best 0.64712), saving model to 'C:\\paris_cite\\m2\\cours\\aide décision\\knowledge_distillation\\models\\pretrained_student_training\\best.ckpt' as top 1


Epoch 9: 100%|██████████| 10/10 [00:21<00:00,  0.47it/s, v_num=1, train_loss_step=0.174, train_f1_score_step=0.663, validation_loss=0.840, validation_f1_score=0.663, train_loss_epoch=0.193, train_f1_score_epoch=0.656]

Epoch 9, global step 100: 'validation_f1_score' reached 0.66334 (best 0.66334), saving model to 'C:\\paris_cite\\m2\\cours\\aide décision\\knowledge_distillation\\models\\pretrained_student_training\\best.ckpt' as top 1
`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 10/10 [00:21<00:00,  0.46it/s, v_num=1, train_loss_step=0.174, train_f1_score_step=0.663, validation_loss=0.840, validation_f1_score=0.663, train_loss_epoch=0.193, train_f1_score_epoch=0.656]


In [14]:
# Test the student model.
trainer.test(model=studentResnet18, dataloaders=test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\paris_cite\m2\cours\aide décision\knowledge_distillation\venv\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 3/3 [00:00<00:00,  3.26it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_f1_score         0.6639996767044067
        test_loss           0.8173630833625793
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.8173630833625793, 'test_f1_score': 0.6639996767044067}]

In [None]:
# The number of parameters for the student Resnet18 (extracted from training output).
#  | Name     | Type   | Params | Mode 
#--------------------------------------------
#0 | resnet18 | ResNet | 11.2 M | train
#--------------------------------------------
#11.2 M    Trainable params
#0         Non-trainable params
#11.2 M    Total params
#44.727    Total estimated model params size (MB)
#94        Modules in train mode
#0         Modules in eval mode

## Training results comparison
 
Teacher model f1 score : 70.21%

Student model f1 score : 66.39%

# Knowldge distillation

In [None]:
class DistilledStudentResnet18(L.LightningModule):
    def __init__(self, num_classes : int,
                 pretrained_student: bool = True,
                 lr : float = 1e-4,
                 temperature : float = 2,
                 soft_target_loss_weight : float = 0.25,
                 ce_loss_weight : float = 0.75) -> None:
        """A pytorch lightning implementation of the distillation process for the pretrained (or not) Resnet18 model.
        Implementation is inspired from https://pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html
        The teacher model is the fine-tuned Resnet50.

        Args:
            num_classes: The number of classes to set when replacing the final classification layer for fine-tuning purposes.
            pretrained_student: A boolean value indicating whether to load the pre-trained weights of the student model. Defaults to true.
            lr: The learning rate used during training. Defaults to 1e-4.
            temperature : Controls the smoothness of the output distributions. Larger T leads to smoother distributions, thus smaller probabilities get a larger boost. Defaults to 2.
            soft_target_loss_weight : A weight assigned to the loss calculated on the scores. Defaults to 0.25.
            ce_loss_weight : A weight assigned to the cross entropy loss calculated on the targets. Defaults to 0.75.
        """
        super().__init__()
        # Load the student model.
        self.student_resnet18 = StudentResNet18(num_classes=num_classes,
                                                pretrained=pretrained_student,
                                                lr=lr)
        # Load the fine-tuned teacher model.
        self.teacher_resnet50 = TeacherResNet50.load_from_checkpoint("finetuned_models/pretrained_teacher_training/best.ckpt", 
                                                                     num_classes=num_classes,
                                                                     lr=lr)
        self.lr = lr
        self.f1_metric = MulticlassF1Score(num_classes=num_classes)
        self.temperature = temperature
        self.soft_target_loss_weight = soft_target_loss_weight
        self.ce_loss_weight = ce_loss_weight

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        return self.resnet50(input_tensor)

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        input_batch, target_batch = batch
        logits = self(input_batch).squeeze()
        loss = torch.nn.functional.cross_entropy(logits, target_batch)
        #Calculate metrics
        self.f1_metric.update(logits, target_batch)
        train_f1_score = self.f1_metric.compute()
        self.log("train_loss", loss, on_epoch=True, prog_bar=True)
        self.log("train_f1_score", train_f1_score, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        # validation_step defines the validation loop.
        input_batch, target_batch = batch
        logits = self(input_batch).squeeze()
        loss = torch.nn.functional.cross_entropy(logits, target_batch)
        #Calculate metrics
        self.f1_metric.update(logits, target_batch)
        validation_f1_score = self.f1_metric.compute()
        self.log("validation_loss", loss, on_epoch=True, prog_bar=True)
        self.log("validation_f1_score", validation_f1_score, on_epoch=True, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        # test_step defines the test loop.
        input_batch, target_batch = batch
        logits = self(input_batch).squeeze()
        loss = torch.nn.functional.cross_entropy(logits, target_batch)
        #Calculate metrics
        self.f1_metric.update(logits, target_batch)
        test_f1_score = self.f1_metric.compute()
        self.log("test_loss", loss, on_epoch=True, prog_bar=True)
        self.log("test_f1_score", test_f1_score, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

# Define the student model