In [17]:
import torch
import torchvision
torchvision.disable_beta_transforms_warning()
from torchvision.transforms import v2
from torchvision.models import efficientnet_b0,EfficientNet_B0_Weights,densenet121,DenseNet121_Weights
import numpy as np
import cv2
import matplotlib.pyplot as plt
import os
import random
import warnings
from typing import Tuple, List, Dict

def set_random_seed(seed: int = 2222, deterministic: bool = False):
    """Set seeds"""
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = deterministic  # type: ignore

class CustomImageFolder(torchvision.datasets.ImageFolder):
    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
        """
        Override this method to load from setting file instead of scanning directory
        """
        self.map=map
        if map is not None:
            classes = list(self.map.keys())
            classes_to_idx = self.map
        else:
            classes, classes_to_idx=super().find_classes(directory)
        return classes, classes_to_idx
    
class GradCAM(torch.nn.Module):
    def __init__(self,model,target_layer):
        super(GradCAM, self).__init__()
        
        # get the pretrained VGG19 network
        self.model = model.eval()
        self.target_layer = target_layer
        
        self.activation = None
        self.gradient = None

        target_layer.register_forward_hook(self.hook_activation)
        target_layer.register_forward_hook(self.hook_gradient)
    
    def hook_activation(self, module, input, output):
        self.activation = output.cpu().detach()

    def hook_gradient(self, module, input,output):
        def save_grad(grad):
            self.gradient = grad.cpu().detach()
        output.register_hook(save_grad)

    def __call__(self, x):
        self.activation = None
        self.gradients = None
        return self.model(x)
    
    # method for the gradient extraction
    def get_activation_gradient(self):
        return self.activation, self.gradient
    
def prep_dataset(path,image_shape=224,augmented_dataset_size=4000,new_map=None
                 ,train_split=0.8,valid_split=0.1,test_split=0.1):
    global map
    non_augment_transform=v2.Compose([v2.ToImageTensor(),
                        v2.ToDtype(torch.float32),
                        v2.Resize((image_shape,image_shape),antialias=True),
                        #v2.Normalize(mean=[0.5], std=[0.5]),
                        ])
    transforms=v2.Compose([v2.ToImageTensor(),
                        v2.ToDtype(torch.float32),
                        v2.RandomAffine(degrees=30,shear=30),
                        v2.RandomZoomOut(side_range=(1,1.5)),
                        v2.Resize((image_shape,image_shape),antialias=True),
                        #v2.Normalize(mean=[0.5], std=[0.5]),
                        ])
    map=new_map
    non_augmented_dataset=CustomImageFolder(path,transform=non_augment_transform)
    # dataset=CustomImageFolder(path,transform=transforms)
    # factor=augmented_dataset_size//len(dataset)-1

    # print(dataset.__getitem__(0)[1])
    # print(dataset.class_to_idx)
    # new_dataset=torch.utils.data.ConcatDataset([non_augmented_dataset]+[non_augmented_dataset for _ in range(factor)])
    # del non_augmented_dataset,dataset

    
    #dataset=torchvision.datasets.ImageFolder(path,transform=transforms)
    generator1 = torch.Generator().manual_seed(42)
    return torch.utils.data.random_split(non_augmented_dataset, [train_split+valid_split,test_split],
                                                                generator=generator1)

def save_gradcam(model, sample, image, class_int, img_name,image_shape):
    output=model(sample)

    output[:,class_int].backward()

    activation,gradient = model.get_activation_gradient()
    activation,gradient = activation.squeeze(0), gradient.squeeze(0)


    gradient = torch.mean(gradient,dim=[1,2])
    activation=activation*gradient.reshape(-1,1,1)

    heatmap = activation.mean(dim=0)
    heatmap = np.maximum(heatmap, 0)
    heatmap /= torch.max(heatmap)
    heatmap = heatmap.cpu().detach().data.numpy()
    heatmap=cv2.resize(heatmap,(image_shape,image_shape))
    heatmap = cv2.applyColorMap(np.uint8(255*heatmap), cv2.COLORMAP_JET)

    cv2.imwrite(img_name,image+0.5*heatmap)

