In [None]:
import torch
import torchvision
torchvision.disable_beta_transforms_warning()
from torchvision.transforms import v2
from torchvision.models import efficientnet_b0,EfficientNet_B0_Weights,densenet121,DenseNet121_Weights
import numpy as np
import cv2
import matplotlib.pyplot as plt
import os
import random
import warnings

def set_random_seed(seed: int = 2222, deterministic: bool = False):
    """Set seeds"""
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = deterministic  # type: ignore

class GradCAM(torch.nn.Module):
    def __init__(self,model,target_layer):
        super(GradCAM, self).__init__()
        
        # get the pretrained VGG19 network
        self.model = model.eval()
        self.target_layer = target_layer
        
        self.activation = None
        self.gradient = None

        target_layer.register_forward_hook(self.hook_activation)
        target_layer.register_forward_hook(self.hook_gradient)
    
    def hook_activation(self, module, input, output):
        self.activation = output.cpu().detach()

    def hook_gradient(self, module, input,output):
        def save_grad(grad):
            self.gradient = grad.cpu().detach()
        output.register_hook(save_grad)

    def __call__(self, x):
        self.activation = None
        self.gradients = None
        return self.model(x)
    
    # method for the gradient extraction
    def get_activation_gradient(self):
        return self.activation, self.gradient
    
def save_gradcam(model, sample, class_int, img_name,image_shape):
    output=model(sample)

    output[:,class_int].backward()

    activation,gradient = model.get_activation_gradient()
    activation,gradient = activation.squeeze(0), gradient.squeeze(0)

    gradient = torch.mean(gradient,dim=[1,2])
    activation=activation*gradient.reshape(-1,1,1)

    heatmap = activation.mean(dim=0)
    heatmap = np.maximum(heatmap, 0)
    heatmap /= torch.max(heatmap)
    heatmap = heatmap.cpu().detach().data.numpy()
    heatmap=cv2.resize(heatmap,(image_shape,image_shape))
    heatmap = cv2.applyColorMap(np.uint8(255*heatmap), cv2.COLORMAP_JET)

    image=sample.cpu().numpy()[0]
    image=image.transpose(1,2,0)

    cv2.imwrite(img_name,image+0.5*heatmap)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
warnings.filterwarnings("ignore")
n_classes=2
image_shape=224
augmented_dataset_size=4000
seed=42
set_random_seed(seed)
path1="D:\Osteoporosis detection\datasets\Osteoporosis Knee X-ray Dataset Preprocessed"
path2="D:\Osteoporosis detection\datasets\Osteoporosis Knee X-ray modified\Osteoporosis Knee X-ray Preprocessed"
non_augment_transform=v2.Compose([v2.ToImageTensor(),
                       v2.ToDtype(torch.float32),
                       v2.Resize((image_shape,image_shape),antialias=True),
                       v2.Normalize(mean=[0],std=[1]),
                       ])
train_split=0.8
valid_split=0.1
test_split=0.1

non_augmented_dataset1=torchvision.datasets.ImageFolder(path1,transform=non_augment_transform)
generator1 = torch.Generator().manual_seed(42)
_,_,test_set1=torch.utils.data.random_split(non_augmented_dataset1, [train_split,valid_split,test_split],
                                                                generator=generator1)

non_augmented_dataset2=torchvision.datasets.ImageFolder(path2,transform=non_augment_transform)
generator1 = torch.Generator().manual_seed(42)
_,_,test_set2=torch.utils.data.random_split(non_augmented_dataset2, [train_split,valid_split,test_split],
                                                                generator=generator1)

In [18]:
model_name='efficient'
print('Model: ',model_name)
#EfficientNetB0 has 16 MBConv layers, freeze till 8th MBConv layer then. Freeze all till before 5th sequential
#DenseNet121 has 58 dense layers, freeze till 29th dense layer then. #Till before dense block 3
if model_name=='efficient':
    model=efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
    p=0.1
    model.classifier[0]=torch.nn.Dropout(p=p,inplace=True)
    model.classifier[-1]=torch.nn.Linear(in_features=1280,out_features=n_classes)
    target_layer = model.features[8][0]
    
elif model_name=='dense':
    model=densenet121(weights=DenseNet121_Weights.DEFAULT)
    p=0.3
    model.classifier=torch.nn.Sequential(torch.nn.Dropout(p=p,inplace=True),
                                        torch.nn.Linear(in_features=1024,out_features=n_classes),
                                        )
    target_layer = model.features[-2].denselayer16.conv2

elif 'mobilenet' in model_name:
    model=torchvision.models.mobilenet_v3_small(weights='DEFAULT')
    model.classifier[3]=torch.nn.Linear(in_features=1024,out_features=n_classes)
    target_layer = model.features[-1][0]

elif 'conv_next' in model_name:
    p=0.3
    model=torchvision.models.convnext_tiny(weights='DEFAULT')
    model.classifier[2]=torch.nn.Sequential(torch.nn.Dropout(p=p,inplace=True),
                                            torch.nn.Linear(in_features=768,out_features=n_classes),
                                            )
    target_layer = model.features[-1][-1].block[0]

model.load_state_dict(torch.load(f"model/{model_name}best_param.pkl"))

model.eval()
model.to(device)

gradcam_model=GradCAM(model,target_layer)

Model:  efficient


In [19]:
normal_sample=test_set.__getitem__(1)[0].unsqueeze(0)
osteoporosis_sample=test_set.__getitem__(2)[0].unsqueeze(0)

normal_sample = normal_sample.to(device)
osteoporosis_sample = osteoporosis_sample.to(device)

In [20]:
image=normal_sample.cpu().numpy()[0]
image=image.transpose(1,2,0)

cv2.imwrite('normal_image.jpg',image)

image=osteoporosis_sample.cpu().numpy()[0]
image=image.transpose(1,2,0)

cv2.imwrite('osteoporosis_image.jpg',image)

True

In [21]:
class_int=0
img_name = f"normal_heatmap_{model_name}.jpg"
save_gradcam(gradcam_model,normal_sample,class_int,img_name,image_shape)

class_int=1
img_name = f"osteoporosis_heatmap_{model_name}.jpg"
save_gradcam(gradcam_model,osteoporosis_sample,class_int,img_name,image_shape)