In [1]:
# Cell 1: Imports and Setup
import torch
from torchvision import transforms
from PIL import Image
import random
import os
import sys

# Ensure root directory is in the path
ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.insert(0, ROOT_DIR)

from classification.classification_model import get_resnet18_model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [2]:
# Cell 2: Load Pretrained Model 
model = get_resnet18_model(num_classes=2)
model.load_state_dict(torch.load('../models/resnet18_brain_tumour.pth', map_location=device))
model.to(device)
model.eval()




ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [37]:
# Cell 3: Define a transform and select random image
base_dir = '../data/classification'
class_dirs = ['yes', 'no']
chosen_class = random.choice(class_dirs)
image_file = random.choice(os.listdir(os.path.join(base_dir, chosen_class)))
image_path = os.path.join(base_dir, chosen_class, image_file)

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),  # Ensures 3 channels
    transforms.ToTensor(),
     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

In [38]:
# Cell 4: Load and preprocess image
image = Image.open(image_path).convert("L")
image = transform(image).unsqueeze(0).to(device)

In [39]:
# Cell 5: Perform inference
output = model(image)
pred = torch.argmax(output, dim=1).item()

In [40]:

# Cell 6: Display result
print(f"✅ Image: {image_file}")
print(f"Ground Truth: {chosen_class.upper()}")
print(f"Predicted: {'TUMOUR' if pred == 1 else 'NO TUMOUR'}")

✅ Image: no 90.jpg
Ground Truth: NO
Predicted: NO TUMOUR
