In [20]:
#SETUP

import random, time
from pathlib import Path
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from torchvision.transforms import functional as TF
from torchvision import models
from torchvision.transforms import InterpolationMode as IM
import matplotlib.pyplot as plt
from datasets import load_from_disk


from datasets import load_from_disk

dataset = load_from_disk(r"C:\Users\arnav_vckkum5\OneDrive\coral-data\coralscapesdata")


In [21]:

# converts any input image into an RGB PIL image
def to_pil(img):
    # if its already PIL, keep it
    if isinstance(img, Image.Image): 
        return img
    # otherwise, make it an array
    arr = np.asarray(img)
    # if its a 2d gray scale image without color, make it grayscale PIL
    if arr.ndim == 2: 
        return Image.fromarray(arr.astype(np.uint8))
    # convert to PIL data type
    if arr.dtype != np.uint8:
        arr = np.clip(arr, 0, 255).astype(np.uint8)
    # convert values in array to valid rgb range and the convert to PIL-compatible type
    # If image array is not in the 8 bit range, clip all values to the valid 8-bit range for PIL handling
    if arr.ndim == 3 and arr.shape[2] == 4:  # RGBA -> RGB
        arr = arr[:, :, :3]
    # if images are RGBA, with an alpha transparency channel, drop that and make it sole RGB
    return Image.fromarray(arr)
    # turn array into PIL object

# convert segmentation masks into grayscale PIL images 
def mask_to_pil(mask): 
    if isinstance(mask, Image.Image):
        return mask.convert("L")
    # if mask is already PIL, make sure it's in grayscale mode 
    arr = np.asarray(mask)
    if arr.dtype != np.uint8:
        arr = arr.astype(np.uint8)
    # array should be numpy and 8-bit        
    return Image.fromarray(arr, mode="L") # return PIL grayscale image

# function to retrieve and confirm the number of seg. classes in dataset 
# sample limit: only scans first N integers for time saving
# hf train split selects the example: ex. dataset['train']
def get_num_classes(hf_train_split, sample_limit=None):
    """Scan masks to find max label and infer num_classes."""
    max_label = 0 # placeholder to keep track of max class 
    # How many samples to go over, all if sample limit is none
    N = len(hf_train_split) if sample_limit is None else min(sample_limit, len(hf_train_split))
    for i in range(N): 
        m = hf_train_split[i]["label"]
        m = np.asarray(m)
        # max label gets the largest pixel segmentation value found for the label it is cycling through
        max_label = max(max_label, int(m.max()))
    return max_label + 1


In [22]:
# SEGFORMER GOALS:
# 1. Every image and mask are the same size
# 2. Training data gets random augmentations (to improve generalization)
# 3. Validation/testing data is preprocessed deterministically (no randomness)
# 4. Both image and mask transformations stay perfectly aligned pixel-for-pixel
from torchvision import transforms as T
from torchvision.transforms import functional as TF, InterpolationMode as IM

class SegTransform:
    def __init__(self, size=512, crop_size=512, is_train=True):
        self.size = size
        self.crop_size = crop_size
        self.is_train = is_train
        self.color_jitter = T.ColorJitter(
            brightness=0.2,  # 20% brighter/darker
            contrast=0.2,    # 20% contrast variation
            saturation=0.2,  # 20% saturation shift
            hue=0.05         # small hue jitter
        )

    def __call__(self, img, mask):
        img = to_pil(img).convert("RGB")
        mask = mask_to_pil(mask)

        # Resize
        img  = TF.resize(img,  self.size, interpolation=IM.BILINEAR)
        mask = TF.resize(mask, self.size, interpolation=IM.NEAREST)

        if self.is_train and self.crop_size is not None:
            # >>> this is the correct way <<<
            i, j, h, w = T.RandomCrop.get_params(img, output_size=(self.crop_size, self.crop_size))
            img  = TF.crop(img,  i, j, h, w)
            mask = TF.crop(mask, i, j, h, w)

            if random.random() < 0.5:
                img  = TF.hflip(img)
                mask = TF.hflip(mask)
            
            img = self.color_jitter(img)


        else:
            if self.crop_size is not None and self.crop_size < self.size:
                img  = TF.center_crop(img,  self.crop_size)
                mask = TF.center_crop(mask, self.crop_size)

        img = TF.to_tensor(img)
        img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        mask = torch.from_numpy(np.array(mask, dtype=np.int64))
        return img, mask


