In [8]:
import os
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from Unet import get_model  

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class CustomDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
        self.valid_images = []  

    def __len__(self):
        return len(self.valid_images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.valid_images[idx])
        image_name = os.path.splitext(self.valid_images[idx])[0]  

        mask_file_name = image_name.replace('cxrimage_', 'cxrmask_') + '.jpeg'
        mask_path = os.path.join(self.mask_dir, mask_file_name)

        if not os.path.exists(mask_path):
            return None 

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        if self.transform:
            image = self.transform(image)
            mask = np.array(mask)
            mask = (mask > 0).astype(np.uint8)
            mask = torch.from_numpy(mask).unsqueeze(0)
        return image, mask.float()

    def filter_valid_images(self):
        valid_images = []
        for image in self.images:
            img_path = os.path.join(self.image_dir, image)
            image_name = os.path.splitext(image)[0]  

            mask_file_name = image_name.replace('cxrimage_', 'cxrmask_') + '.jpeg'
            mask_path = os.path.join(self.mask_dir, mask_file_name)

            if os.path.exists(mask_path):
                valid_images.append(image)
        self.valid_images = valid_images

num_classes = 1
model = get_model(num_classes)
model.load_state_dict(torch.load('Unet_epoch_10.pth', map_location=device, weights_only=True))
model.to(device)
model.eval()

test_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
])

test_image_dir = 'ChestXray/test/image'
test_mask_dir = 'ChestXray/test/mask'
test_dataset = CustomDataset(test_image_dir, test_mask_dir, transform=test_transform)
test_dataset.filter_valid_images()  
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

result_dir = 'ChestXray/result/Unet10'
os.makedirs(result_dir, exist_ok=True)

with torch.no_grad():
    for i, (images, _) in enumerate(test_loader):
        images = images.to(device)
        outputs = model(images)
        outputs = torch.sigmoid(outputs)  
        predicted_masks = (outputs > 0.5).float()  

        for j in range(predicted_masks.shape[0]):
            predicted_mask = predicted_masks[j, 0].cpu().numpy() * 255 
            image_name = os.path.splitext(test_dataset.valid_images[i * predicted_masks.shape[0] + j])[0] + '_pred_mask.jpeg'
            result_path = os.path.join(result_dir, image_name)
            Image.fromarray(predicted_mask.astype(np.uint8)).save(result_path)

print(f'Prediction results have been saved to {result_dir}')

Prediction results have been saved to ChestXray/result/Unet10


In [12]:
def iou_score(pred_mask, true_mask):
    smooth = 1e-6
    pred_mask = pred_mask.astype(bool)
    true_mask = true_mask.astype(bool)
    intersection = np.logical_and(pred_mask, true_mask).sum()
    union = np.logical_or(pred_mask, true_mask).sum()
    iou = (intersection + smooth) / (union + smooth)
    return iou

pred_dir = 'ChestXray/result/Unet10'  
true_dir = 'ChestXray/test/mask'  

pred_files = os.listdir(pred_dir)
true_files = os.listdir(true_dir)

pred_names = {os.path.splitext(f)[0].replace('cxrimage_', '').replace('_pred_mask', '') for f in pred_files}
true_names = {os.path.splitext(f)[0].replace('cxrmask_', '') for f in true_files}

if pred_names != true_names:
    raise ValueError("The filenames of the predicted result images and the ground truth label images do not match.")

ious = []
for pred_file in pred_files:
    pred_name = os.path.splitext(pred_file)[0].replace('cxrimage_', '').replace('_pred_mask', '')
    true_file = f'cxrmask_{pred_name}.jpeg'
    pred_path = os.path.join(pred_dir, pred_file)
    true_path = os.path.join(true_dir, true_file)
    pred_mask = np.array(Image.open(pred_path).convert("L")) > 0  
    true_mask = np.array(Image.open(true_path).convert("L")) > 0  
    iou = iou_score(pred_mask, true_mask)
    ious.append(iou)

mean_iou = np.mean(ious)
print(f'Mean IoU: {mean_iou:.4f}')

Mean IoU: 0.9109
