In [None]:
import torch
import torchvision
torchvision.disable_beta_transforms_warning()
from torchvision.transforms import v2
from torchvision.models import efficientnet_b0,EfficientNet_B0_Weights,densenet121,DenseNet121_Weights
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
import cv2

import warnings

class KONet(torch.nn.Module):

    def __init__(
            self,
            m1_ratio=0.6,
            m2_ratio=0.4,
            m1_dropout=0.1,
            m2_dropout=0.3,
            n_classes=2
    ):
        super().__init__()
        assert m1_ratio+m2_ratio==1
        self.n_classes=n_classes
        self.m1_ratio=m1_ratio
        self.m2_ratio=m2_ratio
        self.m1_dropout=m1_dropout
        self.m2_dropout=m2_dropout

        self.efficient=efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
        self.efficient.classifier[0]=torch.nn.Dropout(p=self.m1_dropout,inplace=True)
        self.efficient.classifier[-1]=torch.nn.Linear(in_features=1280,out_features=self.n_classes)

        self.dense=densenet121(weights=DenseNet121_Weights.DEFAULT)
        self.dense.classifier=torch.nn.Sequential(torch.nn.Dropout(p=self.m2_dropout,inplace=True),
                                            torch.nn.Linear(in_features=1024,out_features=n_classes),
                                            )

    def forward(self, x):
        m1=self.efficient(x)
        m2=self.dense(x)
        out=self.m1_ratio*m1+self.m2_ratio*m2
        return out

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
warnings.filterwarnings("ignore")
n_classes=2
image_shape=224
augmented_dataset_size=4000
    
path="D:\Osteoporosis detection\datasets\Osteoporosis Knee X-ray Dataset"
non_augment_transform=v2.Compose([v2.ToImageTensor(),
                       v2.ToDtype(torch.float32),
                       v2.Resize((image_shape,image_shape),antialias=True),
                       v2.Normalize(mean=[0.5], std=[0.5]),
                       ])
transforms=v2.Compose([v2.ToImageTensor(),
                        v2.ToDtype(torch.float32),
                        v2.RandomAffine(degrees=30,shear=30),
                        v2.RandomZoomOut(side_range=(1,1.5)),
                        v2.Resize((image_shape,image_shape),antialias=True),
                        v2.Normalize(mean=[0.5], std=[0.5]),
                        ])
non_augmented_dataset=torchvision.datasets.ImageFolder(path,transform=non_augment_transform)


generator1 = torch.Generator().manual_seed(42)
train_split=0.8
valid_split=0.1
test_split=0.1
train_set,valid_set,test_set=torch.utils.data.random_split(non_augmented_dataset, [train_split,valid_split,test_split],
                                                                generator=generator1)



In [10]:
model_name='efficient'
print('Model: ',model_name)
#EfficientNetB0 has 16 MBConv layers, freeze till 8th MBConv layer then. Freeze all till before 5th sequential
#DenseNet121 has 58 dense layers, freeze till 29th dense layer then. #Till before dense block 3
if model_name=='efficient':
    model=efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
    p=0.1
    model.classifier[0]=torch.nn.Dropout(p=p,inplace=True)
    model.classifier[-1]=torch.nn.Linear(in_features=1280,out_features=n_classes)
    frozen_layers=4

elif model_name=='dense':
    model=densenet121(weights=DenseNet121_Weights.DEFAULT)
    p=0.3
    model.classifier=torch.nn.Sequential(torch.nn.Dropout(p=p,inplace=True),
                                        torch.nn.Linear(in_features=1024,out_features=n_classes),
                                        )

elif 'mobilenet' in model_name:
    print('working')
    model=torchvision.models.mobilenet_v3_small(weights='DEFAULT')
    model.classifier[3]=torch.nn.Linear(in_features=1024,out_features=n_classes)

elif 'conv_next' in model_name:
    p=0.3
    model=torchvision.models.convnext_tiny(weights='DEFAULT')
    model.classifier[2]=torch.nn.Sequential(torch.nn.Dropout(p=p,inplace=True),
                                            torch.nn.Linear(in_features=768,out_features=n_classes),
                                            )

elif model_name=='KONet':
    m1_ratio=0.6
    m2_ratio=0.4
    m1_dropout=0.1
    m2_dropout=0.3
    model=KONet(m1_ratio=m1_ratio,m2_ratio=m2_ratio,m1_dropout=m1_dropout,m2_dropout=m2_dropout,n_classes=n_classes)

model.load_state_dict(torch.load(f"model/{model_name}best_param.pkl"))

model=model.features
#model=model[:4]
model.eval()
model.to(device)

#Grabs the weights of the final layer in the feature extractor
weights = list(model.parameters())[-2]

Model:  efficient


In [11]:
normal_sample=test_set.__getitem__(1)[0].unsqueeze(0)
osteoporosis_sample=test_set.__getitem__(2)[0].unsqueeze(0)

normal_sample = normal_sample.to(device)
osteoporosis_sample = osteoporosis_sample.to(device)

In [12]:
normal_output=model(normal_sample)
osteoporosis_output=model(osteoporosis_sample)

In [13]:
normal_heatmap = None
osteoporosis_heatmap = None

for i in range (0, len(weights)):
    normal_map = normal_output[0,i,:,:]
    osteoporosis_map = osteoporosis_output[0,i,:,:]

    if i == 0:
        normal_heatmap = weights[i] * normal_map
        osteoporosis_heatmap = weights[i] * osteoporosis_map
    else:
        normal_heatmap += weights[i] * normal_map
        osteoporosis_heatmap += weights[i] * osteoporosis_map

normal_heatmap=normal_heatmap.cpu().data.numpy()
osteoporosis_heatmap=osteoporosis_heatmap.cpu().data.numpy()

normal_heatmap/=np.max(normal_heatmap)
osteoporosis_heatmap/=np.max(osteoporosis_heatmap)

normal_heatmap=cv2.resize(normal_heatmap,(image_shape,image_shape))
osteoporosis_heatmap=cv2.resize(osteoporosis_heatmap,(image_shape,image_shape))

normal_heatmap = cv2.applyColorMap(np.uint8(255*normal_heatmap), cv2.COLORMAP_JET)
osteoporosis_heatmap = cv2.applyColorMap(np.uint8(255*osteoporosis_heatmap), cv2.COLORMAP_JET)

In [14]:
normal_image=normal_sample.cpu().numpy()[0]
normal_image=normal_image.transpose(1,2,0)

osteoporosis_image=osteoporosis_sample.cpu().numpy()[0]
osteoporosis_image=osteoporosis_image.transpose(1,2,0)

In [15]:
cv2.imwrite('normal_image.jpg',normal_image)
cv2.imwrite('osteoporosis_image.jpg',osteoporosis_image)

True

In [16]:
#cv2.imwrite(f'normal_heatmap_{model_name}.jpg',normal_image+0.5*normal_heatmap)
#cv2.imwrite(f'osteoporosis_heatmap_{model_name}.jpg',osteoporosis_image+0.5*osteoporosis_heatmap)

cv2.imwrite(f'normal_heatmap.jpg',normal_image+0.5*normal_heatmap)
cv2.imwrite(f'osteoporosis_heatmap.jpg',osteoporosis_image+0.5*osteoporosis_heatmap)

True