# **SETUP**

In [1]:
import os
import sys

sys.path.append("/home/ma012/AlexServer/phase2/src")
sys.path.append("/home/ma012/AlexServer/phase2/src/utils")

import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import transforms, datasets
from torchvision.transforms.functional import InterpolationMode

from torch import nn

import timm
import torchvision.models as models
from torchinfo import summary


from model_utils import load_vit_model_config,load_cnn_model_config, set_seeds
from data_loader import create_dataloaders,create_dataloaders_2

from skimage.segmentation import mark_boundaries

import numpy as np

from PIL import Image

current_dir = os.getcwd()
sys.path.append(os.path.abspath(os.path.join(current_dir, "..")))

from going_modular.going_modular.engine import trainVal, test_step, load_model_checkpoint, ensemble_test_step,ensemble_inference_step,my_ensemble_test_step,max_each_ensemble,inference_step
from helper_functions import plot_loss_curves

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

## General Configuration

In [3]:
SEED = 42
NUM_WORKERS = 0 # os.cpu_count() 

BATCH_SIZE_64 = 64
BATCH_SIZE_32 = 32
BATCH_SIZE_16 = 16
BATCH_SIZE_8 = 8

EPOCHS_25 = 25
EPOCHS_20 = 20
EPOCHS_15 = 15
EPOCHS_10 = 10

# LR_0_00001 = 0.00001
LR_0_0001 = 0.0001
LR_0_1 = 0.1

LR_0_00005 = 0.00005
LR_0_05 = 0.05

LOG_CHECK = "/home/ma012/AlexServer/phase2/log/xai_check.txt"

DATASET_DIR = "/home/ma012/AlexServer/Dataset"
DATASET_KAGGLE_DIR = "/home/ma012/AlexServer/KaggleDataset/reorganized"
DATASET_BT_LARGE_4C = "/home/ma012/AlexServer/phase2/dataset/BT_Large_4c"

NUM_CLASSES = 4 #3
CLASS_NAMES = ['glioma_tumor', 'meningioma_tumor','no_tumor','pituitary_tumor']

TRAIN_RATIO = 0.7
VAL_RATIO = 0.1
TEST_RATIO = 0.2

loss_fn = torch.nn.CrossEntropyLoss()

# **ViT training**

## ViT Config

In [None]:

LOG_32 = "/home/ma012/AlexServer/log/batch32.txt"
LOG_1K = "/home/ma012/AlexServer/log/train_1k.txt"
LOG_21k = "/home/ma012/AlexServer/log/train_21k.txt"
LOG_21k_1k = "/home/ma012/AlexServer/log/train_21k_1k.txt"

# SAVE_PATH = "/home/ma012/AlexServer/log/ViT_B_16/RMSprop/case3.pth"
# SAVE_PATH = "/home/ma012/AlexServer/log/ViT_L_16/RMSprop/new_21k_case3.pth"
# SAVE_PATH = "/home/ma012/AlexServer/log/ViT_B_32/adam/21k_case33.pth"
SAVE_PATH = "/home/ma012/AlexServer/log/ViT_L_16/Adadelta/21k_case33.pth"
# SAVE_PATH = "/home/ma012/AlexServer/log/ViT_L_32/adam/21k_case1.pth"

# SAVE_PATH = "/home/ma012/AlexServer/log/ViT_L_16/adam/new_21k_case2.pth"
# SAVE_PATH = "/home/ma012/AlexServer/log/ViT_L_16/Adadelta/new_21k_case3.pth"


# SAVE_PATH = "/home/ma012/AlexServer/log/case384/checkpoint/checkpoint_case2_augred_in1k.pth"
# SAVE_PATH = "/home/ma012/AlexServer/log/case384/checkpoint/checkpoint_case4_orig_in21k_in1k.pth"


# MODEL_CFG = "torchvision.ViT_B_16_Weights.IMAGENET1K_V1"
# MODEL_CFG = "torchvision.ViT_B_32_Weights.IMAGENET1K_V1"
# MODEL_CFG = "torchvision.ViT_L_16_Weights.IMAGENET1K_V1"
# MODEL_CFG = "torchvision.ViT_L_32_Weights.IMAGENET1K_V1"

