# Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
! pip install monai

Collecting monai
  Downloading monai-1.4.0-py3-none-any.whl.metadata (11 kB)
Downloading monai-1.4.0-py3-none-any.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m38.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: monai
Successfully installed monai-1.4.0


In [None]:
!mkdir /content/dataset/
!mkdir -p /content/dataset/train2C/
!mkdir -p /content/dataset/val2C/
!tar -xvzf /content/drive/MyDrive/MaiaSpain/CAD/project/dataset/train.tgz -C /content/dataset/train2C/
!tar -xvzf /content/drive/MyDrive/MaiaSpain/CAD/project/dataset/val.tgz -C /content/dataset/val2C/
!mkdir -p /content/dataset/test2C/
!tar -xvzf /content/drive/MyDrive/MaiaSpain/CAD/project/dataset/test/test2C.tgz -C /content/dataset/test2C/

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
testX/xxx04139.jpg
testX/xxx05763.jpg
testX/xxx01199.jpg
testX/xxx03592.jpg
testX/xxx06086.jpg
testX/xxx02062.jpg
testX/xxx01124.jpg
testX/xxx00143.jpg
testX/xxx05437.jpg
testX/xxx00356.jpg
testX/xxx03594.jpg
testX/xxx02648.jpg
testX/xxx02681.jpg
testX/xxx01684.jpg
testX/xxx02143.jpg
testX/xxx01114.jpg
testX/xxx01592.jpg
testX/xxx02163.jpg
testX/xxx05441.jpg
testX/xxx02756.jpg
testX/xxx01785.jpg
testX/xxx00855.jpg
testX/xxx05754.jpg
testX/xxx00763.jpg
testX/xxx01156.jpg
testX/xxx03947.jpg
testX/xxx04798.jpg
testX/xxx05700.jpg
testX/xxx01011.jpg
testX/xxx05481.jpg
testX/xxx03439.jpg
testX/xxx02784.jpg
testX/xxx04103.jpg
testX/xxx04951.jpg
testX/xxx00865.jpg
testX/xxx05506.jpg
testX/xxx02922.jpg
testX/xxx02887.jpg
testX/xxx05907.jpg
testX/xxx02986.jpg
testX/xxx01465.jpg
testX/xxx03753.jpg
testX/xxx04834.jpg
testX/xxx03365.jpg
testX/xxx01115.jpg
testX/xxx02641.jpg
testX/xxx01640.jpg
testX/xxx00853.jpg
testX/xxx01432.jpg
test

# Initialization

In [None]:
%cd /content/drive/MyDrive/MaiaSpain/CAD/project/

/content/drive/.shortcut-targets-by-id/1fTsBspXCDEVY7q8PyaNkRq-51xfzOyIV/MaiaSpain/CAD/project


In [None]:
import os
import numpy as np
from torch.utils.data import WeightedRandomSampler
from monai.transforms import (
    Compose, Rand2DElasticd, RandRotate90d, RandFlipd, RandAffined, ScaleIntensityd,
    RandCoarseShuffled, EnsureTyped, LoadImaged, Resized, ToTensord, NormalizeIntensityd
)
from monai.data import PersistentDataset, Dataset, DataLoader
from monai.losses import FocalLoss
from PIL import Image
import torch
from torch import nn, tensor
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, StepLR
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report, cohen_kappa_score
from tqdm import tqdm
from torchvision import models
from shutil import rmtree
from transformers import DeiTForImageClassification, DeiTConfig
import csv
from torch.nn import functional as F

# Dataloader

In [None]:
cls_to_idx = {'nevus':0,'others':1}
idx_to_cls = {0:'nevus',1:'others'}

def get_weights(data_path):
  classes = os.listdir(data_path)
  num_samples = {cls: len(os.listdir(os.path.join(data_path, cls))) for cls in classes}
  weights = {cls: (1 / num_samples[cls]) / sum((1 / v for v in num_samples.values())) for cls in classes}
  return weights

def load_from_npz(file_name):
    data = np.load(file_name, allow_pickle=True)
    paths = data['paths']
    features = data['features']
    labels = data['labels']
    return paths, features, labels


