In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

class CustomCNN(nn.Module):
    def __init__(self, num_classes=7):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(32 * 16 * 16, 64)
        self.fc2 = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # [B, 16, 64, 64] -> [B, 16, 32, 32]
        x = self.pool(F.relu(self.conv2(x)))  # [B, 32, 16, 16]
        x = x.view(x.size(0), -1)             # flatten
        x = F.relu(self.fc1(x))
        return self.fc2(x)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
teacher_model = CustomCNN(num_classes=7)
teacher_model.load_state_dict(torch.load("custom_cnn.pth", map_location=device))
teacher_model.to(device)
teacher_model.eval()


CustomCNN(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=8192, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=7, bias=True)
)

In [21]:
from transformers import ViTForImageClassification

# ✅ ViT tiny student — for fast distillation
student_model = ViTForImageClassification.from_pretrained(
    "WinKawaks/vit-tiny-patch16-224",
    num_labels=7,
    id2label={i: str(i) for i in range(7)},
    label2id={str(i): i for i in range(7)},
    ignore_mismatched_sizes=True  # ✅ fixes classifier shape mismatch
).to(device)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at WinKawaks/vit-tiny-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([7]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 192]) in the checkpoint and torch.Size([7, 192]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [22]:
from datasets import Dataset
from glob import glob
from PIL import Image
import torchvision.transforms as T

# ✅ Collect image paths and labels
image_paths = glob("./data/MedNIST/*/*.jpeg")
labels = [os.path.basename(os.path.dirname(p)) for p in image_paths]
label2id = {v: k for k, v in enumerate(sorted(set(labels)))}
id2label = {v: k for k, v in label2id.items()}

# ✅ Convert to HuggingFace dataset
examples = [{"image": p, "label": label2id[os.path.basename(os.path.dirname(p))]} for p in image_paths]
dataset = Dataset.from_list(examples).train_test_split(test_size=0.2)

# ✅ Define HuggingFace ViT processor
from transformers import ViTImageProcessor
processor = ViTImageProcessor.from_pretrained("WinKawaks/vit-tiny-patch16-224")

# ✅ Transform function for HuggingFace format
def transform(example):
    img = Image.open(example["image"]).convert("RGB")
    inputs = processor(img, return_tensors="pt")
    example["pixel_values"] = inputs["pixel_values"].squeeze()
    return example

# ✅ Apply transform
train_hf = dataset["train"].map(transform)
test_hf = dataset["test"].map(transform)


Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 47163/47163 [13:18<00:00, 59.04 examples/s]
Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11791/11791 [03:30<00:00, 56.02 examples/s]


In [23]:
from torch.utils.data import DataLoader

# ✅ Collate to handle batch stacking
def collate_fn(batch):
    pixel_values = torch.stack([x["pixel_values"] for x in batch])
    labels = torch.tensor([x["label"] for x in batch])
    return {"pixel_values": pixel_values, "label": labels}

train_hf.set_format(type="torch")
test_hf.set_format(type="torch")

train_loader = DataLoader(train_hf, batch_size=16, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_hf, batch_size=16, shuffle=False, collate_fn=collate_fn)


In [24]:
from transformers import Trainer
import torch.nn.functional as F
import torch.nn as nn

class DistillationTrainer(Trainer):
    def __init__(self, teacher_model, temperature=4.0, alpha=0.5, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        self.teacher.eval()
        self.temperature = temperature
        self.alpha = alpha

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("label")
        pixel_values = inputs["pixel_values"].to(model.device)
        labels = labels.to(model.device)

        # Student logits
        student_logits = model(pixel_values).logits

        with torch.no_grad():
            # Convert RGB to grayscale like custom CNN input
            gray = pixel_values.mean(dim=1, keepdim=True)
            teacher_logits = self.teacher(gray)

        # Losses
        ce_loss = nn.CrossEntropyLoss()(student_logits, labels)
        kd_loss = nn.KLDivLoss(reduction="batchmean")(
            F.log_softmax(student_logits / self.temperature, dim=1),
            F.softmax(teacher_logits / self.temperature, dim=1)
        )
        loss = self.alpha * ce_loss + (1 - self.alpha) * kd_loss
        return (loss, student_logits) if return_outputs else loss


In [25]:
!pip install accelerate -U



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [26]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    logging_dir="./logs",
    logging_steps=10,
    evaluation_strategy="epoch",
    save_strategy="no",
    remove_unused_columns=False,
    report_to="none"  # disables wandb if used
)


ImportError: Using the `Trainer` with `PyTorch` requires `accelerate>=0.21.0`: Please run `pip install transformers[torch]` or `pip install accelerate -U`