In [4]:
%pip install transformers[torch]==4.45.2 huggingface_hub datasets evaluate torchvision

Collecting transformers[torch]==4.45.2
  Using cached transformers-4.45.2-py3-none-any.whl (9.9 MB)
Collecting huggingface_hub
  Using cached huggingface_hub-0.26.3-py3-none-any.whl (447 kB)
Collecting datasets
  Using cached datasets-3.1.0-py3-none-any.whl (480 kB)
Collecting evaluate
  Using cached evaluate-0.4.3-py3-none-any.whl (84 kB)
Collecting torchvision
  Using cached torchvision-0.20.1-cp310-cp310-manylinux1_x86_64.whl (7.2 MB)
Collecting tokenizers<0.21,>=0.20
  Using cached tokenizers-0.20.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.0 MB)
Collecting pyyaml>=5.1
  Using cached PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (751 kB)
Collecting safetensors>=0.4.1
  Using cached safetensors-0.4.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (435 kB)
Collecting filelock
  Using cached filelock-3.16.1-py3-none-any.whl (16 kB)
Collecting regex!=2019.12.17
  Using cached regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.man

In [27]:
from transformers import MobileNetV2Config, MobileNetV2ForImageClassification, AutoModelForImageClassification, Trainer, TrainingArguments
from torchvision import transforms, datasets as dataset
from torch.utils.data import Dataset
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import evaluate
import random
import torch
import os

In [None]:
def reset_seed(seed=42):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

In [None]:
class CIFAR10HFDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        return {
            'pixel_values': image,
            'labels': label
        }

    def __len__(self):
        return len(self.dataset)

In [7]:
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    acc = accuracy.compute(references=labels, predictions=np.argmax(predictions, axis=1))
    return {"accuracy": acc["accuracy"]}

In [8]:
def get_training_args(output_dir, logging_dir):
    return (
        TrainingArguments(
        output_dir=output_dir,
        eval_strategy="epoch",
        save_strategy="epoch",
        learning_rate=5e-5,
        per_device_train_batch_size=64,
        per_device_eval_batch_size=64,
        num_train_epochs=20,
        weight_decay=0.01,
        seed = 42,
        metric_for_best_model="accuracy",
        load_best_model_at_end=True,
        fp16=True, 
        logging_dir=logging_dir,
    ))


In [None]:
def get_random_init_mobilenet():
    student_config = MobileNetV2Config()
    student_config.num_labels = 10
    return MobileNetV2ForImageClassification(student_config)

In [None]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False

    for param in model.classifier.parameters():
        param.requires_grad = True

In [None]:
def get_mobilenet():
    model_pretrained = MobileNetV2ForImageClassification.from_pretrained("google/mobilenet_v2_1.0_224")
    in_features = model_pretrained.classifier.in_features

    model_pretrained.classifier = nn.Linear(in_features,10)
    model_pretrained.num_labels = 10
    model_pretrained.config.num_labels = 10

    return model_pretrained

In [None]:
reset_seed(42)

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])


train = dataset.CIFAR10(root='./data/10', train=True, download=True, transform=transform)
test = dataset.CIFAR10(root='./data/10', train=False, download=True, transform=transform)

train_dataset_hf = CIFAR10HFDataset(train)
test_dataset_hf = CIFAR10HFDataset(test)

In [None]:
training_args = get_training_args("./results/cifar10-random", './logs/cifar10-random')
model = get_random_init_mobilenet()

In [10]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset_hf,
    eval_dataset=test_dataset_hf,
    compute_metrics=compute_metrics,
)

In [11]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,1.9994,1.465344,0.4619
2,1.338,1.238446,0.5633
3,1.1357,1.040728,0.6262
4,0.8884,0.949196,0.6778
5,0.7736,0.940272,0.6785
6,0.6184,0.835595,0.7268
7,0.5585,1.199692,0.6553
8,0.4315,1.572898,0.6026
9,0.356,0.850076,0.7435
10,0.2889,0.897842,0.7531


TrainOutput(global_step=15640, training_loss=0.45282880482466326, metrics={'train_runtime': 3560.2383, 'train_samples_per_second': 280.88, 'train_steps_per_second': 4.393, 'total_flos': 2.020099608576e+18, 'train_loss': 0.45282880482466326, 'epoch': 20.0})

In [11]:
model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [12]:
trainer.evaluate()

{'eval_loss': 1.0349912643432617,
 'eval_accuracy': 0.7594,
 'eval_runtime': 12.9818,
 'eval_samples_per_second': 770.311,
 'eval_steps_per_second': 12.094,
 'epoch': 20.0}

