In [1]:
import os
import random
from PIL import Image

import torch
from torch import nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

import timm

import pandas as pd
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, precision_score, classification_report, roc_auc_score, recall_score

SEED = 16

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Check for CUDA device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

cuda


In [3]:
dataset_path = '../DATA/Train_Val_set/Val'
model_path = '../MODELS/VIT_DataSet01_M1/fold3_model.pth'

# Parameters
model_name = 'tiny_vit_21m_512.dist_in22k_ft_in1k'
num_classes = 4
batch_size = 30

In [4]:
m = timm.create_model(model_name, pretrained=True, num_classes=num_classes, drop_rate=0.3)
model_info = m.default_cfg
del m

input_shape = model_info['input_size'][1:]
transform_mean = model_info['mean']
transform_std = model_info['std']

print(f"USING MODEL ARCHITECTURE {model_info['architecture']} ")
print(f"INPUT SHAPE = {input_shape}")
print(f"       MEAN = {transform_mean}")
print(f"        STD = {transform_std}")

USING MODEL ARCHITECTURE tiny_vit_21m_512 
INPUT SHAPE = (512, 512)
       MEAN = (0.485, 0.456, 0.406)
        STD = (0.229, 0.224, 0.225)


In [5]:
data_transform = transforms.Compose([
        transforms.Resize(input_shape),
        transforms.ToTensor(),
        transforms.Normalize(mean=transform_mean, std=transform_std)
    ])

data_dataset = datasets.ImageFolder(root=dataset_path, transform=data_transform)
data_loader = DataLoader(data_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

print("CLASS MAPPING")
print(data_dataset.class_to_idx)

CLASS MAPPING
{'GNB': 0, 'GNC': 1, 'GPB': 2, 'GPC': 3}


In [6]:
model = timm.create_model(model_name, pretrained=True, num_classes=num_classes, drop_rate=0.3)
model.load_state_dict(torch.load(model_path, map_location=torch.device(DEVICE)))
model = model.to(DEVICE)
model.eval()

  model.load_state_dict(torch.load(model_path, map_location=torch.device(DEVICE)))


TinyVit(
  (patch_embed): PatchEmbed(
    (conv1): ConvNorm(
      (conv): Conv2d(3, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (act): GELU(approximate='none')
    (conv2): ConvNorm(
      (conv): Conv2d(48, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (stages): Sequential(
    (0): ConvLayer(
      (blocks): Sequential(
        (0): MBConv(
          (conv1): ConvNorm(
            (conv): Conv2d(96, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (act1): GELU(approximate='none')
          (conv2): ConvNorm(
            (conv): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=Fals

In [7]:
all_labels = []
all_preds = []
logits = []

with torch.no_grad():
    for i, (inputs, labels) in enumerate(data_loader):
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        outputs = model(inputs)
        
        for x in outputs: logits.append(x)

        _, predicted = torch.max(outputs, 1)
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(predicted.cpu().numpy())

all_labels = np.array(all_labels)
all_preds = np.array(all_preds)

print(classification_report(all_labels, all_preds, target_names=[i for i in data_dataset.class_to_idx]))

  x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)


              precision    recall  f1-score   support

         GNB       0.80      0.78      0.79        50
         GNC       0.95      0.38      0.54        50
         GPB       0.84      0.86      0.85        50
         GPC       0.61      0.98      0.75        50

    accuracy                           0.75       200
   macro avg       0.80      0.75      0.73       200
weighted avg       0.80      0.75      0.73       200



In [8]:
data_dataset.class_to_idx

{'GNB': 0, 'GNC': 1, 'GPB': 2, 'GPC': 3}

In [None]:
# convert logits to probabilities
logits = torch.stack(logits)

logits = nn.Softmax(dim=1)(logits)

tensor([[0.7473, 0.1190, 0.0941, 0.0396],
        [0.6064, 0.0090, 0.2838, 0.1009],
        [0.8222, 0.0335, 0.0611, 0.0832],
        [0.4364, 0.0147, 0.5044, 0.0445],
        [0.9148, 0.0197, 0.0202, 0.0453],
        [0.8382, 0.0509, 0.0950, 0.0159],
        [0.6633, 0.1509, 0.1034, 0.0824],
        [0.5562, 0.1232, 0.2418, 0.0788],
        [0.7127, 0.0135, 0.2426, 0.0311],
        [0.7020, 0.0209, 0.0753, 0.2018],
        [0.6452, 0.1812, 0.0690, 0.1046],
        [0.5278, 0.0101, 0.3119, 0.1503],
        [0.3892, 0.0059, 0.3126, 0.2923],
        [0.0448, 0.0135, 0.8953, 0.0465],
        [0.7107, 0.0275, 0.1721, 0.0897],
        [0.2538, 0.4197, 0.0655, 0.2610],
        [0.2458, 0.0291, 0.5494, 0.1757],
        [0.6288, 0.0088, 0.2062, 0.1562],
        [0.8478, 0.1021, 0.0139, 0.0362],
        [0.4775, 0.0220, 0.3278, 0.1727],
        [0.1349, 0.0109, 0.4119, 0.4423],
        [0.6149, 0.0106, 0.3388, 0.0357],
        [0.1228, 0.0595, 0.3562, 0.4614],
        [0.7719, 0.0497, 0.0758, 0

In [16]:
print("GNB, GNC, GPB, GPC, label, pred")
for prob, lable, pred in zip(logits, all_labels, all_preds):
    print(f'{prob[0]}, {prob[1]}, {prob[2]}, {prob[3]}, {lable}, {pred}')

GNB, GNC, GPB, GPC, label, pred
0.7473011612892151, 0.11903907358646393, 0.09406524151563644, 0.03959449380636215, 0, 0
0.6064049005508423, 0.008950510062277317, 0.28375568985939026, 0.10088887065649033, 0, 0
0.8222436904907227, 0.03349340707063675, 0.061087556183338165, 0.08317539840936661, 0, 0
0.43639472126960754, 0.014715207740664482, 0.504367470741272, 0.04452258348464966, 0, 2
0.9148030877113342, 0.019677210599184036, 0.020247962325811386, 0.04527170956134796, 0, 0
0.8382396697998047, 0.05085292086005211, 0.09501572698354721, 0.015891732648015022, 0, 0
0.6632636785507202, 0.150929793715477, 0.10344664752483368, 0.0823599249124527, 0, 0
0.5562048554420471, 0.12318769842386246, 0.24181540310382843, 0.07879205793142319, 0, 0
0.7127335667610168, 0.013537299819290638, 0.24263426661491394, 0.031094852834939957, 0, 0
0.7020370364189148, 0.020871849730610847, 0.07534057646989822, 0.20175053179264069, 0, 0
0.6452324986457825, 0.18120619654655457, 0.06895191222429276, 0.10460937023162842, 

In [13]:
data_dataset.class_to_idx

{'GNB': 0, 'GNC': 1, 'GPB': 2, 'GPC': 3}