# Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
! pip install monai

Collecting monai
  Downloading monai-1.4.0-py3-none-any.whl.metadata (11 kB)
Downloading monai-1.4.0-py3-none-any.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m27.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: monai
Successfully installed monai-1.4.0


In [None]:
!mkdir /content/dataset/
!mkdir -p /content/dataset/train2C/
!mkdir -p /content/dataset/val2C/
!tar -xvzf /content/drive/MyDrive/MaiaSpain/CAD/project/dataset/train.tgz -C /content/dataset/train2C/
!tar -xvzf /content/drive/MyDrive/MaiaSpain/CAD/project/dataset/val.tgz -C /content/dataset/val2C/

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
train/nevus/nev01090.jpg
train/nevus/nev03737.jpg
train/nevus/nev06203.jpg
train/nevus/nev01992.jpg
train/nevus/nev04542.jpg
train/nevus/nev04173.jpg
train/nevus/nev07309.jpg
train/nevus/nev01549.jpg
train/nevus/nev03977.jpg
train/nevus/nev00101.jpg
train/nevus/nev00499.jpg
train/nevus/nev00232.jpg
train/nevus/nev03062.jpg
train/nevus/nev03100.jpg
train/nevus/nev04661.jpg
train/nevus/nev04998.jpg
train/nevus/nev03432.jpg
train/nevus/nev07724.jpg
train/nevus/nev02586.jpg
train/nevus/nev07550.jpg
train/nevus/nev06346.jpg
train/nevus/nev05863.jpg
train/nevus/nev00483.jpg
train/nevus/nev00460.jpg
train/nevus/nev03889.jpg
train/nevus/nev07257.jpg
train/nevus/nev02541.jpg
train/nevus/nev05323.jpg
train/nevus/nev01568.jpg
train/nevus/nev06884.jpg
train/nevus/nev03661.jpg
train/nevus/nev01681.jpg
train/nevus/nev07438.jpg
train/nevus/nev00646.jpg
train/nevus/nev06119.jpg
train/nevus/nev06064.jpg
train/nevus/nev03147.jpg
train/nevu

# Initialization

In [None]:
%cd /content/drive/MyDrive/MaiaSpain/CAD/project/

/content/drive/.shortcut-targets-by-id/1fTsBspXCDEVY7q8PyaNkRq-51xfzOyIV/MaiaSpain/CAD/project


In [None]:
import os
import numpy as np
from torch.utils.data import WeightedRandomSampler
from monai.transforms import (
    Compose, Rand2DElasticd, RandRotate90d, RandFlipd, RandAffined, ScaleIntensityd,
    RandCoarseShuffled, EnsureTyped, LoadImaged, Resized, ToTensord, NormalizeIntensityd
)
from monai.data import PersistentDataset,Dataset,DataLoader
from monai.losses import FocalLoss
from PIL import Image
import torch
from torch import nn, tensor
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, StepLR
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report, cohen_kappa_score
from tqdm import tqdm
from torchvision import models
from shutil import rmtree
import matplotlib.pyplot as plt

# Dataloader

In [None]:
cls_to_idx = {'nevus':0,'others':1}
idx_to_cls = {0:'nevus',1:'others'}

def get_weights(data_path):
  classes = os.listdir(data_path)
  num_samples = {cls: len(os.listdir(os.path.join(data_path, cls))) for cls in classes}
  weights = {cls: (1 / num_samples[cls]) / sum((1 / v for v in num_samples.values())) for cls in classes}
  return weights

def get_data(data_path, cls_weights):
  data = []
  sample_weights = []
  classes = os.listdir(data_path)
  for cls in classes:
      cls_path = os.path.join(data_path, cls)
      for image_name in os.listdir(cls_path):
          image_path = os.path.join(cls_path, image_name)
          data.append({'image':image_path, 'label':cls_to_idx[cls]})
          sample_weights.append(cls_weights[cls])
  return data, sample_weights

In [None]:
base_transforms = [
    LoadImaged('image',ensure_channel_first=True),
    Resized('image',(224, 224)),
    NormalizeIntensityd('image'),
]

augmentations = [
    ScaleIntensityd('image',0,1),
    Rand2DElasticd('image',prob=0.5, spacing=56, magnitude_range=(1, 4), padding_mode="zeros"),
    RandRotate90d('image',prob=0.5, spatial_axes=[0, 1]),
    RandFlipd('image',prob=0.5, spatial_axis=0),
    RandFlipd('image',prob=0.5, spatial_axis=1),
    RandAffined('image',prob=0.5,rotate_range=(0.75, 0.75),shear_range=(0.1, 0.1),padding_mode="zeros",),
    # RandCoarseShuffled('image',holes=1, spatial_size=10, max_holes=5, max_spatial_size=30, prob=0.5),
    NormalizeIntensityd('image'),
]

