In [2]:
import pandas as pd
from PIL import Image
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms, models

In [3]:
# Custom ResNet Model
class CustomResNet(nn.Module):
    def __init__(self, num_classes):
        super(CustomResNet, self).__init__()
        self.model = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1)  # Use ImageNet pretrained weights
        self.model.fc = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.model.fc.in_features, num_classes),
        )

    def forward(self, x):  # x is input data
        return self.model(x)


In [5]:
# Function to classify unseen images
def classify_image(image_path, csv_path, model_path):
    # Load the CSV to extract labels
    try:
        data = pd.read_csv(csv_path)
        unique_labels = sorted(data['Label'].unique())  # Ensure labels are sorted for consistency
        num_classes = len(unique_labels)
        print(f"Number of classes: {num_classes}")
    except Exception as e:
        print(f"Error loading CSV file: {e}")
        
    # Load the image
    try:
        img = Image.open(image_path).convert("RGB")  # Ensure the image is in RGB format
    except Exception as e:
        print(f"Error loading image: {e}")
        return

    # Apply the same transformations used during training
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize to match model input
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    img = transform(img).unsqueeze(0)  # Add batch dimension (1, C, H, W)

    # Load the model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CustomResNet(num_classes=num_classes).to(device)

    try:
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()
    except Exception as e:
        print(f"Error loading model: {e}")
        return

    # Make prediction
    with torch.no_grad():
        img = img.to(device)
        output = model(img)  # Get model output
        pred = torch.argmax(output, dim=1).item()  # Get the predicted class

    print(f"Predicted class for the image: {unique_labels[pred]}")

# Usage
csv_path = "C:\\Users\\Acer\\Desktop\\Model\\Labels.csv"  # Path to your CSV file
image_path = "cariomegaly.png"  # Path to the image you want to classify
model_path = "best_model60000.pth"  # Path to your saved model

classify_image(image_path, csv_path, model_path)

Number of classes: 590


  model.load_state_dict(torch.load(model_path, map_location=device))


Predicted class for the image: Atelectasis|Cardiomegaly|Consolidation|Effusion|Infiltration|Pneumonia