def get_data(data_path, cls_weights, embeddings_path=None):
  data = []
  sample_weights = []
  classes = os.listdir(data_path)
  if embeddings_path is not None:
    npz_paths, npz_embeddings, npz_labels = load_from_npz(embeddings_path)
  for cls in classes:
      cls_path = os.path.join(data_path, cls)
      for image_name in os.listdir(cls_path):
          image_path = os.path.join(cls_path, image_name)
          if embeddings_path is not None:
            if image_path.replace('/content/dataset/','/Users/ayaelgebaly/Downloads/maiaUdg/CAD/DLproject/') not in npz_paths:
              print(image_path.replace('/content/dataset/','/Users/ayaelgebaly/Downloads/maiaUdg/CAD/DLproject/'))
              continue
            embedding = npz_embeddings[list(npz_paths).index(image_path.replace('/content/dataset/','/Users/ayaelgebaly/Downloads/maiaUdg/CAD/DLproject/'))]
            assert cls_to_idx[cls] == npz_labels[list(npz_paths).index(image_path.replace('/content/dataset/','/Users/ayaelgebaly/Downloads/maiaUdg/CAD/DLproject/'))]
            data.append({'image':image_path, 'label':cls_to_idx[cls], 'embedding':embedding})
          else:
            data.append({'image':image_path, 'label':cls_to_idx[cls]})
          sample_weights.append(cls_weights[cls])
  return data, sample_weights

In [None]:
base_transforms = [
    LoadImaged('image',ensure_channel_first=True),
    Resized('image',(224, 224)),
    NormalizeIntensityd(['image'])
]

augmentations = [
    ScaleIntensityd('image',0,1),
    Rand2DElasticd('image',prob=0.75, spacing=56, magnitude_range=(1, 3), padding_mode="zeros"),
    RandRotate90d('image',prob=0.75, spatial_axes=[0, 1]),
    RandFlipd('image',prob=0.5, spatial_axis=0),
    RandFlipd('image',prob=0.5, spatial_axis=1),
    RandAffined('image',prob=0.75,rotate_range=(0.75, 0.75),padding_mode="zeros",),
    NormalizeIntensityd('image')
]

train_transform = Compose(base_transforms + augmentations)
val_transform = Compose(base_transforms)

# rmtree('/content/cache/',True)

cls_weights = get_weights("/content/dataset/train2C/train")
train_data, train_weights = get_data("/content/dataset/train2C/train", cls_weights, 'dataset/test/google_derm_train2c_embeddings.npz')
val_data, val_weights = get_data("/content/dataset/val2C/val", cls_weights)

train_dataset = PersistentDataset(data=train_data, transform=train_transform, cache_dir='/content/cache/train')
train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
train_loader = DataLoader(train_dataset, batch_size=32, sampler=train_sampler, num_workers=os.cpu_count(), pin_memory=True)
val_dataset = PersistentDataset(data=val_data, transform=val_transform, cache_dir='/content/cache/val')
val_loader = DataLoader(val_dataset, batch_size=32, num_workers=os.cpu_count(), pin_memory=True)

# Visualization

In [None]:
# import matplotlib.pyplot as plt

# train_loader.dataset.transform.transforms = train_loader.dataset.transform.transforms+(ScaleIntensityd('image',0,1),)
# plt.figure(figsize=(18,10))
# batch = next(iter(train_loader))
# for i in range(32):
#   image, label = batch['image'][i], batch['label'][i]
#   plt.subplot(4,8,i+1)
#   plt.imshow(image.permute(1,2,0).numpy())
#   plt.title(idx_to_cls[label.item()])
#   plt.axis('off')
# plt.tight_layout()
# plt.show()

# Training and Evaluation

In [None]:
def calculate_metrics(labels, preds):
    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average="macro")
    recall = recall_score(labels, preds, average="macro")
    f1 = f1_score(labels, preds, average="macro")
    kappa = cohen_kappa_score(labels, preds)
    return accuracy, precision, recall, f1, kappa