# ======== Augment 21k ======
# MODEL_CFG = "timm.vit_base_patch16_224.augreg_in21k.PRETRAINED" 
# MODEL_CFG = "timm.vit_base_patch32_224.augreg_in21k.PRETRAINED"
MODEL_CFG = "timm.vit_large_patch16_224.augreg_in21k.PRETRAINED"
# MODEL_CFG = "timm.vit_large_patch32_224.orig_in21k.PRETRAINED" 

# === Augment 21k and fine-tuned 1k
# MODEL_CFG = "timm.vit_base_patch16_224.augreg2_in21k_ft_in1k.PRETRAINED" 
# MODEL_CFG = "timm.vit_base_patch32_224.augreg_in21k_ft_in1k.PRETRAINED" 
# MODEL_CFG = "timm.vit_large_patch16_224.augreg_in21k_ft_in1k.PRETRAINED" 
# MODEL_CFG = "timm.vit_large_patch32_224.augreg_in21k_ft_in1k.PRETRAINED" # error

# CASE 384x384
# MODEL_CFG = "timm.vit_base_patch16_384.augreg_in1k.PRETRAINED"
# MODEL_CFG = "timm.vit_base_patch32_384.augreg_in1k.PRETRAINED"

# MODEL_CFG = "timm.vit_base_patch16_384.orig_in21k_ft_in1k"
# MODEL_CFG = "timm.vit_base_patch32_384.augreg_in21k_ft_in1k"
# MODEL_CFG = "timm.vit_large_patch16_384.augreg_in21k_ft_in1k"
# MODEL_CFG = "timm.vit_large_patch32_384.orig_in21k_ft_in1k"

## ViT RUN

In [None]:
set_seeds(SEED)
model,model_transforms = load_vit_model_config(MODEL_CFG, NUM_CLASSES)
                       
# case 224x224
custom_transforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    model_transforms
])

# optimizer = torch.optim.RMSprop(model.parameters(), lr=LR_0_00005)
# optimizer = torch.optim.Adam(model.parameters(), lr=LR_0_00005)
optimizer = torch.optim.Adadelta(model.parameters(), lr=LR_0_05)

loss_fn = torch.nn.CrossEntropyLoss()

dataset_dir = "/home/ma012/AlexServer/Dataset"

train_loader, val_loader, test_loader = create_dataloaders(
    dataset_dir=dataset_dir,
    transform=custom_transforms,
    batch_size= BATCH_SIZE_16,
    num_workers=NUM_WORKERS,
    train_ratio=TRAIN_RATIO,
    val_ratio=VAL_RATIO,
    test_ratio=TEST_RATIO,
    seed=SEED
)

In [None]:
print(len(train_loader))
print(len(val_loader))
print(len(test_loader))

In [None]:
results = trainVal(model, train_loader, val_loader, optimizer, loss_fn,
                    EPOCHS_15, DEVICE, SAVE_PATH,LOG_32)

In [None]:
plot_loss_curves(results, type="val")

# **ViT Testing**

In [4]:
ViT_CHECKPOINT = [
    # Figshare Au/home/ma012/AlexServer/phase2/log/drive/Figshare Augmentation/L16_21K_case3.pthgmentation
    ["/home/ma012/AlexServer/phase2/log/drive/Figshare Augmentation/L16_21K_case3.pth", "timm.vit_large_patch16_224.augreg_in21k.PRETRAINED"],

    # BT-4C
    ["/home/ma012/AlexServer/phase2/log/drive/ BT-4C/L16_1K_case3.pth", "torchvision.ViT_L_16_Weights.IMAGENET1K_V1"], 

    # Hugging Face
    ["/home/ma012/AlexServer/phase2/log/drive/Hugging Face/L16_1K_case3.pth", "torchvision.ViT_L_16_Weights.IMAGENET1K_V1"], 
]

In [5]:
set_seeds(SEED)
model,model_transforms = load_vit_model_config(ViT_CHECKPOINT[1][1], NUM_CLASSES)

CUSTOM_TRANSFORMS = transforms.Compose([
    # transforms.Grayscale(num_output_channels=3),
    model_transforms
])

