In [11]:
import torch
import torch.optim as optim
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

In [12]:
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import torch
import numpy as np
import os

class CXRSegmentationDataset(Dataset):
    def __init__(self, img_dir, lungs_dir, heart_dir,
                 size=(512,512),
                 augment=False):
        """
        img_dir   : folder of resized grayscale PNGs of original X‑rays
        lungs_dir : folder of resized lungs masks (0/255 PNGs)
        heart_dir : folder of resized heart masks (0/255 PNGs)
        size      : (width, height) to resize to
        augment   : whether to apply random flips/rotations
        """
        self.ids = [fname[:-4] for fname in os.listdir(img_dir) if fname.endswith(".png")]
        self.img_dir   = img_dir
        self.lungs_dir = lungs_dir
        self.heart_dir = heart_dir

        # image transformations
        self.tf_img = transforms.Compose([
            transforms.Resize(size),
            transforms.ToTensor(),                         # [1,H,W], floats in [0,1]
            transforms.Normalize(mean=[0.5], std=[0.5])    # adjust to your data
        ])
        # mask resizing (nearest to preserve labels)
        self.tf_mask = transforms.Resize(size,
                                         interpolation=transforms.InterpolationMode.NEAREST)

        # optional augmentations
        self.aug = transforms.RandomChoice([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(15),
        ]) if augment else None

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        _id = self.ids[idx]

        # Load image and masks
        img_path   = os.path.join(self.img_dir,   f"{_id}.png")
        lung_path  = os.path.join(self.lungs_dir, f"{_id}_lungs.png")
        heart_path = os.path.join(self.heart_dir, f"{_id}_heart.png")

        img   = Image.open(img_path).convert("L")
        mask_l = Image.open(lung_path)
        mask_h = Image.open(heart_path)

        # Apply same random augmentation to all three
        if self.aug:
            seed = np.random.randint(0, 1_000_000)
            torch.manual_seed(seed)
            img   = self.aug(img)
            torch.manual_seed(seed)
            mask_l = self.aug(mask_l)
            torch.manual_seed(seed)
            mask_h = self.aug(mask_h)

        # Resize
        img    = self.tf_img(img)        # tensor [1,H,W]
        mask_l = self.tf_mask(mask_l)    # PIL image resized
        mask_h = self.tf_mask(mask_h)

        # Convert masks to tensors [1,H,W] uint8
        mask_l = transforms.PILToTensor()(mask_l)
        mask_h = transforms.PILToTensor()(mask_h)

        # Build a single multi-class mask: 0=BG,1=Lung,2=Heart
        ml = (mask_l.squeeze(0) // 255).to(torch.uint8)
        mh = (mask_h.squeeze(0) // 255).to(torch.uint8)
        mask = torch.zeros_like(ml, dtype=torch.uint8)
        mask[ml == 1] = 1
        mask[mh == 1] = 2

        return img, mask.long()  # img: [1,H,W], mask: [H,W] ints in {0,1,2}