In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
from PIL import Image

# Define the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the saved model state dict
path = r'model_path.pth'
pretrained_vit_state_dict = torch.load(path, map_location=device)

# Setup a ViT model instance with pretrained weights
pretrained_vit = models.vit_b_16().to(device)

# Rename the keys in the loaded state dict to match the keys in the model's state dict
new_pretrained_vit_state_dict = {}
for k, v in pretrained_vit_state_dict.items():
    if 'head' in k:
        k = k.replace('head', 'heads')  # Rename keys related to the classifier head
    new_pretrained_vit_state_dict[k] = v

pretrained_vit.load_state_dict(new_pretrained_vit_state_dict, strict=False)

# Freeze the base parameters
for parameter in pretrained_vit.parameters():
    parameter.requires_grad = False

# Change the classifier head
class_names = ['benign', 'malignant', 'normal']
pretrained_vit.conv_proj.in_channels = 3
pretrained_vit.heads = nn.Linear(in_features=768, out_features=len(class_names)).to(device)

# Load the image
image_path = r"DATASET\test\malignant\malignant347.png"
image = Image.open(image_path).convert('RGB')

# Define the transformations for the input image
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Preprocess the image
input_image = transform(image).unsqueeze(0).to(device)

# Set the model to evaluation mode
pretrained_vit.eval()

# Make prediction
with torch.no_grad():
    outputs = pretrained_vit(input_image)

# Get predicted class index
_, predicted = torch.max(outputs, 1)

# Map the predicted index to class label
predicted_class = class_names[predicted.item()]
print("Predicted class:", predicted_class)