loss_fn = torch.nn.CrossEntropyLoss()

train_loader, val_loader, test_loader = create_dataloaders_2(
    dataset_dir=DATASET_BT_LARGE_4C,
    transform=CUSTOM_TRANSFORMS,
    batch_size= BATCH_SIZE_32,
    num_workers=NUM_WORKERS,
    # train_ratio=TRAIN_RATIO,
    val_ratio=VAL_RATIO,
    # test_ratio=TEST_RATIO,
    seed=SEED
)

In [None]:
model_name = ['ViT B-16','ViT B-32', 'ViT L-16', 'ViT L-32']

for idx, (cp_path, model_cfg) in enumerate(BEST_CHECKPOINT):
    with open("/home/ma012/AlexServer/log/test_results.txt", "a", encoding="utf-8") as file:
        file.write(f"\n ======================   {model_name[idx]}  ==================================\n")

    model,model_transforms = load_vit_model_config(model_cfg, NUM_CLASSES)
    model = load_model_checkpoint(model, cp_path, DEVICE)

    test_results = test_step(model, test_loader, loss_fn, DEVICE,LOG_CHECK)

ViT original head


## Specific ViT Model

In [6]:
model,model_transforms = load_vit_model_config(ViT_CHECKPOINT[1][1], NUM_CLASSES)
model = load_model_checkpoint(model, ViT_CHECKPOINT[1][0], DEVICE)

test_results = test_step(model, test_loader, loss_fn, DEVICE,LOG_CHECK)

# **ENSEMBLE-TEST**

## ViT Ensemble Model

### Set up

In [12]:
BEST_CHECKPOINT = [
    ["/home/ma012/AlexServer/log/vit_best/B16_21K_case3.pth", "timm.vit_base_patch16_224.augreg_in21k.PRETRAINED"],
 
    ["/home/ma012/AlexServer/log/vit_best/B32_21K_case2.pth", "timm.vit_base_patch32_224.augreg_in21k.PRETRAINED"], 

    ["/home/ma012/AlexServer/log/vit_best/L16_21K_case3.pth", "timm.vit_large_patch16_224.augreg_in21k.PRETRAINED"], 
 
    ["/home/ma012/AlexServer/log/vit_best/L32_21K_case1.pth",  "timm.vit_large_patch32_224.orig_in21k.PRETRAINED" ] 
]

In [13]:
for cp_path, model_cfg in BEST_CHECKPOINT:
    model,model_transforms = load_vit_model(model_cfg, NUM_CLASSES)
    print(f"{model_transforms} \n")

