In [None]:
%matplotlib notebook
import torch
import torchvision
from torchvision import transforms
import pytorch_lightning as pl
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def load_file(path):
    return np.load(path).astype(np.float32)


In [None]:
val_transforms = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize(0.49, 0.248),

])

val_dataset = torchvision.datasets.DatasetFolder("Processed/val/", loader=load_file, extensions="npy", transform=val_transforms)

In [None]:
temp_model = torchvision.models.resnet18()
temp_model

In [None]:
list(temp_model.children())[:-2]  

In [None]:
torch.nn.Sequential(*list(temp_model.children())[:-2])

In [None]:
class PneumoniaModel(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.model = torchvision.models.resnet18()
        
        self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        
        self.model.fc = torch.nn.Linear(in_features=512, out_features=1)

        
        self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
    def forward(self, data):

        
        feature_map = self.feature_map(data)
        
        avg_pool_output = torch.nn.functional.adaptive_avg_pool2d(input=feature_map, output_size=(1, 1))
        print(avg_pool_output.shape)
        
        avg_pool_output_flattened = torch.flatten(avg_pool_output)
        print(avg_pool_output_flattened.shape)
        
        pred = self.model.fc(avg_pool_output_flattened)
        return pred, feature_map



In [None]:
def cam(model, img):
    with torch.no_grad():
        pred, features = model(img.unsqueeze(0))
    features = features.reshape((512, 49))
    weight_params = list(model.model.fc.parameters())[0]
    weight = weight_params[0].detach()


    cam = torch.matmul(weight, features)
    cam_img = cam.reshape(7, 7).cpu()
    return cam_img, torch.sigmoid(pred)

In [None]:

model = PneumoniaModel.load_from_checkpoint("weights/weights_3.ckpt", strict=False)
model.eval()

In [None]:
def cam(model, img):
    
    with torch.no_grad():
        pred, features = model(img.unsqueeze(0))
    b, c, h, w = features.shape

    
    features = features.reshape((c, h*w))

    
    weight_params = list(model.model.fc.parameters())[0]

    
    weight = weight_params[0].detach()
    print(weight.shape)
    
    cam = torch.matmul(weight, features)
    print(features.shape)

   
    cam = cam - torch.min(cam)
    cam_img = cam / torch.max(cam)
    
    cam_img = cam_img.reshape(h, w).cpu()

    return cam_img, torch.sigmoid(pred)

def visualize(img, heatmap, pred):
    
    img = img[0]
    
    heatmap = transforms.functional.resize(heatmap.unsqueeze(0), (img.shape[0], img.shape[1]))[0]

    
    fig, axis = plt.subplots(1, 2)

    axis[0].imshow(img, cmap="bone")
    
    axis[1].imshow(img, cmap="bone")
    axis[1].imshow(heatmap, alpha=0.5, cmap="jet")
    plt.title(f"Pneumonia: {(pred > 0.5).item()}")

In [None]:
def visualize(img, cam, pred):
    img = img[0]
    cam = transforms.functional.resize(cam.unsqueeze(0), (224, 224))[0]

    fig, axis = plt.subplots(1, 2)
    axis[0].imshow(img, cmap="bone")
    axis[1].imshow(img, cmap="bone")
    axis[1].imshow(cam, alpha=0.5, cmap="jet")
    plt.title(pred)

In [None]:
img = val_dataset[-6][0]  
activation_map, pred = cam(model, img)  

In [None]:
visualize(img, activation_map, pred)  