In [21]:
import torch
from torch import nn
from torch.nn import functional as F
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
import time


In [10]:
class CarvanaDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg", "_mask.gif"))
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

In [11]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)
    

class UNET(nn.Module):
    def __init__(
            self,
            in_channels=3, 
            out_channels=1,
            features=[64, 128, 256, 512],
    ):
        super().__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))

            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:], mode="bilinear", align_corners=True)
            concat_skip = torch.cat((skip_connection, x), dim=1) # along channels
            x = self.ups[idx + 1](concat_skip)

        return self.final_conv(x)

In [12]:
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_maskdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
):
    train_ds = CarvanaDataset(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform=train_transform,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_ds = CarvanaDataset(
        image_dir=val_dir,
        mask_dir=val_maskdir,
        transform=val_transform,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return train_loader, val_loader

def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()

def save_predictions_as_imgs(
    loader, model, folder="saved_images/", device="cuda"
):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()

In [None]:
# Hyperparameters 
LEARNING_RATE = 1e-4
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    print("Using CUDA (GPU)")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    print("Using MPS (Apple Silicon GPU)")
else:
    DEVICE = torch.device("cpu")
    print("Using CPU") 
BATCH_SIZE = 2
NUM_EPOCHS = 2
NUM_WORKERS = 0
IMAGE_HEIGHT = 600  # 1280 originally
IMAGE_WIDTH = 900  # 1918 originally
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = "./train/"
TRAIN_MASK_DIR = "./train_masks/"
VAL_IMG_DIR = "./val/"
VAL_MASK_DIR = "./val_masks/"

def train_fn(loader, model, optimizer, loss_fn):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        predictions = model(data)
        loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())


def main():
    train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    val_transforms = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY,
    )

    if LOAD_MODEL:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)


    check_accuracy(val_loader, model, device=DEVICE)

    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn)

        # save model
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer":optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)

        # check accuracy
        check_accuracy(val_loader, model, device=DEVICE)

        # print some examples to a folder
        save_predictions_as_imgs(
            val_loader, model, folder="saved_images/", device=DEVICE
        )



main()

In [22]:
def calculate_dice_score(predicted_mask, ground_truth_mask):
    """Calculates the Dice score between two binary masks."""
    intersection = (predicted_mask * ground_truth_mask).sum()
    dice_score = (2 * intersection) / (
        torch.sum(predicted_mask) + torch.sum(ground_truth_mask) + 1e-8
    )
    return dice_score.item()

def preprocess_image(image_path, input_size, device):
    """Loads and preprocesses an image for model input."""
    image = Image.open(image_path).convert('RGB')
    
    transform = transforms.Compose([
        transforms.Resize(input_size), 
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]) 
    ])
    
    return transform(image).unsqueeze(0).to(device)

def load_original_ground_truth_mask(mask_path, device):
    """Loads the ground truth mask at its original size."""
    mask_image = Image.open(mask_path).convert('L')

    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    original_size = mask_image.size # (width, height)
    original_hw = (original_size[1], original_size[0]) 
    return transform(mask_image).unsqueeze(0).to(device), original_hw

