In [None]:
import torch
import torch.nn as nn
from torchvision import models


In [None]:
class HierarchicalResNet(nn.Module):
    def __init__(self, n_anatomy, n_category, n_disease):
        super().__init__()
        self.backbone = models.resnet18(weights='IMAGENET1K_V1')
        in_feats = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()

        self.anatomy_head = nn.Linear(in_feats, n_anatomy)
        self.category_head = nn.Linear(in_feats, n_category)
        self.disease_head = nn.Linear(in_feats, n_disease)

    def forward(self, x):
        features = self.backbone(x)
        return (
            self.anatomy_head(features),
            self.category_head(features),
            self.disease_head(features)
        )


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

checkpoint = torch.load(
    "/content/hierarchical_mri_model_15epoch.pth",
    map_location=device
)

model = HierarchicalResNet(
    checkpoint['num_anatomy'],
    checkpoint['num_category'],
    checkpoint['num_disease']
)

model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 220MB/s]


HierarchicalResNet(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=Tru

In [None]:
from torchvision import transforms
from PIL import Image

infer_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


In [None]:
img_path = "/content/Bilateral_Ulnar_Impaction_Syndrome002.jpg"  # change this
img = Image.open(img_path).convert("RGB")

img_tensor = infer_transform(img)


In [None]:
print(img_tensor.shape)  # should be torch.Size([3, 224, 224])


torch.Size([3, 224, 224])


In [None]:
with torch.no_grad():
    anat_logits, cat_logits, dis_logits = model(
        img_tensor.unsqueeze(0).to(device)
    )

pred_anatomy = anat_logits.argmax(1).item()
pred_category = cat_logits.argmax(1).item()
pred_disease = dis_logits.argmax(1).item()


In [None]:
import torch.nn.functional as F

probs = F.softmax(dis_logits, dim=1)
confidence = probs.max().item()


In [None]:
print("Anatomy:", pred_anatomy)
print("Category:", pred_category)
print("Disease:", pred_disease)
print("Confidence:", confidence)


Anatomy: 2
Category: 3
Disease: 5
Confidence: 0.9999867677688599