def create_model(model_name,n_classes):
    if 'efficient' in model_name:
        model=efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
        p=0.1
        model.classifier[0]=torch.nn.Dropout(p=p,inplace=True)
        model.classifier[-1]=torch.nn.Linear(in_features=1280,out_features=n_classes)
        target_layer = model.features[8][0]
        
    elif 'dense' in model_name:
        model=densenet121(weights=DenseNet121_Weights.DEFAULT)
        p=0.3
        model.classifier=torch.nn.Sequential(torch.nn.Dropout(p=p,inplace=True),
                                            torch.nn.Linear(in_features=1024,out_features=n_classes),
                                            )
        target_layer = model.features[-2].denselayer16.conv2

    elif 'mobilenet' in model_name:
        model=torchvision.models.mobilenet_v3_small(weights='DEFAULT')
        model.classifier[3]=torch.nn.Linear(in_features=1024,out_features=n_classes)
        target_layer = model.features[-1][0]

    elif 'conv_next' in model_name:
        p=0.3
        model=torchvision.models.convnext_tiny(weights='DEFAULT')
        model.classifier[2]=torch.nn.Sequential(torch.nn.Dropout(p=p,inplace=True),
                                                torch.nn.Linear(in_features=768,out_features=n_classes),
                                                )
        target_layer = model.features[-1][-1].block[0]

    model.load_state_dict(torch.load(f"model/{model_name}best_param.pkl"))

    
        
    return model, target_layer

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
warnings.filterwarnings("ignore")
n_classes=3
image_shape=224
augmented_dataset_size=4000
seed=42
path1='D:\Osteoporosis detection\datasets\Osteoporosis Knee X-ray modified\Osteoporosis Knee X-ray Preprocessed'
path2="D:\Osteoporosis detection\datasets\Osteoporosis Knee X-ray only osteopenia Preprocessed"

new_n_classes=3
map1={'normal':0,'osteoporosis':2}
map2={'osteopenia':1}
new_map={**map1, **map2}
idx_to_class = {v: k for k, v in new_map.items()}

set_random_seed(seed)

dataset1,test_set1=prep_dataset(path1,image_shape,augmented_dataset_size,map1)
dataset2,test_set2=prep_dataset(path2,image_shape,augmented_dataset_size,map2)

test_set = torch.utils.data.ConcatDataset([test_set1, test_set2])

In [19]:
model_name='mobilenet_incremental_3_class'
print('Model: ',model_name)
#EfficientNetB0 has 16 MBConv layers, freeze till 8th MBConv layer then. Freeze all till before 5th sequential
#DenseNet121 has 58 dense layers, freeze till 29th dense layer then. #Till before dense block 3
model, target_layer = create_model(model_name,n_classes)
model.eval()
model.to(device)

image_normalizer = v2.Normalize(mean=[0], std=[1])
sample_normalizer = v2.Normalize(mean=[0.5], std=[0.5])

Model:  mobilenet_incremental_3_class


In [23]:
wrong_indices = []
model.eval()
os.makedirs('misclassify', exist_ok=True)
with torch.no_grad():
    for idx in range(len(test_set)):
        sample, label = test_set.__getitem__(idx)
        image = image_normalizer(sample).cpu().numpy().transpose(1,2,0)
        sample = sample_normalizer(sample).unsqueeze(0).to(device)
        output = model(sample)
        pred = output.argmax(dim=1).item()
        if pred != label:
            wrong_indices.append((idx, pred, label))
            cv2.imwrite(f'misclassify/{idx},{pred},{label}.jpg',image)
print("Misclassified indices:", wrong_indices)

Misclassified indices: [(4, 1, 0), (7, 1, 0), (18, 2, 1), (20, 0, 1), (27, 0, 1), (28, 2, 1), (29, 2, 1), (31, 2, 1)]


In [25]:
os.makedirs('misclassify heatmaps', exist_ok=True)

gradcam_model=GradCAM(model,target_layer)
selected_samples = [(7, 1, 0), (18, 2, 1), (27, 0, 1)]
for idx, pred_class, true_class in selected_samples:
    sample=test_set.__getitem__(idx)[0].unsqueeze(0)

    image=image_normalizer(sample).cpu().numpy()[0]
    image=image.transpose(1,2,0)

    cv2.imwrite(f'misclassify heatmaps/Idx {idx} Predicted {idx_to_class[pred_class]} Actual {idx_to_class[true_class]}.jpg',image)

    img_name = f'misclassify heatmaps/Idx {idx} Predicted {idx_to_class[pred_class]} Actual {idx_to_class[true_class]} pred heatmap.jpg'
    sample = sample_normalizer(sample.to(device))
    save_gradcam(gradcam_model,sample,image,pred_class,img_name,image_shape)

    img_name = f'misclassify heatmaps/Idx {idx} Predicted {idx_to_class[pred_class]} Actual {idx_to_class[true_class]} actual heatmap.jpg'
    sample = sample_normalizer(sample.to(device))
    save_gradcam(gradcam_model,sample,image,true_class,img_name,image_shape)