def train_and_evaluate_model(model, train_loader, val_loader, device, epochs=30, min_delta=0.0001):
    model = model.to(device)
    projector = nn.Sequential(nn.Linear(768, 3072), nn.ReLU(), nn.Linear(3072, 6144), nn.ReLU()).to(device)
    cls_criterion = FocalLoss(to_onehot_y=True, use_softmax=True)
    distill_criterion = nn.KLDivLoss(reduction='batchmean')
    optimizer = torch.optim.AdamW(list(model.parameters()) + list(projector.parameters()), lr=1e-4, weight_decay=1e-4)
    scheduler=StepLR(optimizer, step_size=15, gamma=0.2)
    best_accuracy = 0
    save_dir = "models_alberb_2cls_distill/"
    os.makedirs(save_dir, exist_ok=True)

    for epoch in range(epochs):
        model.train()
        epoch_cls_loss, epoch_distill_loss = 0, 0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}"):
            images, labels, gt_embedding = batch['image'].to(device), batch['label'].to(device), batch['embedding'].to(device)
            optimizer.zero_grad()
            outputs = model(images)
            outputs, embedding = outputs.logits, outputs.hidden_states[-1][:, -1, :]
            cls_loss = cls_criterion(outputs, labels)
            distill_loss = distill_criterion(F.log_softmax(projector(embedding), dim=-1), F.softmax(gt_embedding, dim=-1))
            loss = cls_loss + distill_loss
            loss.backward()
            optimizer.step()
            epoch_cls_loss += cls_loss.item()
            epoch_distill_loss += distill_loss.item()
        scheduler.step()

        print(f"Epoch {epoch + 1}/{epochs}, Classification Loss: {epoch_cls_loss / len(train_loader):.6f}, Distill Loss: {epoch_distill_loss / len(train_loader):.6f}")

        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for batch in val_loader:
                images, labels = batch['image'].to(device), batch['label'].to(device)
                outputs = model(images)
                outputs = outputs.logits
                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        accuracy, precision, recall, f1, kappa = calculate_metrics(all_labels, all_preds)
        cm = confusion_matrix(all_labels, all_preds)
        print(f"Validation Metrics - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Kappa: {kappa:.4f}")
        print(f"Confusion Matrix:\n{cm}")

        if accuracy > best_accuracy + min_delta:
            model_path = os.path.join(save_dir, f"best_model_{accuracy:.4f}.pth")
            torch.save(model.state_dict(), model_path)
            if os.path.exists(os.path.join(save_dir, f"best_model_{best_accuracy:.4f}.pth")):
              os.remove(os.path.join(save_dir, f"best_model_{best_accuracy:.4f}.pth"))
            best_accuracy = accuracy
            print(f"Model saved to {model_path}")

In [None]:
from transformers import DeiTForImageClassification, DeiTConfig

config = DeiTConfig.from_pretrained("facebook/deit-base-distilled-patch16-224")
config.num_labels, config.output_hidden_states = 2, True
model = DeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224", config=config)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Experiments

In [None]:
train_and_evaluate_model(
    model, train_loader, val_loader, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), epochs=100
)

  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 1/100: 100%|██████████| 475/475 [09:05<00:00,  1.15s/it]

Epoch 1/100, Classification Loss: 0.053860, Distill Loss: 0.193884



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8388, Precision: 0.8422, Recall: 0.8396, F1: 0.8386, Kappa: 0.6780
Confusion Matrix:
[[1525  406]
 [ 206 1659]]
Model saved to models_alberb_2cls_distill/best_model_0.8388.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 2/100: 100%|██████████| 475/475 [08:49<00:00,  1.11s/it]

Epoch 2/100, Classification Loss: 0.040954, Distill Loss: 0.124359



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8512, Precision: 0.8524, Recall: 0.8517, F1: 0.8511, Kappa: 0.7026
Confusion Matrix:
[[1586  345]
 [ 220 1645]]
Model saved to models_alberb_2cls_distill/best_model_0.8512.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 3/100: 100%|██████████| 475/475 [08:52<00:00,  1.12s/it]

Epoch 3/100, Classification Loss: 0.035087, Distill Loss: 0.107187



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8612, Precision: 0.8612, Recall: 0.8610, F1: 0.8611, Kappa: 0.7222
Confusion Matrix:
[[1680  251]
 [ 276 1589]]
Model saved to models_alberb_2cls_distill/best_model_0.8612.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 4/100: 100%|██████████| 475/475 [08:48<00:00,  1.11s/it]

Epoch 4/100, Classification Loss: 0.031649, Distill Loss: 0.098941



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8614, Precision: 0.8654, Recall: 0.8624, F1: 0.8612, Kappa: 0.7233
Confusion Matrix:
[[1563  368]
 [ 158 1707]]
Model saved to models_alberb_2cls_distill/best_model_0.8614.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 5/100: 100%|██████████| 475/475 [08:44<00:00,  1.10s/it]

Epoch 5/100, Classification Loss: 0.028390, Distill Loss: 0.089162



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8772, Precision: 0.8810, Recall: 0.8764, F1: 0.8767, Kappa: 0.7540
Confusion Matrix:
[[1787  144]
 [ 322 1543]]
Model saved to models_alberb_2cls_distill/best_model_0.8772.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 6/100: 100%|██████████| 475/475 [08:44<00:00,  1.10s/it]

Epoch 6/100, Classification Loss: 0.024468, Distill Loss: 0.082894



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8833, Precision: 0.8848, Recall: 0.8828, F1: 0.8830, Kappa: 0.7663
Confusion Matrix:
[[1764  167]
 [ 276 1589]]
Model saved to models_alberb_2cls_distill/best_model_0.8833.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 7/100: 100%|██████████| 475/475 [08:46<00:00,  1.11s/it]