train_transform = Compose(base_transforms + augmentations)
val_transform = Compose(base_transforms)

rmtree('/content/cache/',True)

cls_weights = get_weights("/content/dataset/train2C/train")
train_data, train_weights = get_data("/content/dataset/train2C/train", cls_weights)
val_data, val_weights = get_data("/content/dataset/val2C/val", cls_weights)

train_dataset = PersistentDataset(data=train_data, transform=train_transform, cache_dir='/content/cache/train')
train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
train_loader = DataLoader(train_dataset, batch_size=64, sampler=train_sampler, num_workers=os.cpu_count(), pin_memory=True)
val_dataset = PersistentDataset(data=val_data, transform=val_transform, cache_dir='/content/cache/val')
val_loader = DataLoader(val_dataset, batch_size=64, num_workers=os.cpu_count(), pin_memory=True)

In [None]:
# train_loader.dataset.transform.transforms.append(ScaleIntensityd('image',0,1))
# plt.figure(figsize=(18,5))
# batch = next(iter(train_loader))
# for i in range(16):
#   image, label = batch['image'][i], batch['label'][i]
#   plt.subplot(2,8,i+1)
#   plt.imshow(image.permute(1,2,0).numpy())
#   plt.title(idx_to_cls[label.item()])
#   plt.axis('off')
# plt.tight_layout()
# plt.show()

# Training and Evaluation

In [None]:
def calculate_metrics(labels, preds):
    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average="macro")
    recall = recall_score(labels, preds, average="macro")
    f1 = f1_score(labels, preds, average="macro")
    kappa = cohen_kappa_score(labels, preds)
    return accuracy, precision, recall, f1, kappa

def train_and_evaluate_model(model, train_loader, val_loader, device, epochs=30, patience=20, min_delta=0.0001):
    model = model.to(device)
    criterion = FocalLoss(to_onehot_y=True, use_softmax=True)
    # optimizer = Adam(model.parameters(), lr=1e-3)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.2)
    best_f1 = 0
    save_dir = "models_new/"
    os.makedirs(save_dir, exist_ok=True)

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}"):
            images, labels = batch['image'].to(device), batch['label'].to(device)
            optimizer.zero_grad()
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        scheduler.step()
        print(f"Epoch {epoch + 1}/{epochs}, Training Loss: {train_loss / len(train_loader):.4f}")
        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for batch in val_loader:
                images, labels = batch['image'].to(device), batch['label'].to(device)
                outputs = model(images).logits
                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        accuracy, precision, recall, f1, kappa = calculate_metrics(all_labels, all_preds)
        cm = confusion_matrix(all_labels, all_preds)
        print(f"Validation Metrics - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Kappa: {kappa:.4f}")
        print(f"Confusion Matrix:\n{cm}")

        if f1 > best_f1 + min_delta:
            best_f1 = f1
            model_path = os.path.join(save_dir, f"best_model_epoch_{epoch+1}.pth")
            torch.save(model.state_dict(), model_path)
            print(f"Model saved to {model_path}")
        else:
            patience -= 1
            if patience <= 0:
                print("Early stopping!")
                break

# Experiments

In [None]:
from transformers import DeiTForImageClassification, DeiTFeatureExtractor

model_name = "facebook/deit-base-distilled-patch16-224"
model = DeiTForImageClassification.from_pretrained(model_name, num_labels=len(cls_to_idx.keys()))
feature_extractor = DeiTFeatureExtractor.from_pretrained(model_name)

# model = models.efficientnet_v2_l(weights=models.EfficientNet_V2_L_Weights.IMAGENET1K_V1)
# num_ftrs = model.classifier[1].in_features
# model.classifier[1] = nn.Linear(num_ftrs, len(cls_to_idx.keys()))

train_and_evaluate_model(
    model, train_loader, val_loader, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), epochs=100
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/69.6k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/349M [00:00<?, ?B/s]

Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


preprocessor_config.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 1/100: 100%|██████████| 238/238 [13:38<00:00,  3.44s/it]

Epoch 1/100, Training Loss: 0.0889





Validation Metrics - Accuracy: 0.7432, Precision: 0.7432, Recall: 0.7428, F1: 0.7429, Kappa: 0.4859
Confusion Matrix:
[[1469  462]
 [ 513 1352]]
Model saved to models_new/best_model_epoch_1.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 2/100: 100%|██████████| 238/238 [09:44<00:00,  2.46s/it]

Epoch 2/100, Training Loss: 0.0689



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.7426, Precision: 0.7596, Recall: 0.7403, F1: 0.7371, Kappa: 0.4828
Confusion Matrix:
[[1685  246]
 [ 731 1134]]


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 3/100: 100%|██████████| 238/238 [08:13<00:00,  2.07s/it]

