In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.models import segmentation
from torchvision import datasets, transforms

import numpy as np
import os
import timm
from PIL import Image
import matplotlib.pyplot as plt

In [2]:
# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# Stage 1: Classification between normal and abnormal
stage1_model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet101', pretrained=False)
stage1_model.fc = nn.Sequential(
    nn.Linear(stage1_model.fc.in_features, 512),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(256, 2)  # 2 classes: normal or abnormal
)
stage1_model.to(device)
stage1_model.eval()

# Load pre-trained weights for stage 1
stage1_model.load_state_dict(torch.load('/kaggle/input/final_model_v2/pytorch/v1/1/stage 1.pth'))

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /root/.cache/torch/hub/v0.10.0.zip


<All keys matched successfully>

In [4]:
# Stage 2: Segmentation
stage2_model = segmentation.deeplabv3_resnet50(pretrained=False)
stage2_model.classifier[-1] = nn.Conv2d(256, 1, kernel_size=1)
stage2_model.to(device)
stage2_model.eval()

state_dict = torch.load('/kaggle/input/final_model_v2/pytorch/v1/1/stage 2.pth')

expected_keys = stage2_model.state_dict().keys()
state_dict = {k: v for k, v in state_dict.items() if k in expected_keys}

stage2_model.load_state_dict(state_dict, strict=False)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 152MB/s] 


<All keys matched successfully>

In [5]:
# Stage 3: Classification between benign and malignant
stage3_model = timm.create_model('inception_v4', pretrained=False)
stage3_model.last_linear = nn.Sequential(
    nn.Linear(stage3_model.last_linear.in_features, 512),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(256, 2)  # 2 classes: benign or malignant
)
stage3_model.to(device)
stage3_model.eval()

# Load pre-trained weights for stage 3
stage3_model.load_state_dict(torch.load('/kaggle/input/final_model_v2/pytorch/v3/1/Stage 3.pth'))

<All keys matched successfully>

In [6]:
dataset_path = '/kaggle/input/evalset/Testing'

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])

if os.path.exists(dataset_path + '.DS_Store'):
        os.remove(dataset_path + '.DS_Store')

dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)

In [9]:
total = 0
correct = 0

for images, labels in data_loader: # labels are: [benign malignant normal]
    
    # Stage 1
    out_stage1 = stage1_model(images.to(device))
    _, pred_stage1 = torch.max(out_stage1, 1)
    
    if pred_stage1.item() == 1:
        print(f"Image is predicted as normal - stage 1")
        total += 1
        if labels == 2:
            correct += 1
        continue
    
    # Stage 2
    out_stage2 = stage2_model(images.to(device))
    if 'out' in out_stage2:
        mask = out_stage2['out']
    mask = (mask > 0.8).float()
    
    if torch.sum(mask) == 0:
        print(f"Image is predicted as normal - black mask")
        total += 1
        if labels == 2:
            correct += 1
            continue
    
    masked_img = (images.to(device)) * mask
    
    # Stage 3
    out_stage3 = stage3_model(masked_img.to(device))
    _, pred_stage3 = torch.max(out_stage3, 1)
    
    if pred_stage3 == 0:
        print(f"Image is predicted as benign")
        total += 1
        if labels == 0:
            correct += 1
    elif pred_stage3 == 1:
        print(f"Image is predicted as malignant")
        total += 1
        if labels == 1:
            correct += 1
            
print(f'Total images: {total}, Correct predictions: {correct}')
print(f'ACC: {(correct / total) * 100}%')

Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as malignant
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is predicted as benign
Image is pr