# Importing Libraries

In [26]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.nn import Linear, Sequential, Dropout, ReLU
from torch import load, inference_mode, round, sigmoid
from sklearn import metrics
from torchvision.models.regnet import regnet_y_3_2gf, RegNet_Y_3_2GF_Weights
from torchvision.models.swin_transformer import swin_v2_t, Swin_V2_T_Weights
from torchvision.models.efficientnet import efficientnet_v2_s, EfficientNet_V2_S_Weights

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [6]:
import sys
sys.path.append("../")
from utils import ImagesOnlyDataset

In [7]:
from torchvision.transforms import v2, InterpolationMode
import torch
RegNet_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Resize((232,232), interpolation=InterpolationMode.BICUBIC, antialias=True),
    v2.CenterCrop((224,224)),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

EfficientNet_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Resize((384,384), interpolation=InterpolationMode.BICUBIC, antialias=True),
    v2.CenterCrop((384,384)),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

SwinV2_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Resize((260,260), interpolation=InterpolationMode.BICUBIC, antialias=True),
    v2.CenterCrop((256,256)),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Preparing Test Data

In [5]:
test_targets = pd.read_csv("../Data/Processed/test_targets.csv")
test_features = pd.read_csv("../Data/Processed/test_features.csv")

In [8]:
test_dataset_RegNet = ImagesOnlyDataset(test_features['filename'], test_targets, "../Data/images", RegNet_transform)
test_dataset_EfficientNet = ImagesOnlyDataset(test_features['filename'], test_targets, "../Data/images", EfficientNet_transform)
test_dataset_SwinV2 = ImagesOnlyDataset(test_features['filename'], test_targets, "../Data/images", SwinV2_transform)

In [9]:
test_dataloader_regnet = DataLoader(test_dataset_RegNet, 16, True)
test_dataloader_efficientnet = DataLoader(test_dataset_EfficientNet, 16, True)
test_dataloader_swinv2 = DataLoader(test_dataset_SwinV2, 16, True)

# Preparing Models

## RegNet

In [11]:
regnet_model = regnet_y_3_2gf()
regnet_model.fc = Linear(1512, 22)
regnet_model.load_state_dict(load("../Models/FinetunedRegNet.pth"))

<All keys matched successfully>

In [14]:
total_params = sum(p.numel() for p in regnet_model.parameters())
print("Number of Parameters: %.3fM" % (total_params/1e6,))

Number of Parameters: 17.957M


## EfficientNet

In [16]:
efficientnet_model = efficientnet_v2_s()
efficientnet_model.classifier = Sequential(
    Dropout(p=0.2),
    ReLU(),
    Linear(in_features=1280, out_features=22)
)
efficientnet_model.load_state_dict(load("../Models/FinetunedEfficientNet.pth"))

<All keys matched successfully>

In [17]:
total_params = sum(p.numel() for p in efficientnet_model.parameters())
print("Number of Parameters: %.3fM" % (total_params/1e6,))

Number of Parameters: 20.206M


## SwinV2

In [18]:
swinv2_model = swin_v2_t()
swinv2_model.head = Linear(768, 22)
swinv2_model.load_state_dict(load("../Models/FinetunedSwinV2.pth"))

<All keys matched successfully>

In [19]:
total_params = sum(p.numel() for p in swinv2_model.parameters())
print("Number of Parameters: %.3fM" % (total_params/1e6,))

Number of Parameters: 27.599M


# Testing Performance

## Accuracy

### RegNet

In [20]:
regnet_true_labels = []
regnet_pred_labels = []
regnet_model.eval()
with inference_mode():
    for batch, labels in test_dataloader_regnet:
        outputs = regnet_model(batch)
        preds = round(sigmoid(outputs))
        regnet_true_labels.extend(labels.cpu().numpy())
        regnet_pred_labels.extend(preds.cpu().numpy())

regnet_true_labels = np.array(regnet_true_labels)
regnet_pred_labels = np.array(regnet_pred_labels)

regnet_accuracy = metrics.accuracy_score(regnet_true_labels, regnet_pred_labels)

print(f'Accuracy: {regnet_accuracy * 100:.2f}%')

Accuracy: 70.55%


### EfficientNet