Epoch 3/100, Training Loss: 0.0642



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.7813, Precision: 0.7813, Recall: 0.7813, F1: 0.7813, Kappa: 0.5626
Confusion Matrix:
[[1516  415]
 [ 415 1450]]
Model saved to models_new/best_model_epoch_3.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 4/100: 100%|██████████| 238/238 [08:03<00:00,  2.03s/it]

Epoch 4/100, Training Loss: 0.0600



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.7898, Precision: 0.7920, Recall: 0.7890, F1: 0.7890, Kappa: 0.5788
Confusion Matrix:
[[1613  318]
 [ 480 1385]]
Model saved to models_new/best_model_epoch_4.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 5/100: 100%|██████████| 238/238 [07:59<00:00,  2.02s/it]

Epoch 5/100, Training Loss: 0.0577



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.7911, Precision: 0.7916, Recall: 0.7914, F1: 0.7911, Kappa: 0.5824
Confusion Matrix:
[[1489  442]
 [ 351 1514]]
Model saved to models_new/best_model_epoch_5.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 6/100: 100%|██████████| 238/238 [07:59<00:00,  2.02s/it]

Epoch 6/100, Training Loss: 0.0575



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.7958, Precision: 0.7975, Recall: 0.7965, F1: 0.7957, Kappa: 0.5921
Confusion Matrix:
[[1469  462]
 [ 313 1552]]
Model saved to models_new/best_model_epoch_6.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 7/100: 100%|██████████| 238/238 [07:59<00:00,  2.01s/it]

Epoch 7/100, Training Loss: 0.0561



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.7964, Precision: 0.7989, Recall: 0.7955, F1: 0.7956, Kappa: 0.5920
Confusion Matrix:
[[1631  300]
 [ 473 1392]]


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 8/100: 100%|██████████| 238/238 [08:01<00:00,  2.02s/it]

Epoch 8/100, Training Loss: 0.0563



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.7964, Precision: 0.8019, Recall: 0.7975, F1: 0.7958, Kappa: 0.5936
Confusion Matrix:
[[1413  518]
 [ 255 1610]]


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 9/100: 100%|██████████| 238/238 [07:58<00:00,  2.01s/it]

Epoch 9/100, Training Loss: 0.0550



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.7921, Precision: 0.8022, Recall: 0.7905, F1: 0.7897, Kappa: 0.5829
Confusion Matrix:
[[1707  224]
 [ 565 1300]]


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 10/100: 100%|██████████| 238/238 [07:58<00:00,  2.01s/it]

Epoch 10/100, Training Loss: 0.0554



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.7655, Precision: 0.7813, Recall: 0.7634, F1: 0.7612, Kappa: 0.5290
Confusion Matrix:
[[1709  222]
 [ 668 1197]]


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 11/100: 100%|██████████| 238/238 [08:02<00:00,  2.03s/it]

Epoch 11/100, Training Loss: 0.0517



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8119, Precision: 0.8128, Recall: 0.8124, F1: 0.8119, Kappa: 0.6241
Confusion Matrix:
[[1519  412]
 [ 302 1563]]
Model saved to models_new/best_model_epoch_11.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 12/100: 100%|██████████| 238/238 [07:59<00:00,  2.01s/it]

Epoch 12/100, Training Loss: 0.0516



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8174, Precision: 0.8175, Recall: 0.8176, F1: 0.8174, Kappa: 0.6349
Confusion Matrix:
[[1561  370]
 [ 323 1542]]
Model saved to models_new/best_model_epoch_12.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 13/100: 100%|██████████| 238/238 [07:58<00:00,  2.01s/it]

Epoch 13/100, Training Loss: 0.0496



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8164, Precision: 0.8165, Recall: 0.8162, F1: 0.8162, Kappa: 0.6325
Confusion Matrix:
[[1602  329]
 [ 368 1497]]


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 14/100: 100%|██████████| 238/238 [07:58<00:00,  2.01s/it]

Epoch 14/100, Training Loss: 0.0486



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8177, Precision: 0.8180, Recall: 0.8180, F1: 0.8177, Kappa: 0.6355
Confusion Matrix:
[[1549  382]
 [ 310 1555]]
Model saved to models_new/best_model_epoch_14.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 15/100: 100%|██████████| 238/238 [07:59<00:00,  2.01s/it]

Epoch 15/100, Training Loss: 0.0490



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8161, Precision: 0.8161, Recall: 0.8161, F1: 0.8161, Kappa: 0.6322
Confusion Matrix:
[[1575  356]
 [ 342 1523]]


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 16/100: 100%|██████████| 238/238 [07:59<00:00,  2.02s/it]

Epoch 16/100, Training Loss: 0.0476



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8188, Precision: 0.8205, Recall: 0.8181, F1: 0.8182, Kappa: 0.6370
Confusion Matrix:
[[1655  276]
 [ 412 1453]]