def evaluate_model(model, image_dir, mask_dir, image_filename, mask_filename, num_classes, device, scaling_factors):
    """
    Evaluates the model at different input scales.
    Inference time is measured on the scaled input.
    Dice score is calculated by resizing the predicted mask back to the original GT size.
    """
    original_image_path = os.path.join(image_dir, image_filename)
    original_mask_path = os.path.join(mask_dir, mask_filename)

    # --- Load Original Ground Truth Mask (ONCE) ---
    print("Loading original ground truth mask...")
    original_gt_mask_tensor, original_mask_hw = load_original_ground_truth_mask(original_mask_path, device)
    print(f"Original mask size (H, W): {original_mask_hw}")

    # Get original image size for scaling calculations
    original_image_pil = Image.open(original_image_path)
    original_width, original_height = original_image_pil.size


    # --- Dummy Inference Run (Warm-up) ---
    print("Performing dummy inference run to warm-up...")
    # Use a representative size, e.g., the largest scale factor or original size
    # Here using original size, but could use max scaled size too.
    dummy_input_size = (original_height, original_width)
    dummy_image_tensor = preprocess_image(original_image_path, dummy_input_size, device)
    with torch.no_grad():
        _ = model(dummy_image_tensor) # Just run forward pass, no timing or storage
    torch.cuda.synchronize() if device.type == 'cuda' else None # Ensure completion if on CUDA
    print("Warm-up complete.")
    # --- End Dummy Inference Run ---

    results = {}

    for scale_factor in scaling_factors:
        # Calculate scaled input size based on original image dimensions
        scaled_height = int(original_height * scale_factor)
        scaled_width = int(original_width * scale_factor)
        input_size = (scaled_height, scaled_width) # H, W for transforms.Resize

        print(f"\n--- Scaling Factor: {scale_factor:.2f} (Input Size: {input_size}) ---")

        # Preprocess the image to the current scaled size
        image_tensor = preprocess_image(original_image_path, input_size, device)

        # --- Inference ---
        start_time = time.time()
        with torch.no_grad():
            # Model outputs prediction at the scaled size
            predicted_mask_scaled = torch.sigmoid(model(image_tensor))
        torch.cuda.synchronize() if device.type == 'cuda' else None # Ensure completion for accurate timing
        inference_time = time.time() - start_time

        # --- Resize Predicted Mask to Original GT Size ---
        # predicted_mask_scaled shape: [1, C, H_scaled, W_scaled]
        # original_mask_hw: (H_original, W_original)
        print(f"Using technique bicubic")
        predicted_mask_original_size = F.interpolate(
            predicted_mask_scaled,
            size=original_mask_hw, # Target size (H, W)
            mode='bicubic',       
            # align_corners=False    # Recommended setting
        )

        # --- Threshold and Calculate Dice Score ---
        # Threshold the *resized* prediction
        predicted_mask_binary = (predicted_mask_original_size > 0.5).float()

        # Calculate Dice against the *original sized* ground truth mask
        dice_score = calculate_dice_score(predicted_mask_binary, original_gt_mask_tensor)

        # Store results
        results[scale_factor] = {
            "inference_time": inference_time,
            "dice_score": dice_score,
            "input_size": input_size
        }

        print(f"Input Size (H, W): {input_size}")
        print(f"Inference Time: {inference_time:.4f} seconds")
        print(f"Dice score (vs Original GT): {dice_score:.4f}")


    print("\n--- Summary Results ---")
    for scale_factor, metrics in results.items():
        print(f"Scaling Factor: {scale_factor:.2f}, Input Size: {metrics['input_size']}, Inference Time: {metrics['inference_time']:.4f}s, Dice Score: {metrics['dice_score']:.4f}")


if __name__ == '__main__':
    image_dir = './train/'
    mask_dir = './train_masks/'
    image_filename = '0cdf5b5d0ce1_01.jpg'
    mask_filename = '0cdf5b5d0ce1_01_mask.gif'
    num_classes = 1

    # --- Device Setup ---
    if torch.cuda.is_available():
        DEVICE = torch.device("cuda")
        print("Using CUDA (GPU)")
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        DEVICE = torch.device("mps")
        print("Using MPS (Apple Silicon GPU)")
    else:
        DEVICE = torch.device("cpu")
        print("Using CPU")

    # --- Model Loading ---
    print("Loading model...")
    model = UNET(in_channels=3, out_channels=num_classes)
    try:
        checkpoint_path = "my_checkpoint.pth.tar"
        if not os.path.exists(checkpoint_path):
             raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
        load_checkpoint(torch.load(checkpoint_path, map_location=DEVICE), model) 
        print(f"Checkpoint '{checkpoint_path}' loaded successfully.")
    except FileNotFoundError as e:
        print(f"Error: {e}")
        print("Please ensure the checkpoint file exists and the path is correct.")
        exit() # Exit if checkpoint is missing
    except Exception as e:
        print(f"An error occurred during model loading: {e}")
        exit()


    model.to(DEVICE)
    model.eval() 
    scaling_factors = [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5] 

    evaluate_model(model, image_dir, mask_dir, image_filename, mask_filename, num_classes, DEVICE, scaling_factors)

Using MPS (Apple Silicon GPU)
Loading model...
=> Loading checkpoint
Checkpoint 'my_checkpoint.pth.tar' loaded successfully.
Loading original ground truth mask...
Original mask size (H, W): (1280, 1918)
Performing dummy inference run to warm-up...
Warm-up complete.

--- Scaling Factor: 0.10 (Input Size: (128, 191)) ---
Using technique bicubic
Input Size (H, W): (128, 191)
Inference Time: 0.4250 seconds
Dice score (vs Original GT): 0.8426

--- Scaling Factor: 0.15 (Input Size: (192, 287)) ---
Using technique bicubic
Input Size (H, W): (192, 287)
Inference Time: 0.1625 seconds
Dice score (vs Original GT): 0.9190

--- Scaling Factor: 0.20 (Input Size: (256, 383)) ---
Using technique bicubic
Input Size (H, W): (256, 383)
Inference Time: 0.1653 seconds
Dice score (vs Original GT): 0.9445

--- Scaling Factor: 0.25 (Input Size: (320, 479)) ---
Using technique bicubic
Input Size (H, W): (320, 479)
Inference Time: 0.1596 seconds
Dice score (vs Original GT): 0.9866

--- Scaling Factor: 0.30 (Inp