In [None]:
import os
import datetime
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
import seaborn as sn
from torchcam.methods import CAM, GradCAM
from torchcam.utils import overlay_mask
from torchvision.transforms.functional import to_pil_image
from LandUseDataset import LandUseDataset, Mode
from Models import LandUseModelResnet50
from Models import LandUseModelResnet50NoFeatures, LandUseModelResnet152NoFeatures, LandUseModelVisionTransformerB16NoFeatures, LandUseModelDensenet161NoFeatures


In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

transform = {
    "validation": transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    "reverse": transforms.ToPILImage()
}

dataset = LandUseDataset(Mode.EXTERNAL, transform=transform["validation"], image_size=IMAGE_SIZE)
loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=os.cpu_count())

test_dataset = LandUseDataset(Mode.TEST, transform=transform["validation"], image_size=IMAGE_SIZE)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=os.cpu_count())

In [None]:
model = LandUseModelVisionTransformerB16NoFeatures(len(dataset.classes), DEVICE).to(DEVICE)

checkpoint = torch.load(r"C:\Users\aakas\Documents\MLDL Project\Project\Checkpoints\visiontransformerb16_nofeatures\2024-02-11 20_30_58\model_visiontransformerb16_nofeatures_last_48.pt")
model.load_state_dict(checkpoint['model_state_dict'])
print(model)

Accuracy on the test and external dataset

In [None]:
load = loader
ds = dataset

correct = 0

model.eval()
with torch.no_grad():
    for X, y in tqdm(load):
        X, y = X.to(DEVICE), y.to(DEVICE)
        pred = model(X).argmax(dim=1).item()

        if pred == y.item():
            correct += 1

print("External")
print(100 * correct / len(ds))

load = test_loader
ds = test_dataset

correct = 0

model.eval()
with torch.no_grad():
    for X, y in tqdm(load):
        X, y = X.to(DEVICE), y.to(DEVICE)
        pred = model(X).argmax(dim=1).item()

        if pred == y.item():
            correct += 1
            
print("test")      
print(100 * correct / len(ds))

Get misclassified images on the test dataset

In [None]:
load = test_loader
ds = test_dataset

model.eval()
with torch.no_grad():
    for i, (X, y) in enumerate(tqdm(load)):
        X, y = X.to(DEVICE), y.to(DEVICE)
        pred = model(X)
        
        if pred.argmax(dim=1).item() != y.item():
            pred_class = ""
            target_class = ""

            for k, v in test_dataset.classes.items():
                if v == y.item():
                    target_class = k
                if v == pred.argmax(dim=1).item():
                    pred_class = k
            
            fig, ax = plt.subplots(1, 1, figsize=(8, 4))
            ax.imshow(to_pil_image(X[0]))
            ax.axis("off")
            fig.suptitle(f'Target: {target_class} ({100 * pred[0][y].item():.3f}%)\nPred: {pred_class} ({100 * pred[0][pred.argmax(dim=1).item()]:.3f}%)', fontsize=12)
            plt.show()
            fig.savefig(f"Misclassified/{model.name}_{i}.svg")

Get activation map for misclassified images on the test dataset, when using the 2% range method

In [None]:
load = test_loader
ds = test_dataset

correct = 0
incorrect = 0
incorrect_same = 0
y_target = []
y_pred = []

model.eval()

for params in model.parameters():
    params.requires_grad = True
    
#cam_extractor = CAM(model, target_layer="base_model.layer4", fc_layer="classifier.1")

with torch.no_grad():
    for i, (X, y) in enumerate(load):
        X, y = X.to(DEVICE), y.to(DEVICE)
        pred = model(X)

        if pred.argmax(dim=1).item() != y.item():
            incorrect += 1
            
            pred_prob = 100 * pred[0][y].item()
            target_prob = 100 * pred[0][pred.argmax(dim=1).item()].item()
            
            if pred_prob - 2 < target_prob < pred_prob + 2:
                incorrect_same += 1
                y_target.append(y.item())
                y_pred.append(y.item())
            else:
                y_target.append(y.item())
                y_pred.append(pred.argmax(dim=1).item())

            
            pred_class = ""
            target_class = ""

            for k, v in test_dataset.classes.items():
                if v == y.item():
                    target_class = k
                if v == pred.argmax(dim=1).item():
                    pred_class = k
                    
            
            """
            activation_map = cam_extractor(pred.argmax(dim=1).item(), pred)
            result = overlay_mask(to_pil_image(X[0]), to_pil_image(activation_map[0].squeeze(0), mode="F"),  alpha=0.5)

            fig, ax = plt.subplots(1, 2, figsize=(8, 4))

            fig.suptitle(f'Target: {target_class} ({100 * pred[0][y].item():.3f}%)\nPred: {pred_class} ({100 * pred[0][pred.argmax(dim=1).item()]:.3f}%)', fontsize=12)

            ax[0].imshow(to_pil_image(X[0]))
            ax[1].imshow(result)

            ax[0].axis('off')
            ax[1].axis('off')

            plt.show()
            
            fig.savefig(f"CAM IMAGES/{model.name}_{i}.svg")
            """
            
        else:
            correct += 1
            y_target.append(y.item())
            y_pred.append(pred.argmax(dim=1).item())

Results

In [None]:
print(f"Accuracy: {100 * correct / len(ds)}")
print(f"AccuracyNew: {100 * (correct + incorrect_same) / len(ds)}")
print(f"Incorrect: {100 * incorrect / len(ds)}")
print(f"Incorrect with same prob: {100 * incorrect_same / incorrect}")
print(incorrect, incorrect_same)

In [None]:
labels = list(test_dataset.classes.keys())

y_pred_labels = [labels[i] for i in y_pred]
y_target_labels = [labels[i] for i in y_target]

data = precision_recall_fscore_support(y_target_labels, y_pred_labels, average=None, labels=labels)

In [None]:
for item in data:
    s = sum(item) / 33
    print(s)

In [None]:
cf = confusion_matrix(y_true=y_target, y_pred=y_pred)
df_cm = pd.DataFrame(cf, index=[labels], columns=[labels])
plt.figure(figsize=(15,15))
ax = sn.heatmap(df_cm, annot=True, vmax=30)
ax.set(xlabel="Pred", ylabel="Target")
plt.savefig("fc_nofeatures_acc2.svg")
plt.show()