Compose(
    Resize(size=248, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    MaybeToTensor()
    Normalize(mean=tensor([0.5000, 0.5000, 0.5000]), std=tensor([0.5000, 0.5000, 0.5000]))
) 

Compose(
    Resize(size=248, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    MaybeToTensor()
    Normalize(mean=tensor([0.5000, 0.5000, 0.5000]), std=tensor([0.5000, 0.5000, 0.5000]))
) 

Compose(
    Resize(size=248, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    MaybeToTensor()
    Normalize(mean=tensor([0.5000, 0.5000, 0.5000]), std=tensor([0.5000, 0.5000, 0.5000]))
) 

ViT original head
Compose(
    Resize(size=248, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    MaybeToTensor()
    Normalize(mean=tensor([0.5000, 0.5000, 0.5000]), std=tensor([0.5000, 0.5000, 0.5000]))
) 



In [8]:
model_test, model_transforms = load_vit_model("timm.vit_base_patch16_224.augreg_in21k.PRETRAINED", NUM_CLASSES)

CUSTOM_TRANSFORMS = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    model_transforms
])
model_transforms

Compose(
    Resize(size=248, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    MaybeToTensor()
    Normalize(mean=tensor([0.5000, 0.5000, 0.5000]), std=tensor([0.5000, 0.5000, 0.5000]))
)

In [6]:
_, _, test_loader = create_dataloaders(
    dataset_dir=DATASET_DIR,
    transform=CUSTOM_TRANSFORMS,
    batch_size=BATCH_SIZE_16,
    num_workers=NUM_WORKERS,
    train_ratio=TRAIN_RATIO,
    val_ratio=VAL_RATIO,
    test_ratio=TEST_RATIO,
    seed=SEED
)

loss_fn = torch.nn.CrossEntropyLoss()

### Load models

In [15]:
models_list = []
for cp_path, model_cfg in BEST_CHECKPOINT:
    model,model_transforms = load_vit_model(model_cfg, NUM_CLASSES)
    model = load_model_checkpoint(model, cp_path, DEVICE)
    
    models_list.append(model)

ViT original head


## ViT Ensemble - Testing Results

In [16]:
ensemble_loss, ensemble_acc , ensemble_sensitivity, ensemble_specificity = ensemble_test_step(models_list, test_loader,loss_fn, DEVICE,LOG_CHECK)
print(f"Ensemble Test Loss: {ensemble_loss:.4f} | Ensemble Test Accuracy: {ensemble_acc:.4f}")
print(f"Ensemble sensitivity: {ensemble_sensitivity:.4f} | Ensemble specificity: {ensemble_specificity:.4f}")

Ensemble Test Loss: 0.0187 | Ensemble Test Accuracy: 0.9739
Ensemble sensitivity: 0.9706 | Ensemble specificity: 0.9860


# **CNN**

In [12]:
CHECKPOINT_CNN = [
    ["/home/ma012/AlexServer/log/cnn_best/torchvision_VGG16_Weights_IMAGENET1K_V1.pth", "torchvision.VGG16_Weights.IMAGENET1K_V1"],
 
    ["/home/ma012/AlexServer/log/cnn_best/torchvision_MobileNet_V2_Weights_IMAGENET1K_V1.pth", "torchvision.MobileNet_V2_Weights.IMAGENET1K_V2"], 

    ["/home/ma012/AlexServer/log/cnn_best/torchvision_GoogLeNet_Weights_IMAGENET1K_V1.pth", "torchvision.GoogLeNet_Weights.IMAGENET1K_V1"], 
 
    ["/home/ma012/AlexServer/log/cnn_best/torchvision_ResNet50_Weights_IMAGENET1K_V1.pth",  "torchvision.ResNet50_Weights.IMAGENET1K_V2" ],
    
    # ["/home/ma012/AlexServer/log/cnn_best/torchvision_ResNet50_Weights_IMAGENET1K_V1.pth",  "torchvision.ResNet50_Weights.IMAGENET1K_V2" ],

    # ["/home/ma012/AlexServer/log/cnn_best/torchvision_ResNet50_Wevights_IMAGENET1K_V1.pth",  "torchvision.ResNet50_Weights.IMAGENET1K_V2" ],

    # ["/home/ma012/AlexServer/log/cnn_best/torchvision_ResNet50_Weights_IMAGENET1K_V1.pth",  "torchvision.ResNet50_Weights.IMAGENET1K_V2" ],
]

In [13]:
cnn_models_list = []

for cp_path, model_cfg in CHECKPOINT_CNN:
    # print(model_cfg)
    model_cnn,model_transforms_cnn = load_cnn_model(DEVICE,model_cfg, NUM_CLASSES)
    print(f"{model_transforms_cnn}\n")

ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)

ImageClassification(
    crop_size=[224]
    resize_size=[232]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)

ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)

ImageClassification(
    crop_size=[224]
    resize_size=[232]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)



In [4]:
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

DEFAULT_SIZE = (256, 256)
CROP_SIZE = 224

DEFAULT_NORMALIZE = transforms.Normalize(mean=MEAN, std=STD)

DEFAULT_TRANSFORMS = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize(DEFAULT_SIZE, interpolation=InterpolationMode.BILINEAR),
    transforms.CenterCrop(CROP_SIZE),
    transforms.ToTensor(),
    DEFAULT_NORMALIZE
])

_, _, test_loader = create_dataloaders(
    dataset_dir=DATASET_DIR,
    transform=DEFAULT_TRANSFORMS,
    batch_size=BATCH_SIZE_16,
    num_workers=NUM_WORKERS,
    train_ratio=TRAIN_RATIO,
    val_ratio=VAL_RATIO,
    test_ratio=TEST_RATIO,
    seed=SEED
)

## **Individual Testing**

