In [1]:
import torch
from torch import nn
from torchvision import transforms
from PIL import Image
import os

# Define your model architecture (ConvNet)
class ConvNet(nn.Module):
    def __init__(self, num_classes=4):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 12, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(12)
        self.relu1 = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(12, 20, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(20, 32, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(32)
        self.relu3 = nn.ReLU()
        self.fc = nn.Linear(75 * 75 * 32, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        x = x.view(-1, 32 * 75 * 75)
        x = self.fc(x)
        return x

# Now, load your model
MODEL_PATH = 'D:/Brain_Tumor_Detection/brain_tumor_model.pth'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvNet(num_classes=4).to(device)

if os.path.exists(MODEL_PATH):
    model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
    model.eval()

# Test with an image
transform = transforms.Compose([
    transforms.Resize((150, 150)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

image_path = 'image(23).jpg'  # Replace with your image path
image = Image.open(image_path)

# Process the image
img = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
    output = model(img)

# Get prediction
_, predicted = torch.max(output, 1)
classes = ['no_tumor', 'glioma_tumor', 'meningioma_tumor', 'pituitary_tumor']
predicted_class = classes[predicted.item()]
print(f"Predicted tumor type: {predicted_class}")


Predicted tumor type: pituitary_tumor


  model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
