In [None]:
import sys
sys.path.append('/tf/data')

import torch
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from pytorch_grad_cam import ScoreCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

from Classification.conv_net_model import convnext_large
from Classification.class_functions import split_ds, concat_data
from general_func import load_dataset

In [None]:
#Import final model
params = {
    'best_model_path': '/tf/data/Classification/ConvNeXt/Grid_Search_synth_1/Search_2/62/Epoch_018.zip',
    'batch_size': 16,
    'loader_workers': 2,
    }
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_transform = transforms.Compose([])
val_transform = transforms.Compose([])

convnext_net = convnext_large(pretrained=False, in_22k=False, transform_train=train_transform, transform_val=val_transform, num_classes=2)
model_weights = torch.load(params['best_model_path'], map_location=device)
convnext_net.load_state_dict(model_weights)

convnext_net.to(device)
convnext_net.eval()
print(device)

In [None]:
#Import Tds
ds_pos = load_dataset(custom_path='/tf/data/cropped/test data/1')
ds_neg = load_dataset(custom_path='/tf/data/cropped/test data/0')

train_scans_pos, _, _ = split_ds(ds_pos, train_split = 1, val_split = 0, seed = None)
train_scans_neg, _, _ = split_ds(ds_neg, train_split = 1, val_split = 0, seed = None)
test_loader = concat_data(train_scans_pos, train_scans_neg, batch_size=params['batch_size'], workers=params['loader_workers'])

In [None]:
# Function to visualize CAM
def show_CAM(model, input, cam_function, cam_class, alpha=0.7, device = device):
    model.eval()
    image, label = input

    print(image.shape)
    layers = []
    for stage, blocks in zip([0,1,2,3],[3,3,27,3]):
        for j in range(blocks):
            layers.append(model.stages[stage][j].dwconv)

    num_layers = len(layers)
    fig, axes = plt.subplots(num_layers, 1, figsize=(26, num_layers))
    
    out = model(image)
    _, predicted_class = torch.max(out, 1)
    pred0 = round(float(out[0][0]),3)
    pred1 = round(float(out[0][1]),3)
    
    for idx, layer in enumerate(layers):
        targets = [ClassifierOutputTarget(cam_class)]
        cam = cam_function(model=model, target_layers=[layer])
        grayscale_cam = cam(input_tensor=image, targets=targets)
        grid_image = vutils.make_grid(image.cpu(), nrow=1, padding=2, normalize=True).permute(1, 2, 0).numpy()
        grayscale_cam = grayscale_cam[0, :]
        result_original = show_cam_on_image(grid_image, grayscale_cam, use_rgb=True, image_weight=alpha)
        
        ax = axes[idx] if num_layers > 1 else axes
        ax.imshow(result_original)
        ax.axis('off')
        ax.set_title(f'Layer: {idx}, CAM Class: {cam_class}', fontsize=12)
    
    plt.suptitle(f'True label: {label}, Predicted: {int(predicted_class)}   -   Predicted logits [0,1]: [{[pred0,pred1]}]', fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

In [None]:
#Find a true positive
true_pos = []
for images, labels in test_loader:
    image, label = images[0], labels[0]
    image = image.unsqueeze(0).to(device)
    convnext_net.eval()
    out = convnext_net(image)
    _, predicted_class = torch.max(out, 1)
    if label == 1 and label == predicted_class:
        true_pos.append(image)
        true_pos.append(label)
        break

print(convnext_net(true_pos[0]), true_pos[1])
print(len(true_pos))

In [None]:
#Create CAM
for class_ in [0,1]:
    show_CAM(convnext_net, true_pos, cam_class=class_, cam_function=ScoreCAM)

In [None]:
#Find a true negative
true_neg = []
for images, labels in test_loader:
    image, label = images[0], labels[0]
    image = image.unsqueeze(0).to(device)
    convnext_net.eval()
    out = convnext_net(image)
    _, predicted_class = torch.max(out, 1)
    if label == 0 and label == predicted_class:
        true_neg.append(image)
        true_neg.append(label)
        break

print(convnext_net(true_neg[0]), true_neg[1])
print(len(true_neg))

In [None]:
#Create CAM
for class_ in [0,1]:
    show_CAM(convnext_net, true_neg, cam_class=class_, cam_function=ScoreCAM)

In [None]:
#Find a false positive
false_pos = []
for images, labels in test_loader:
    image, label = images[0], labels[0]
    image = image.unsqueeze(0).to(device)
    convnext_net.eval()
    out = convnext_net(image)
    _, predicted_class = torch.max(out, 1)
    if label == 0 and label != predicted_class:
        false_pos.append(image)
        false_pos.append(label)
        break
print(convnext_net(false_pos[0]), false_pos[1])
print(len(false_pos))

In [None]:
#Create CAM
for class_ in [0,1]:
    show_CAM(convnext_net, false_pos, cam_class=class_, cam_function=ScoreCAM)

In [None]:
#Find a false negative
false_neg = []
for images, labels in test_loader:
    image, label = images[0], labels[0]
    image = image.unsqueeze(0).to(device)
    convnext_net.eval()
    out = convnext_net(image)
    _, predicted_class = torch.max(out, 1)
    if label == 1 and label != predicted_class:
        false_neg.append(image)
        false_neg.append(label)
        break

print(convnext_net(false_neg[0]), false_neg[1])
print(len(false_neg))

In [None]:
#Create CAM
for class_ in [0,1]:
    show_CAM(convnext_net, false_neg, cam_class=class_, cam_function=ScoreCAM)