In [11]:
# model_name = ['VGG16','MobileNet_V2', 'GoogLeNet', 'ResNet50', 'ConvNext-Large', 'DenseNet121', 'EfficienNet_B0']
model_name = ['VGG16','MobileNet_V2', 'GoogLeNet', 'ResNet50']


for idx, (cp_path, model_cfg) in enumerate(CHECKPOINT_CNN):
    with open(LOG_CHECK, "a", encoding="utf-8") as file:
        file.write(f"\n======================   {model_name[idx]}  ==============================\n")

    model_cnn,model_transforms_cnn = load_cnn_model(DEVICE,model_cfg, NUM_CLASSES)
 
    model_cnn = load_model_checkpoint(model_cnn, cp_path, DEVICE)

    test_cnn_results = test_step(model_cnn, test_loader, loss_fn, DEVICE,LOG_CHECK)

In [None]:
["/home/ma012/AlexServer/log/cnn_best/torchvision_VGG16_Weights_IMAGENET1K_V1.pth", "torchvision.VGG16_Weights.IMAGENET1K_V1"],

["/home/ma012/AlexServer/log/cnn_best/torchvision_MobileNet_V2_Weights_IMAGENET1K_V1.pth", "torchvision.MobileNet_V2_Weights.IMAGENET1K_V2"], 

["/home/ma012/AlexServer/log/cnn_best/torchvision_GoogLeNet_Weights_IMAGENET1K_V1.pth", "torchvision.GoogLeNet_Weights.IMAGENET1K_V1"], 

["/home/ma012/AlexServer/log/cnn_best/torchvision_ResNet50_Weights_IMAGENET1K_V1.pth",  "torchvision.ResNet50_Weights.IMAGENET1K_V2" ],


In [6]:
model,model_transforms = load_cnn_model("torchvision.MobileNet_V2_Weights.IMAGENET1K_V2", NUM_CLASSES)
model = load_model_checkpoint(model, "/home/ma012/AlexServer/log/cnn_best/torchvision_MobileNet_V2_Weights_IMAGENET1K_V1.pth", DEVICE)

test_results = test_step(model, test_loader, loss_fn, DEVICE,LOG_CHECK)

## **Ensemble Test**

In [12]:
cnn_models_list = []

for cp_path, model_cfg in CHECKPOINT_CNN:
    # print(model_cfg)
    model_cnn,model_transforms_cnn = load_cnn_model(DEVICE,model_cfg, NUM_CLASSES)
    model_cnn = load_model_checkpoint(model_cnn, cp_path, DEVICE)
    
    cnn_models_list.append(model_cnn)

In [13]:
ensemble_cnn_loss, ensemble_cnn_acc , ensemble_cnn_sensitivity, ensemble_cnn_specificity = ensemble_test_step(cnn_models_list, test_loader,loss_fn, DEVICE,LOG_CHECK)
print(f"Ensemble_cnn Test Loss: {ensemble_cnn_loss:.4f} | Ensemble_cnn Test Accuracy: {ensemble_cnn_acc:.4f}")
print(f"Ensemble_cnn sensitivity: {ensemble_cnn_sensitivity:.4f} | Ensemble_cnn specificity: {ensemble_cnn_specificity:.4f}")

Ensemble_cnn Test Loss: 0.0368 | Ensemble_cnn Test Accuracy: 0.9772
Ensemble_cnn sensitivity: 1.0000 | Ensemble_cnn specificity: 0.9930


# **Inference Stage**

In [None]:
CLASS_NAMES = ['glioma_tumor', 'meningioma_tumor', 'pituitary_tumor']

In [None]:
["/home/ma012/AlexServer/log/vit_best/B16_21K_case3.pth", "timm.vit_base_patch16_224.augreg_in21k.PRETRAINED"],
 
["/home/ma012/AlexServer/log/vit_best/B32_21K_case2.pth", "timm.vit_base_patch32_224.augreg_in21k.PRETRAINED"], 

["/home/ma012/AlexServer/log/vit_best/L16_21K_case3.pth", "timm.vit_large_patch16_224.augreg_in21k.PRETRAINED"], 
 