Epoch 7/100, Classification Loss: 0.021380, Distill Loss: 0.081743



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8772, Precision: 0.8786, Recall: 0.8778, F1: 0.8772, Kappa: 0.7547
Confusion Matrix:
[[1634  297]
 [ 169 1696]]


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 8/100: 100%|██████████| 475/475 [08:43<00:00,  1.10s/it]

Epoch 8/100, Classification Loss: 0.019752, Distill Loss: 0.072364



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8844, Precision: 0.8844, Recall: 0.8845, F1: 0.8843, Kappa: 0.7687
Confusion Matrix:
[[1688  243]
 [ 196 1669]]
Model saved to models_alberb_2cls_distill/best_model_0.8844.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 9/100: 100%|██████████| 475/475 [08:40<00:00,  1.10s/it]

Epoch 9/100, Classification Loss: 0.016832, Distill Loss: 0.070557



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8580, Precision: 0.8598, Recall: 0.8586, F1: 0.8579, Kappa: 0.7163
Confusion Matrix:
[[1588  343]
 [ 196 1669]]


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 10/100: 100%|██████████| 475/475 [08:41<00:00,  1.10s/it]


Epoch 10/100, Classification Loss: 0.016225, Distill Loss: 0.068276


  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8830, Precision: 0.8839, Recall: 0.8835, F1: 0.8830, Kappa: 0.7662
Confusion Matrix:
[[1657  274]
 [ 170 1695]]


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 11/100: 100%|██████████| 475/475 [08:38<00:00,  1.09s/it]

Epoch 11/100, Classification Loss: 0.013947, Distill Loss: 0.065551



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8880, Precision: 0.8881, Recall: 0.8882, F1: 0.8880, Kappa: 0.7761
Confusion Matrix:
[[1698  233]
 [ 192 1673]]
Model saved to models_alberb_2cls_distill/best_model_0.8880.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 12/100: 100%|██████████| 475/475 [08:36<00:00,  1.09s/it]

Epoch 12/100, Classification Loss: 0.011651, Distill Loss: 0.060666



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8920, Precision: 0.8919, Recall: 0.8920, F1: 0.8920, Kappa: 0.7839
Confusion Matrix:
[[1720  211]
 [ 199 1666]]
Model saved to models_alberb_2cls_distill/best_model_0.8920.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 13/100: 100%|██████████| 475/475 [08:43<00:00,  1.10s/it]

Epoch 13/100, Classification Loss: 0.010712, Distill Loss: 0.061132



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8988, Precision: 0.8992, Recall: 0.8986, F1: 0.8988, Kappa: 0.7975
Confusion Matrix:
[[1762  169]
 [ 215 1650]]
Model saved to models_alberb_2cls_distill/best_model_0.8988.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 14/100: 100%|██████████| 475/475 [08:50<00:00,  1.12s/it]

Epoch 14/100, Classification Loss: 0.010284, Distill Loss: 0.060851



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8930, Precision: 0.8940, Recall: 0.8926, F1: 0.8929, Kappa: 0.7859
Confusion Matrix:
[[1769  162]
 [ 244 1621]]


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 15/100: 100%|██████████| 475/475 [08:39<00:00,  1.09s/it]

Epoch 15/100, Classification Loss: 0.008644, Distill Loss: 0.058724



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.8994, Precision: 0.8993, Recall: 0.8994, F1: 0.8994, Kappa: 0.7987
Confusion Matrix:
[[1729  202]
 [ 180 1685]]
Model saved to models_alberb_2cls_distill/best_model_0.8994.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 16/100: 100%|██████████| 475/475 [08:43<00:00,  1.10s/it]

Epoch 16/100, Classification Loss: 0.003697, Distill Loss: 0.046706



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.9089, Precision: 0.9092, Recall: 0.9086, F1: 0.9088, Kappa: 0.8176
Confusion Matrix:
[[1780  151]
 [ 195 1670]]
Model saved to models_alberb_2cls_distill/best_model_0.9089.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 17/100: 100%|██████████| 475/475 [09:00<00:00,  1.14s/it]

Epoch 17/100, Classification Loss: 0.001646, Distill Loss: 0.042387



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.9110, Precision: 0.9111, Recall: 0.9108, F1: 0.9109, Kappa: 0.8218
Confusion Matrix:
[[1774  157]
 [ 181 1684]]
Model saved to models_alberb_2cls_distill/best_model_0.9110.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 18/100: 100%|██████████| 475/475 [08:44<00:00,  1.10s/it]