In [21]:
efficientnet_true_labels = []
efficientnet_pred_labels = []
efficientnet_model.eval()
with inference_mode():
    for batch, labels in test_dataloader_efficientnet:
        outputs = efficientnet_model(batch)
        preds = round(sigmoid(outputs))
        efficientnet_true_labels.extend(labels.cpu().numpy())
        efficientnet_pred_labels.extend(preds.cpu().numpy())

efficientnet_true_labels = np.array(efficientnet_true_labels)
efficientnet_pred_labels = np.array(efficientnet_pred_labels)

efficientnet_accuracy = metrics.accuracy_score(efficientnet_true_labels, efficientnet_pred_labels)

print(f'Accuracy: {efficientnet_accuracy * 100:.2f}%')

Accuracy: 65.03%


### Swinv2

In [22]:
swinv2_true_labels = []
swinv2_pred_labels = []
swinv2_model.eval()
with inference_mode():
    for batch, labels in test_dataloader_swinv2:
        outputs = swinv2_model(batch)
        preds = round(sigmoid(outputs))
        swinv2_true_labels.extend(labels.cpu().numpy())
        swinv2_pred_labels.extend(preds.cpu().numpy())

swinv2_true_labels = np.array(swinv2_true_labels)
swinv2_pred_labels = np.array(swinv2_pred_labels)

swinv2_accuracy = metrics.accuracy_score(swinv2_true_labels, swinv2_pred_labels)

print(f'Accuracy: {swinv2_accuracy * 100:.2f}%')

Accuracy: 61.96%


## Confusion Matrix

### RegNet

In [23]:
regnet_cf = metrics.multilabel_confusion_matrix(regnet_true_labels, regnet_pred_labels)
print(metrics.classification_report(regnet_true_labels, regnet_pred_labels))

              precision    recall  f1-score   support

           0       0.83      0.38      0.53        13
           1       0.89      0.80      0.84        99
           2       0.00      0.00      0.00         1
           3       1.00      1.00      1.00         1
           4       1.00      0.67      0.80         6
           5       1.00      1.00      1.00         1
           6       0.00      0.00      0.00         1
           7       0.00      0.00      0.00         2
           8       0.00      0.00      0.00         2
           9       1.00      0.50      0.67         2
          10       0.00      0.00      0.00         2
          11       1.00      0.50      0.67         2
          12       0.25      0.25      0.25         4
          13       1.00      1.00      1.00         1
          14       1.00      0.50      0.67         6
          15       0.95      0.96      0.96       141
          16       0.75      1.00      0.86         3
          17       0.50    

In [24]:
efficientnet_cf = metrics.multilabel_confusion_matrix(efficientnet_true_labels, efficientnet_pred_labels)
print(metrics.classification_report(efficientnet_true_labels, efficientnet_pred_labels))

              precision    recall  f1-score   support

           0       0.00      0.00      0.00        13
           1       0.88      0.83      0.85        99
           2       0.00      0.00      0.00         1
           3       0.00      0.00      0.00         1
           4       0.33      0.17      0.22         6
           5       0.00      0.00      0.00         1
           6       0.00      0.00      0.00         1
           7       0.00      0.00      0.00         2
           8       0.00      0.00      0.00         2
           9       0.00      0.00      0.00         2
          10       0.00      0.00      0.00         2
          11       0.00      0.00      0.00         2
          12       0.00      0.00      0.00         4
          13       0.00      0.00      0.00         1
          14       0.33      0.17      0.22         6
          15       0.96      0.98      0.97       141
          16       0.00      0.00      0.00         3
          17       0.00    

In [25]:
swinv2_cf = metrics.multilabel_confusion_matrix(swinv2_true_labels, swinv2_pred_labels)
print(metrics.classification_report(swinv2_true_labels, swinv2_pred_labels))

              precision    recall  f1-score   support

           0       0.00      0.00      0.00        13
           1       0.64      0.88      0.74        99
           2       0.00      0.00      0.00         1
           3       0.00      0.00      0.00         1
           4       0.00      0.00      0.00         6
           5       0.00      0.00      0.00         1
           6       0.00      0.00      0.00         1
           7       0.00      0.00      0.00         2
           8       0.00      0.00      0.00         2
           9       0.00      0.00      0.00         2
          10       0.00      0.00      0.00         2
          11       0.00      0.00      0.00         2
          12       0.00      0.00      0.00         4
          13       0.00      0.00      0.00         1
          14       0.00      0.00      0.00         6
          15       0.95      0.98      0.96       141
          16       0.00      0.00      0.00         3
          17       0.00    