In [23]:
class HFDatasetWrapper(Dataset):
    def __init__(self, hf_split, transform: SegTransform):
        self.ds = hf_split
        self.t = transform

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        row = self.ds[idx]
        img, mask = row["image"], row["label"]
        img, mask = self.t(img, mask)
        return img, mask

def collate_fn(batch):
    imgs, masks = zip(*batch)
    imgs = torch.stack(imgs, dim=0)
    masks = torch.stack(masks, dim=0)
    return imgs, masks


In [24]:
def segmentation_metrics(logits, targets, num_classes):
    """
    Compute pixel accuracy and mIoU for a batch.
    logits: (B, C, H, W) raw
    targets: (B, H, W) long
    """
    preds = logits.argmax(1)  # (B,H,W)
    valid = (targets >= 0)  # assume all valid; adjust if you use ignore_index
    correct = (preds[valid] == targets[valid]).sum().item()
    total = valid.sum().item()
    pix_acc = correct / max(1, total)

    # IoU per class
    ious = []
    for c in range(num_classes):
        pred_c = (preds == c)
        targ_c = (targets == c)
        inter = (pred_c & targ_c).sum().item()
        union = (pred_c | targ_c).sum().item()
        if union == 0:
            continue
        ious.append(inter / union)
    miou = float(np.mean(ious)) if ious else 0.0
    return pix_acc, miou

In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Infer class count (full scan, set sample_limit to speed up if huge)
num_classes = get_num_classes(dataset["train"], sample_limit=None)
print("Inferred num_classes:", num_classes)

# Transforms (tweak sizes to fit GPU)
TRAIN_SIZE = 320  #544
CROP = 320   #512f
train_tf = SegTransform(size=TRAIN_SIZE, crop_size=CROP, is_train=True)
val_tf   = SegTransform(size=TRAIN_SIZE, crop_size=CROP, is_train=False)
test_tf  = SegTransform(size=TRAIN_SIZE, crop_size=CROP, is_train=False)

train_set = HFDatasetWrapper(dataset["train"], train_tf)
val_set   = HFDatasetWrapper(dataset["validation"], val_tf)
test_set  = HFDatasetWrapper(dataset["test"], test_tf)

train_loader = DataLoader(train_set, batch_size=3, shuffle=True, num_workers=0, pin_memory=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_set,   batch_size=4, shuffle=False, num_workers=0, pin_memory=True, collate_fn=collate_fn)
test_loader  = DataLoader(test_set,  batch_size=4, shuffle=False, num_workers=0, pin_memory=True, collate_fn=collate_fn)

Device: cpu
Inferred num_classes: 40


In [28]:
#choose backbone
model = models.segmentation.deeplabv3_resnet50(num_classes=num_classes)
model.to(device)

# loss
criterion = nn.CrossEntropyLoss(ignore_index=-1)

# Optimizer & scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=3e-4, steps_per_epoch=len(train_loader), epochs= 30
)

scaler = torch.amp.GradScaler(enabled=(device.type == "cuda"))


In [29]:
def run_epoch(model, loader, train=True):
    if train:
        model.train()
    else:
        model.eval()

    total_loss, total_pix, total_iou, total_batches = 0.0, 0.0, 0.0, 0
    for imgs, masks in loader:
        imgs, masks = imgs.to(device), masks.to(device)

        with torch.set_grad_enabled(train):
            out = model(imgs)["out"]
            loss = criterion(out, masks)

            if train:
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                optimizer.step()
                scheduler.step()

        pix_acc, miou = segmentation_metrics(out.detach().cpu(), masks.cpu(), num_classes)
        total_loss += loss.item()
        total_pix += pix_acc
        total_iou += miou
        total_batches += 1

    return {
        "loss": total_loss / max(1, total_batches),
        "pix_acc": total_pix / max(1, total_batches),
        "miou": total_iou / max(1, total_batches),    
    }


In [30]:
from torchvision import transforms as T
from torchvision.transforms import functional as TF, InterpolationMode as IM