Model saved to models_new/best_model_epoch_16.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 17/100: 100%|██████████| 238/238 [07:58<00:00,  2.01s/it]

Epoch 17/100, Training Loss: 0.0485



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8193, Precision: 0.8213, Recall: 0.8186, F1: 0.8187, Kappa: 0.6380
Confusion Matrix:
[[1660  271]
 [ 415 1450]]
Model saved to models_new/best_model_epoch_17.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 18/100: 100%|██████████| 238/238 [07:58<00:00,  2.01s/it]

Epoch 18/100, Training Loss: 0.0476



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8190, Precision: 0.8191, Recall: 0.8191, F1: 0.8190, Kappa: 0.6381
Confusion Matrix:
[[1568  363]
 [ 324 1541]]
Model saved to models_new/best_model_epoch_18.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 19/100: 100%|██████████| 238/238 [08:01<00:00,  2.02s/it]

Epoch 19/100, Training Loss: 0.0451



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8288, Precision: 0.8292, Recall: 0.8284, F1: 0.8286, Kappa: 0.6572
Confusion Matrix:
[[1639  292]
 [ 358 1507]]
Model saved to models_new/best_model_epoch_19.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 20/100: 100%|██████████| 238/238 [08:00<00:00,  2.02s/it]

Epoch 20/100, Training Loss: 0.0461



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8311, Precision: 0.8325, Recall: 0.8306, F1: 0.8307, Kappa: 0.6618
Confusion Matrix:
[[1669  262]
 [ 379 1486]]
Model saved to models_new/best_model_epoch_20.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 21/100: 100%|██████████| 238/238 [07:58<00:00,  2.01s/it]

Epoch 21/100, Training Loss: 0.0430



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8335, Precision: 0.8335, Recall: 0.8334, F1: 0.8335, Kappa: 0.6669
Confusion Matrix:
[[1617  314]
 [ 318 1547]]
Model saved to models_new/best_model_epoch_21.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 22/100: 100%|██████████| 238/238 [07:59<00:00,  2.02s/it]

Epoch 22/100, Training Loss: 0.0409



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8332, Precision: 0.8333, Recall: 0.8331, F1: 0.8331, Kappa: 0.6663
Confusion Matrix:
[[1628  303]
 [ 330 1535]]


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 23/100: 100%|██████████| 238/238 [08:00<00:00,  2.02s/it]

Epoch 23/100, Training Loss: 0.0396



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8293, Precision: 0.8294, Recall: 0.8295, F1: 0.8293, Kappa: 0.6586
Confusion Matrix:
[[1583  348]
 [ 300 1565]]


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 24/100: 100%|██████████| 238/238 [08:03<00:00,  2.03s/it]

Epoch 24/100, Training Loss: 0.0401



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8314, Precision: 0.8314, Recall: 0.8315, F1: 0.8314, Kappa: 0.6628
Confusion Matrix:
[[1597  334]
 [ 306 1559]]


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 25/100: 100%|██████████| 238/238 [08:01<00:00,  2.02s/it]

Epoch 25/100, Training Loss: 0.0387



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8264, Precision: 0.8281, Recall: 0.8257, F1: 0.8259, Kappa: 0.6523
Confusion Matrix:
[[1668  263]
 [ 396 1469]]


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 26/100: 100%|██████████| 238/238 [08:01<00:00,  2.02s/it]

Epoch 26/100, Training Loss: 0.0374



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8264, Precision: 0.8275, Recall: 0.8269, F1: 0.8264, Kappa: 0.6531
Confusion Matrix:
[[1540  391]
 [ 268 1597]]


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 27/100: 100%|██████████| 238/238 [08:01<00:00,  2.02s/it]

Epoch 27/100, Training Loss: 0.0375



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8359, Precision: 0.8364, Recall: 0.8355, F1: 0.8357, Kappa: 0.6714
Confusion Matrix:
[[1655  276]
 [ 347 1518]]
Model saved to models_new/best_model_epoch_27.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 28/100:  68%|██████▊   | 161/238 [05:29<02:34,  2.00s/it]

In [None]:
from transformers import DeiTForImageClassification, DeiTFeatureExtractor, DeiTConfig

config = DeiTConfig.from_pretrained("facebook/deit-base-distilled-patch16-224")
config.num_labels = 2
config.output_hidden_states = True
model = DeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224", config=config)

import torch
outputs = model(torch.randn(1,3,224,224))
pred = outputs.logits
distill_token = outputs.hidden_states[-1][:, -1, :]


Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


torch.Size([1, 2]) torch.Size([1, 768])


In [None]:
print(pred.shape, distill_token.shape)

torch.Size([1, 2]) torch.Size([1, 768])


In [None]:
distill_token.shape

google_token = None

distill_loss = distill_loss(distill_token, google_token)

focal_loss + distill_loss

torch.Size([1, 768])