In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
from PIL import Image

In [2]:
class Alzheimer(nn.Module):
    def __init__(self):
        super().__init__()
        self.base = models.resnet18(pretrained=False)  # pretrained=False since we’re loading weights
        self.base.fc = nn.Linear(self.base.fc.in_features, 4)  # 4 classes

    def forward(self, x):
        return self.base(x)

# Load model
model = Alzheimer()
model.load_state_dict(torch.load("Alzehimer.pth", map_location="cpu"))
model.eval()



Alzheimer(
  (base): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runn

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load and preprocess image
img = Image.open("images.jpg").convert("RGB")
img = transform(img).unsqueeze(0)  

# Predict
with torch.no_grad():
    outputs = model(img)
    _, predicted = torch.max(outputs, 1)

classes = ["Mild Impairment", "Moderate Impairment", "No Impairment", "Very Mild Impairment"]
print("Prediction:", classes[predicted.item()])


Prediction: Mild Impairment