["/home/ma012/AlexServer/log/vit_best/L32_21K_case1.pth",  "timm.vit_large_patch32_224.orig_in21k.PRETRAINED" ] 

In [7]:
vit_model, vit_model_transforms = load_vit_model("timm.vit_base_patch16_224.augreg_in21k.PRETRAINED", NUM_CLASSES)
load_vit_model = load_model_checkpoint(vit_model, "/home/ma012/AlexServer/log/vit_best/B16_21K_case3.pth", DEVICE)

VIT_CUSTOM_TRANSFORMS = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    vit_model_transforms
])

cnn_model, cnn_model_transforms = load_cnn_model("torchvision.MobileNet_V2_Weights.IMAGENET1K_V2", NUM_CLASSES)
load_cnn_model = load_model_checkpoint(cnn_model, "/home/ma012/AlexServer/log/cnn_best/torchvision_MobileNet_V2_Weights_IMAGENET1K_V1.pth", DEVICE)

CNN_CUSTOM_TRANSFORMS = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    cnn_model_transforms
])

In [8]:
# Load and preprocess the image
img_path = "/home/ma012/AlexServer/Dataset/pituitary_tumor/994.jpg"
# image = Image.open(img_path).convert("L")  # grayscale image assumed
image = Image.open(img_path)

# Apply the same transforms as trainingcnn_image_tensor
vit_image_tensor = VIT_CUSTOM_TRANSFORMS(image).unsqueeze(0)  # shape: [1, C, H, W]
cnn_image_tensor = CNN_CUSTOM_TRANSFORMS(image).unsqueeze(0)  # shape: [1, C, H, W]


### individual ViT and CNN 

In [9]:
vit_predicted_class, vit_softmax = inference_step(load_vit_model,vit_image_tensor,DEVICE)

print(f"predicted_class: {vit_predicted_class} | softmax: {vit_softmax}")

predicted_class: 2 | softmax: tensor([0.0642, 0.0174, 0.9184], device='cuda:0')


In [10]:
cnn_predicted_class, cnn_softmax = inference_step(load_cnn_model,cnn_image_tensor,DEVICE)

print(f"predicted_class: {cnn_predicted_class} | softmax: {cnn_softmax}")

predicted_class: 0 | softmax: tensor([0.5875, 0.4076, 0.0049], device='cuda:0')


In [None]:
BEST_CHECKPOINT_CNN = [
    ["/home/ma012/AlexServer/log/cnn_best/torchvision_VGG16_Weights_IMAGENET1K_V1.pth", "torchvision.VGG16_Weights.IMAGENET1K_V1"],
 
    ["/home/ma012/AlexServer/log/cnn_best/torchvision_MobileNet_V2_Weights_IMAGENET1K_V1.pth", "torchvision.MobileNet_V2_Weights.IMAGENET1K_V2"], 

    ["/home/ma012/AlexServer/log/cnn_best/torchvision_GoogLeNet_Weights_IMAGENET1K_V1.pth", "torchvision.GoogLeNet_Weights.IMAGENET1K_V1"], 
 
    ["/home/ma012/AlexServer/log/cnn_best/torchvision_ResNet50_Weights_IMAGENET1K_V1.pth",  "torchvision.ResNet50_Weights.IMAGENET1K_V2" ],
]

cnn_models_list = []

for cp_path, model_cfg in BEST_CHECKPOINT_CNN:
    # print(model_cfg)
    model_cnn,model_transforms_cnn = load_cnn_model(model_cfg, NUM_CLASSES)
    model_cnn = load_model_checkpoint(model_cnn, cp_path, DEVICE)
    
    cnn_models_list.append(model_cnn)

    CNN_CUSTOM_TRANSFORMS = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    model_transforms_cnn])

    cnn_image_tensor = CNN_CUSTOM_TRANSFORMS(image).unsqueeze(0)  # shape: [1, C, H, W]

for model in cnn_models_list:
    cnn_predicted_class, cnn_softmax = inference_step(model,cnn_image_tensor,DEVICE)
    print(f"predicted_class: {cnn_predicted_class} | softmax: {cnn_softmax}")


## ViT ensemble