Epoch 18/100, Classification Loss: 0.001165, Distill Loss: 0.042477



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.9117, Precision: 0.9118, Recall: 0.9116, F1: 0.9117, Kappa: 0.8234
Confusion Matrix:
[[1774  157]
 [ 178 1687]]
Model saved to models_alberb_2cls_distill/best_model_0.9117.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 19/100: 100%|██████████| 475/475 [08:48<00:00,  1.11s/it]

Epoch 19/100, Classification Loss: 0.000906, Distill Loss: 0.039884



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.9131, Precision: 0.9133, Recall: 0.9129, F1: 0.9130, Kappa: 0.8260
Confusion Matrix:
[[1783  148]
 [ 182 1683]]
Model saved to models_alberb_2cls_distill/best_model_0.9131.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 20/100: 100%|██████████| 475/475 [08:51<00:00,  1.12s/it]

Epoch 20/100, Classification Loss: 0.000513, Distill Loss: 0.038294



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.9181, Precision: 0.9186, Recall: 0.9178, F1: 0.9180, Kappa: 0.8360
Confusion Matrix:
[[1803  128]
 [ 183 1682]]
Model saved to models_alberb_2cls_distill/best_model_0.9181.pth


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 21/100: 100%|██████████| 475/475 [08:41<00:00,  1.10s/it]

Epoch 21/100, Classification Loss: 0.000576, Distill Loss: 0.037191



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.9131, Precision: 0.9130, Recall: 0.9131, F1: 0.9130, Kappa: 0.8261
Confusion Matrix:
[[1761  170]
 [ 160 1705]]


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 22/100: 100%|██████████| 475/475 [08:32<00:00,  1.08s/it]

Epoch 22/100, Classification Loss: 0.000543, Distill Loss: 0.036437



  return torch.load(hashfile)
  return torch.load(hashfile)


Validation Metrics - Accuracy: 0.9170, Precision: 0.9172, Recall: 0.9169, F1: 0.9170, Kappa: 0.8339
Confusion Matrix:
[[1787  144]
 [ 171 1694]]


  return torch.load(hashfile)
  return torch.load(hashfile)
Epoch 23/100:  52%|█████▏    | 246/475 [04:27<04:22,  1.15s/it]

# TTA

In [None]:
tta_augs = Compose(augmentations)
def augment(images):
  images = torch.stack([tta_augs({'image':image})['image'] for image in images], 0)
  return images

In [None]:
tta_iters = 7
model.to('cuda')
model.load_state_dict(torch.load('models_alberb_2cls_distill/best_model_0.9181.pth'))
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for batch in val_loader:
        images, labels = batch['image'].to('cuda'), batch['label'].to('cuda')
        outputs = torch.softmax(model(images).logits,1)
        for _ in range(tta_iters):
          outputs += torch.softmax(model(augment(images.clone())).logits,1)
        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    accuracy, precision, recall, f1, kappa = calculate_metrics(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)
    print(f"Validation Metrics - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Kappa: {kappa:.4f}")
    print(f"Confusion Matrix:\n{cm}")

  model.load_state_dict(torch.load('models_alberb_2cls_distill/best_model_0.9181.pth'))


Validation Metrics - Accuracy: 0.9207, Precision: 0.9215, Recall: 0.9204, F1: 0.9206, Kappa: 0.8413
Confusion Matrix:
[[1815  116]
 [ 185 1680]]


# Prediction

In [None]:
def get_test_data(data_path):
  data = []
  for image_name in sorted(os.listdir(data_path)):
      image_path = os.path.join(data_path, image_name)
      data.append({'image':image_path})
  return data
test_transforms = Compose(base_transforms)
test_data = get_test_data("/content/dataset/test2C/testX")
test_dataset = Dataset(data=test_data, transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=os.cpu_count(), pin_memory=True)

In [None]:
model_path = 'models_alberb_2cls_distill/best_model_0.9181.pth'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval();
model = model.to(device)

  model.load_state_dict(torch.load(model_path, map_location=device))


In [None]:
tta_augs = Compose(augmentations)
def augment(images):
  images = torch.stack([tta_augs({'image':image})['image'] for image in images], 0)
  return images

tta_iters = 7
all_preds = []
with torch.no_grad():
    for batch in test_loader:
        images = batch['image'].to(device)
        outputs = torch.softmax(model(images).logits,1)
        for _ in range(tta_iters):
          outputs += torch.softmax(model(augment(images.clone())).logits,1)
        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
with open(model_path.replace('.pth','_test_results.csv'), "w", newline="") as f:
  csv.writer(f).writerows([[x] for x in all_preds])

In [None]:
print(np.unique(all_preds, return_counts=True))

(array([0, 1]), array([3365, 2975]))
