# Import Libraries
This section imports the necessary libraries for data handling, model training, and visualization.

In [6]:
import os
import glob
import cv2
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split

from unet_model import UNet

# Define Dataset Class
This section defines the `MammoDataset` class, which is responsible for loading and preprocessing mammogram images and their corresponding masks.

In [7]:
# -------------------- Dataset --------------------
class MammoDataset(Dataset):
    def __init__(self, image_paths, mask_paths):
        assert len(image_paths) == len(mask_paths), "Image and mask count mismatch"
        self.image_paths = image_paths
        self.mask_paths = mask_paths

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = cv2.imread(self.image_paths[idx], cv2.IMREAD_GRAYSCALE)
        mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)

        image = image.astype(np.float32) / 255.0
        mask = mask.astype(np.float32) / 255.0

        image = np.expand_dims(image, axis=0)  # Add channel dimension
        mask = np.expand_dims(mask, axis=0)

        return torch.tensor(image), torch.tensor(mask)

# Load and Prepare Data
This section loads the mammogram images and masks, applies transformations, and splits the dataset into training and validation sets.

In [8]:
# -------------------- Load Data --------------------
image_dir = "../../data/train/inputs"
mask_dir = "../../data/train/masks"
image_paths = sorted(glob.glob(os.path.join(image_dir, "*.png")))
mask_paths = sorted(glob.glob(os.path.join(mask_dir, "*.png")))
print(len(image_paths), len(mask_paths))
dataset = MammoDataset(image_paths, mask_paths)

print(len(dataset))
print(dataset[0])

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4)


28 28
28
(tensor([[[0.9020, 0.9137, 0.9373,  ..., 0.0000, 0.0000, 0.0000],
         [0.9059, 0.9176, 0.9373,  ..., 0.0000, 0.0000, 0.0000],
         [0.9098, 0.9176, 0.9373,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.9765, 1.0000, 1.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.9765, 1.0000, 1.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.9765, 1.0000, 1.0000,  ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]]))


# Initialize Model
This section initializes the U-Net model, sets the device (CPU or GPU), and defines the loss function and optimizer.

In [9]:
# -------------------- Model Setup --------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
model = UNet().to(device)

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

Using device: cuda


# Training Loop
This section contains the training loop, which iterates over the dataset for a specified number of epochs, calculates losses, and saves the best model.

In [10]:
# -------------------- Training --------------------
epochs = 50
best_loss = float("inf")

for epoch in range(epochs):
    model.train()
    train_loss = 0
    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.to(device)

        preds = model(images)
        loss = criterion(preds, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)
            preds = model(images)
            loss = criterion(preds, masks)
            val_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    # Save best model
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), "best_unet.pth")


Epoch 1/50, Train Loss: 3.0505, Val Loss: 1.3913
Epoch 2/50, Train Loss: 2.0052, Val Loss: 1.3833
Epoch 3/50, Train Loss: 1.7765, Val Loss: 1.3492
Epoch 4/50, Train Loss: 1.5662, Val Loss: 1.3068
Epoch 5/50, Train Loss: 1.4658, Val Loss: 1.2249
Epoch 6/50, Train Loss: 1.3975, Val Loss: 1.0852
Epoch 7/50, Train Loss: 1.3041, Val Loss: 0.8910
Epoch 8/50, Train Loss: 1.2973, Val Loss: 0.6924
Epoch 9/50, Train Loss: 1.2057, Val Loss: 0.5830
Epoch 10/50, Train Loss: 1.2011, Val Loss: 0.6400
Epoch 11/50, Train Loss: 1.1295, Val Loss: 0.4631
Epoch 12/50, Train Loss: 1.1892, Val Loss: 0.5933
Epoch 13/50, Train Loss: 1.1323, Val Loss: 0.3884
Epoch 14/50, Train Loss: 1.1277, Val Loss: 0.4479
Epoch 15/50, Train Loss: 1.1127, Val Loss: 0.3592
Epoch 16/50, Train Loss: 1.0230, Val Loss: 0.3620
Epoch 17/50, Train Loss: 1.0141, Val Loss: 0.4499
Epoch 18/50, Train Loss: 0.9839, Val Loss: 0.3528
Epoch 19/50, Train Loss: 0.9731, Val Loss: 0.3523
Epoch 20/50, Train Loss: 0.9121, Val Loss: 0.5272
Epoch 21/