In [6]:
BEST_CHECKPOINT = [
    ["/home/ma012/AlexServer/log/vit_best/B16_21K_case3.pth", "timm.vit_base_patch16_224.augreg_in21k.PRETRAINED"],
 
    ["/home/ma012/AlexServer/log/vit_best/B32_21K_case2.pth", "timm.vit_base_patch32_224.augreg_in21k.PRETRAINED"], 

    ["/home/ma012/AlexServer/log/vit_best/L16_21K_case3.pth", "timm.vit_large_patch16_224.augreg_in21k.PRETRAINED"], 
 
    ["/home/ma012/AlexServer/log/vit_best/L32_21K_case1.pth",  "timm.vit_large_patch32_224.orig_in21k.PRETRAINED" ] 
]

In [7]:
models_list = []
for cp_path, model_cfg in BEST_CHECKPOINT:
    model,model_transforms = load_vit_model(model_cfg, NUM_CLASSES)
    model = load_model_checkpoint(model, cp_path, DEVICE)
    
    models_list.append(model)

ViT original head


In [None]:
predicted_class, ensemble_softmax = ensemble_inference_step(models_list,image_tensor,DEVICE)

In [17]:
print(f"predicted_class: {predicted_class} | softmax: {ensemble_softmax}")

predicted_class: 0 | softmax: tensor([9.9996e-01, 2.5323e-05, 1.8574e-05], device='cuda:0')


## CNN ensemble

In [4]:
BEST_CHECKPOINT_CNN = [
    ["/home/ma012/AlexServer/log/cnn_best/torchvision_VGG16_Weights_IMAGENET1K_V1.pth", "torchvision.VGG16_Weights.IMAGENET1K_V1"],
 
    ["/home/ma012/AlexServer/log/cnn_best/torchvision_MobileNet_V2_Weights_IMAGENET1K_V1.pth", "torchvision.MobileNet_V2_Weights.IMAGENET1K_V2"], 

    ["/home/ma012/AlexServer/log/cnn_best/torchvision_GoogLeNet_Weights_IMAGENET1K_V1.pth", "torchvision.GoogLeNet_Weights.IMAGENET1K_V1"], 
 
    ["/home/ma012/AlexServer/log/cnn_best/torchvision_ResNet50_Weights_IMAGENET1K_V1.pth",  "torchvision.ResNet50_Weights.IMAGENET1K_V2" ],
    
    # ["/home/ma012/AlexServer/log/cnn_best/torchvision_ResNet50_Weights_IMAGENET1K_V1.pth",  "torchvision.ResNet50_Weights.IMAGENET1K_V2" ],

    # ["/home/ma012/AlexServer/log/cnn_best/torchvision_ResNet50_Wevights_IMAGENET1K_V1.pth",  "torchvision.ResNet50_Weights.IMAGENET1K_V2" ],

    # ["/home/ma012/AlexServer/log/cnn_best/torchvision_ResNet50_Weights_IMAGENET1K_V1.pth",  "torchvision.ResNet50_Weights.IMAGENET1K_V2" ],
]

In [5]:
cnn_models_list = []

for cp_path, model_cfg in BEST_CHECKPOINT_CNN:
    # print(model_cfg)
    model_cnn,model_transforms_cnn = load_cnn_model(model_cfg, NUM_CLASSES)
    model_cnn = load_model_checkpoint(model_cnn, cp_path, DEVICE)
    
    cnn_models_list.append(model_cnn)

In [8]:
cnn_predicted_class, cnn_ensemble_softmax = ensemble_inference_step(cnn_models_list,image_tensor,DEVICE)

In [9]:
print(f"cnn predicted_class: {cnn_predicted_class} | softmax: {cnn_ensemble_softmax}")

cnn predicted_class: 0 | softmax: tensor([0.7466, 0.0103, 0.2431], device='cuda:0')


In [None]:
test_pred_labels = torch.tensor([1, 0, 2, 3])
y = torch.tensor([1, 1, 2, 0])

mismatched = test_pred_labels != y
# tensor([False, True, False, True])
wrong_indices = torch.nonzero(mismatched).squeeze()
wrong_indices



wrong_paths = [path[i] for i in torch.nonzero(mismatched).squeeze()]

tensor([1, 3])