In [41]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import numpy as np
from PIL import Image
from tqdm import tqdm


In [42]:
from store_and_retrieve import load_images_from_mongodb

In [43]:
class MongoDBDataset(Dataset):
    def __init__(self, query=None, modality=None, body_part=None, transform=None):
        super().__init__()
        self.data = load_images_from_mongodb(query, modality, body_part)
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image_info = self.data[idx]
        image_array = image_info["image_array"]
        image = Image.fromarray(image_array.astype(np.uint8))
        if self.transform:
            image = self.transform(image)
        label = 1 if image_info["metadata"].get("is_anatomy", None) else 0
        return image, label

    @property
    def classes(self):
        return sorted(set([image_info["metadata"].get("is_anatomy") 
                           for image_info in self.data 
                           if image_info["metadata"].get("is_anatomy") is not None]))


In [44]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
])


In [45]:
test_dataset = MongoDBDataset(query={}, modality='xray_test', body_part='chest_test', transform=transform)
print("Test dataset length:", len(test_dataset))
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

Loading images: 100%|██████████| 624/624 [00:00<00:00, 741.34 images/s]

624 Image(s) loaded from MongoDB
Test dataset length: 624





In [46]:
model_path = r"D:\Personal Project\Federated-Anatomy-Prediction\src\model.pth"
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
num_of_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(num_of_features, 1024),
    nn.BatchNorm1d(1024),
    nn.ReLU(),
    nn.Dropout(0.3),

    nn.Linear(1024, 512),  
    nn.BatchNorm1d(512),
    nn.ReLU(),
    nn.Dropout(0.3),  

    nn.Linear(512, 256),  
    nn.BatchNorm1d(256),
    nn.ReLU(),
    nn.Dropout(0.2),  

    nn.Linear(256, 2)
)
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval()


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [47]:
criterion = nn.CrossEntropyLoss()
test_loss = 0.0
correct = 0
total = 0
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="evalutae"):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        test_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

avg_loss = test_loss / total
accuracy = correct / total
print(f"Test Loss: {avg_loss:.4f}")
print(f"Test Accuracy: {accuracy:.4f}")


evalutae: 100%|██████████| 20/20 [00:01<00:00, 14.27it/s]

Test Loss: 0.1869
Test Accuracy: 0.9391