class SegTransform:
    def __init__(self, size=512, crop_size=512, is_train=True):
        self.size = size
        self.crop_size = crop_size
        self.is_train = is_train

    def __call__(self, img, mask):
        img  = to_pil(img).convert("RGB")
        mask = mask_to_pil(mask)

        img  = TF.resize(img,  self.size, interpolation=IM.BILINEAR)
        mask = TF.resize(mask, self.size, interpolation=IM.NEAREST)

        if self.is_train and self.crop_size is not None:
            i, j, h, w = T.RandomCrop.get_params(img, (self.crop_size, self.crop_size))
            img  = TF.crop(img,  i, j, h, w)
            mask = TF.crop(mask, i, j, h, w)
            if random.random() < 0.5:
                img  = TF.hflip(img)
                mask = TF.hflip(mask)
        elif self.crop_size and self.crop_size < self.size:
            img  = TF.center_crop(img,  self.crop_size)
            mask = TF.center_crop(mask, self.crop_size)

        img  = TF.to_tensor(img)
        img  = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        mask = torch.from_numpy(np.array(mask, dtype=np.int64))
        return img, mask


In [32]:
EPOCHS = 30 # 25
best_miou = -1
save_path = r"c:\Users\arnav_vckkum5\OneDrive\coral-data\modelversions\deeplabv3_resnet50_coral.pth"

for epoch in range(1, EPOCHS + 1):
    t0 = time.time()
    tr = run_epoch(model, train_loader, train=True)
    va = run_epoch(model, val_loader, train=False)
    dt = time.time() - t0
    print(f"[{epoch:02d}/{EPOCHS}] "
          f"train: loss={tr['loss']:.4f} acc={tr['pix_acc']:.3f} mIoU={tr['miou']:.3f} | "
          f"val: loss={va['loss']:.4f} acc={va['pix_acc']:.3f} mIoU={va['miou']:.3f} | "
          f"{dt:.1f}s")

    # save best by val mIoU
    if va["miou"] > best_miou:
        best_miou = va["miou"]
        torch.save({"model": model.state_dict(),
                    "num_classes": num_classes}, save_path)
        print(f"  ↳ Saved checkpoint: {save_path} (best val mIoU {best_miou:.3f})")



[01/30] train: loss=2.8194 acc=0.346 mIoU=0.063 | val: loss=2.1522 acc=0.485 mIoU=0.096 | 5079.9s
  ↳ Saved checkpoint: c:\Users\arnav_vckkum5\OneDrive\coral-data\modelversions\deeplabv3_resnet50_coral.pth (best val mIoU 0.096)
[02/30] train: loss=2.0596 acc=0.458 mIoU=0.130 | val: loss=1.9461 acc=0.462 mIoU=0.117 | 5116.5s
  ↳ Saved checkpoint: c:\Users\arnav_vckkum5\OneDrive\coral-data\modelversions\deeplabv3_resnet50_coral.pth (best val mIoU 0.117)
[03/30] train: loss=1.7111 acc=0.502 mIoU=0.156 | val: loss=1.5180 acc=0.551 mIoU=0.143 | 5129.4s
  ↳ Saved checkpoint: c:\Users\arnav_vckkum5\OneDrive\coral-data\modelversions\deeplabv3_resnet50_coral.pth (best val mIoU 0.143)
[04/30] train: loss=1.5937 acc=0.516 mIoU=0.167 | val: loss=1.7648 acc=0.477 mIoU=0.128 | 5166.0s
[05/30] train: loss=1.5167 acc=0.529 mIoU=0.177 | val: loss=1.5519 acc=0.539 mIoU=0.144 | 5161.2s
  ↳ Saved checkpoint: c:\Users\arnav_vckkum5\OneDrive\coral-data\modelversions\deeplabv3_resnet50_coral.pth (best val mI

In [None]:
ckpt = torch.load(save_path, map_location="cpu")
model.load_state_dict(ckpt["model"])
model.eval()

test_metrics = run_epoch(model, test_loader, train=False)
print(f"TEST: loss={test_metrics['loss']:.4f} acc={test_metrics['pix_acc']:.3f} mIoU={test_metrics['miou']:.3f}")

TEST: loss=1.0980 acc=0.635 mIoU=0.230
