In [16]:
import torch
import torch.optim as optim
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

In [8]:
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import torch
import numpy as np
import os

class CXRSegmentationDataset(Dataset):
    def __init__(self, img_dir, lungs_dir, heart_dir,
                 size=(512,512),
                 augment=False):
        """
        img_dir   : folder of resized grayscale JPEGs of original X‑rays
        lungs_dir : folder of resized lungs masks (0/255 PNGs)
        heart_dir : folder of resized heart masks (0/255 PNGs)
        size      : (width, height) to resize to
        augment   : whether to apply random flips/rotations
        """
        # collect IDs from the .jpg files in img_dir
        self.ids = [os.path.splitext(f)[0]
                    for f in os.listdir(img_dir)
                    if f.lower().endswith(".jpg")]
        print(f"Found {len(self.ids)} samples in {img_dir}")
        
        self.img_dir   = img_dir
        self.lungs_dir = lungs_dir
        self.heart_dir = heart_dir

        # image transformations
        self.tf_img = transforms.Compose([
            transforms.Resize(size),
            transforms.ToTensor(),                         # [1,H,W], floats in [0,1]
            transforms.Normalize(mean=[0.5], std=[0.5])    # adjust to your data
        ])
        # mask resizing (nearest to preserve labels)
        self.tf_mask = transforms.Resize(size,
                                         interpolation=transforms.InterpolationMode.NEAREST)

        # optional augmentations
        self.aug = transforms.RandomChoice([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(15),
        ]) if augment else None

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        _id = self.ids[idx]

        img_path   = os.path.join(self.img_dir,   f"{_id}.jpg")
        lung_path  = os.path.join(self.lungs_dir, f"{_id}_lungs.png")
        heart_path = os.path.join(self.heart_dir, f"{_id}_heart.png")

        img    = Image.open(img_path).convert("L")
        mask_l = Image.open(lung_path)
        mask_h = Image.open(heart_path)

        # Apply same random augmentation to all three
        if self.aug:
            seed = np.random.randint(0, 1_000_000)
            torch.manual_seed(seed)
            img    = self.aug(img)
            torch.manual_seed(seed)
            mask_l = self.aug(mask_l)
            torch.manual_seed(seed)
            mask_h = self.aug(mask_h)

        # Resize
        img    = self.tf_img(img)         # tensor [1,H,W]
        mask_l = self.tf_mask(mask_l)     # PIL image resized
        mask_h = self.tf_mask(mask_h)

        # Convert masks to tensors [1,H,W] uint8
        mask_l = transforms.PILToTensor()(mask_l)
        mask_h = transforms.PILToTensor()(mask_h)

        # Build a single multi-class mask: 0=BG,1=Lung,2=Heart
        ml = (mask_l.squeeze(0) // 255).to(torch.uint8)
        mh = (mask_h.squeeze(0) // 255).to(torch.uint8)
        mask = torch.zeros_like(ml, dtype=torch.uint8)
        mask[ml == 1] = 1
        mask[mh == 1] = 2

        return img, mask.long()  # img: [1,H,W], mask: [H,W] ints in {0,1,2}


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# — Attention Block —
class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        # gating signal F_g, skip connection F_l
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)
    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

# — Double Convolution block —
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.double_conv(x)

# — Attention U‑Net —
class AttUNet(nn.Module):
    def __init__(self, n_channels=1, n_classes=3, features=[64,128,256,512]):
        super().__init__()
        # Encoder
        self.downs = nn.ModuleList()
        for f_in, f_out in zip([n_channels]+features, features):
            self.downs.append(DoubleConv(f_in, f_out))
        self.pool = nn.MaxPool2d(2)
        # Bottleneck
        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        # Decoder with Attention
        self.ups = nn.ModuleList()
        self.attentions = nn.ModuleList()
        rev_features = features[::-1]
        for f in rev_features:
            self.ups.append(nn.ConvTranspose2d(f*2, f, kernel_size=2, stride=2))
            self.attentions.append(AttentionBlock(F_g=f, F_l=f, F_int=f//2))
            self.ups.append(DoubleConv(f*2, f))
        # Final conv
        self.final_conv = nn.Conv2d(features[0], n_classes, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.ups), 2):
            up_trans = self.ups[idx]
            attn     = self.attentions[idx//2]
            conv     = self.ups[idx+1]
            x = up_trans(x)
            skip = skip_connections[idx//2]
            # crop if needed to match sizes
            if x.shape != skip.shape:
                x = F.interpolate(x, size=skip.shape[2:])
            # attention & concatenation
            skip = attn(g=x, x=skip)
            x = torch.cat([skip, x], dim=1)
            x = conv(x)
        return self.final_conv(x)


In [10]:
from torch.utils.data import DataLoader
# assuming you have CXRSegmentationDataset defined
train_ds = CXRSegmentationDataset(
    img_dir   = "processed_images",
    lungs_dir = "processed_masks/lungs",
    heart_dir = "processed_masks/heart",
    size      = (512,512),
    augment   = True
)
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=4)


Found 253 samples in processed_images


In [11]:
import os
import wandb

In [19]:
train_loader = DataLoader(
    train_ds,
    batch_size=8,
    shuffle=True,
    num_workers=0   # <— no multiprocessing
)

In [20]:
# 0) set your W&B key in‑script
os.environ["WANDB_API_KEY"] = "9b3cc6b608bb679c5cc822e2e256c754cc777ee0"
wandb.login(key=os.environ["WANDB_API_KEY"], force=True)

# 1) init W&B
wandb.init(
    project="cxr-segmentation",
    name="attunet-batch-run",
    config={
        "epochs": 15,
        "batch_size": 8,
        "learning_rate": 1e-4,
        "architecture": "Attention U-Net",
        "input_size": [512,512],
        "n_classes": 3
    }
)



In [None]:
device    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model     = AttUNet(n_channels=1, n_classes=wandb.config.n_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=wandb.config.learning_rate)
criterion = CrossEntropyLoss()

# 4) watch gradients & parameters
wandb.watch(model, log="all", log_freq=50)

# 5) training loop
for epoch in range(wandb.config.epochs):
    model.train()
    running_loss = 0.0
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{wandb.config.epochs}, Loss: {epoch_loss:.4f}")

    # log to W&B
    wandb.log({
        "epoch":      epoch+1,
        "train_loss": epoch_loss
    })

# 6) save model artifact
torch.save(model.state_dict(), "attunet_final.pt")
wandb.save("attunet_final.pt")

# 7) finish
wandb.finish()