In [32]:
class ImageDistilTrainer(Trainer):
    def __init__(self, teacher_model=None, student_model=None, temperature=None, lambda_param=None, *args, **kwargs):
        super().__init__(model=student_model, *args, **kwargs)
        self.teacher = teacher_model
        self.student = student_model
        self.loss_function = nn.KLDivLoss(reduction="batchmean")
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.teacher.to(device)
        self.teacher.eval()
        self.temperature = temperature
        self.lambda_param = lambda_param



    def compute_loss(self, student, inputs, return_outputs=False):
        student_output = self.student(**inputs)

        with torch.no_grad():
          teacher_output = self.teacher(**inputs)


        soft_teacher = F.softmax(teacher_output.logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)

        
        distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2)

 
        student_target_loss = student_output.loss


        loss = (1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss
        return (loss, student_output) if return_outputs else loss

In [42]:
reset_seed(42)

In [43]:
teacher_model = AutoModelForImageClassification.from_pretrained(
    "aaraki/vit-base-patch16-224-in21k-finetuned-cifar10",
    num_labels=10
)

student_model = get_random_init_mobilenet()

In [None]:
training_args = get_training_args("./results/cifar10-random-KD", './logs/cifar10-random-KD')

In [46]:
trainer = ImageDistilTrainer(
    student_model=student_model,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train_dataset_hf,
    eval_dataset=test_dataset_hf,
    compute_metrics=compute_metrics,
    temperature = 5,
    lambda_param = 0.6
)

In [47]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,1.3791,1.075377,0.4148
2,0.9672,0.87,0.5553
3,0.8363,0.705798,0.6483
4,0.6664,0.651311,0.6821
5,0.5899,0.83738,0.6345
6,0.4976,0.562861,0.7271
7,0.4608,0.651885,0.702
8,0.3919,0.790115,0.6255
9,0.3512,0.604098,0.7171
10,0.3144,0.544215,0.7583


TrainOutput(global_step=15640, training_loss=0.4183128375226579, metrics={'train_runtime': 24817.0423, 'train_samples_per_second': 40.295, 'train_steps_per_second': 0.63, 'total_flos': 2.020099608576e+18, 'train_loss': 0.4183128375226579, 'epoch': 20.0})

In [None]:
student_model.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [None]:
trainer.evaluate()

{'eval_loss': 0.4978441298007965,
 'eval_accuracy': 0.7746,
 'eval_runtime': 91.479,
 'eval_samples_per_second': 109.315,
 'eval_steps_per_second': 1.716,
 'epoch': 20.0}

In [17]:
reset_seed(42)

In [18]:
model_pretrained = get_mobilenet()

In [19]:
print(model_pretrained)

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [None]:
freeze_model(model_pretrained)

In [21]:
training_args = get_training_args("./results/cifar10-pretrained-head", './logs/cifar10-pretrained-head')

In [22]:
trainer = Trainer(
    model=model_pretrained,
    args=training_args,
    train_dataset=train_dataset_hf,
    eval_dataset=test_dataset_hf,
    compute_metrics=compute_metrics,
)

In [23]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,1.8904,1.359479,0.6311
2,1.1226,1.20726,0.6242
3,1.0168,0.971185,0.7098
4,0.9041,0.907877,0.715
5,0.8768,0.984696,0.6812
6,0.8324,0.870024,0.722
7,0.8274,1.021015,0.6553
8,0.8021,1.10505,0.6295
9,0.7919,0.866775,0.7074
10,0.785,0.875023,0.7127


TrainOutput(global_step=15640, training_loss=0.8626149672681414, metrics={'train_runtime': 1668.8023, 'train_samples_per_second': 599.232, 'train_steps_per_second': 9.372, 'total_flos': 2.020099608576e+18, 'train_loss': 0.8626149672681414, 'epoch': 20.0})

In [24]:
model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [25]:
trainer.evaluate()

{'eval_loss': 0.7992069125175476,
 'eval_accuracy': 0.7357,
 'eval_runtime': 12.9925,
 'eval_samples_per_second': 769.677,
 'eval_steps_per_second': 12.084,
 'epoch': 20.0}

In [26]:
reset_seed(42)

In [30]:
model_pretrained_whole = get_mobilenet()

In [31]:
training_args = get_training_args("./results/cifar10-pretrained", './logs/cifar10-pretrained')

In [32]:
trainer = Trainer(
    model=model_pretrained_whole,
    args=training_args,
    train_dataset=train_dataset_hf,
    eval_dataset=test_dataset_hf,
    compute_metrics=compute_metrics,
)

In [33]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.6657,0.351591,0.8819
2,0.1825,0.456693,0.859
3,0.1093,0.312467,0.9066
4,0.0525,0.29301,0.9152
5,0.0351,0.563375,0.872
6,0.0165,0.432066,0.9057
7,0.0134,0.597252,0.8762
8,0.0082,0.704292,0.8609
9,0.0052,0.474817,0.9095
10,0.0046,0.451907,0.9178


TrainOutput(global_step=15640, training_loss=0.048197138210868136, metrics={'train_runtime': 2689.1165, 'train_samples_per_second': 371.869, 'train_steps_per_second': 5.816, 'total_flos': 2.020099608576e+18, 'train_loss': 0.048197138210868136, 'epoch': 20.0})

In [34]:
model_pretrained_whole.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [35]:
trainer.evaluate()

{'eval_loss': 0.41933944821357727,
 'eval_accuracy': 0.9266,
 'eval_runtime': 13.1012,
 'eval_samples_per_second': 763.287,
 'eval_steps_per_second': 11.984,
 'epoch': 20.0}

In [42]:
reset_seed(42)

In [43]:
student_model_pretrained = get_mobilenet()

In [44]:
freeze_model(student_model_pretrained)

In [45]:
training_args = get_training_args("./results/cifar10-pretrained-head-KD", './logs/cifar10-pretrained-head-KD')

In [46]:
trainer = ImageDistilTrainer(
    student_model=student_model_pretrained,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train_dataset_hf,
    eval_dataset=test_dataset_hf,
    compute_metrics=compute_metrics,
    temperature = 5,
    lambda_param = 0.5
)

In [47]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,1.3103,0.962174,0.637
2,0.8388,0.90321,0.6369
3,0.7938,0.769327,0.7129
4,0.7484,0.747457,0.7151
5,0.7396,0.805402,0.6856
6,0.7223,0.742205,0.7209
7,0.7209,0.832782,0.6644
8,0.7109,0.877281,0.6403
9,0.7076,0.753492,0.7052
10,0.7059,0.750468,0.7209


TrainOutput(global_step=15640, training_loss=0.7418093366696097, metrics={'train_runtime': 11090.4968, 'train_samples_per_second': 90.167, 'train_steps_per_second': 1.41, 'total_flos': 2.020099608576e+18, 'train_loss': 0.7418093366696097, 'epoch': 20.0})

In [48]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [49]:
trainer.evaluate()

{'eval_loss': 0.715542197227478,
 'eval_accuracy': 0.737,
 'eval_runtime': 91.5307,
 'eval_samples_per_second': 109.253,
 'eval_steps_per_second': 1.715,
 'epoch': 20.0}

In [50]:
reset_seed(42)

In [51]:
student_model_pretrained_whole = get_mobilenet()

In [52]:
training_args = get_training_args("./results/cifar10-pretrained-KD", './logs/cifar10-pretrained-KD')

In [53]:
trainer = ImageDistilTrainer(
    student_model=student_model_pretrained_whole,
    teacher_model=teacher_model,
    args=training_args,
    train_dataset=train_dataset_hf,
    eval_dataset=test_dataset_hf,
    compute_metrics=compute_metrics,
    temperature = 5,
    lambda_param = 0.5
)

In [54]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.5448,0.346203,0.8852
2,0.2294,0.377549,0.8675
3,0.1839,0.246646,0.9216
4,0.1462,0.229126,0.9265
5,0.1317,0.368331,0.8788
6,0.116,0.258062,0.9151
7,0.1123,0.340128,0.8845
8,0.1045,0.333817,0.8801
9,0.1007,0.238064,0.9207
10,0.0991,0.213275,0.9324


TrainOutput(global_step=15640, training_loss=0.12985367409103668, metrics={'train_runtime': 12071.5298, 'train_samples_per_second': 82.84, 'train_steps_per_second': 1.296, 'total_flos': 2.020099608576e+18, 'train_loss': 0.12985367409103668, 'epoch': 20.0})

In [55]:
student_model_pretrained.eval()

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [57]:
trainer.evaluate()

{'eval_loss': 0.1961217224597931,
 'eval_accuracy': 0.9371,
 'eval_runtime': 91.3167,
 'eval_samples_per_second': 109.509,
 'eval_steps_per_second': 1.719,
 'epoch': 20.0}