# FCN Semantic Segmentation 



In [None]:


import os, random, time, csv
from glob import glob
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from PIL import Image, ImageDraw

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset

import torchvision
import torchvision.transforms as T
import torchvision.models as models


RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# Dataset config: choose 'synthetic' or 'local'
DATA_SOURCE = 'synthetic'   

# Synthetic dataset params
NUM_IMAGES = 80   
IMG_SIZE = 128
BATCH_SIZE = 8

# Model & training
NUM_CLASSES = 2   # binary synthetic masks: background (0), foreground (1)
BACKBONE = 'resnet34'   
MODE_LIST = ['32s','16s','8s']  
UPSAMPLE_MODES = ['transpose','bilinear'] 
EPOCHS = 20 
LR = 1e-4
WEIGHT_DECAY = 1e-4
SAVE_DIR = 'fcn_experiments'
os.makedirs(SAVE_DIR, exist_ok=True)

# ImageNet normalization 
IMAGENET_MEAN = np.array([0.485,0.456,0.406], dtype=np.float32)
IMAGENET_STD  = np.array([0.229,0.224,0.225], dtype=np.float32)


# Minimal Dataset Loader

class LocalSmallVOC(Dataset):
   

    def __init__(self, images_dir, masks_dir, size=IMG_SIZE, rgb_to_index_map=None, extensions=('jpg','jpeg','png')):
        self.images = sorted([p for p in glob(os.path.join(images_dir, '*')) if p.split('.')[-1].lower() in extensions])
        self.masks  = sorted([p for p in glob(os.path.join(masks_dir, '*')) if p.split('.')[-1].lower() in extensions])
        # try to pair by basename if counts differ
        if len(self.images) != len(self.masks):
            imgs_by_name = {Path(p).stem: p for p in self.images}
            masks_by_name = {Path(p).stem: p for p in self.masks}
            common = sorted(set(imgs_by_name.keys()) & set(masks_by_name.keys()))
            self.images = [imgs_by_name[n] for n in common]
            self.masks  = [masks_by_name[n] for n in common]
        assert len(self.images) == len(self.masks) and len(self.images) > 0, f"No paired images/masks found in {images_dir}, {masks_dir}"
        self.size = size
        self.rgb_to_index_map = rgb_to_index_map
        self.img_tr = T.Compose([T.Resize((size,size)), T.ToTensor(), T.Normalize(IMAGENET_MEAN.tolist(), IMAGENET_STD.tolist())])
        self.mask_tr = T.Compose([T.Resize((size,size), interpolation=T.InterpolationMode.NEAREST)])

    def __len__(self):
        return len(self.images)

    def _load_mask_index(self, path):
        m = Image.open(path)
        # palette
        if m.mode == 'P':
            arr = np.array(m.convert('L'), dtype=np.int64)
            return arr
        if m.mode == 'L':
            return np.array(m, dtype=np.int64)
        if m.mode == 'RGB':
            arr = np.array(m, dtype=np.uint8)
            if self.rgb_to_index_map is not None:
                H,W,_ = arr.shape
                out = np.zeros((H,W), dtype=np.int64) + 255
                for rgb, cid in self.rgb_to_index_map.items():
                    rgb = np.array(rgb, dtype=np.uint8)
                    match = np.all(arr == rgb.reshape(1,1,3), axis=2)
                    out[match] = cid
                return out
            else:
                # best-effort: collapse R channel (not ideal)
                return arr[:,:,0].astype(np.int64)
        return np.array(m.convert('L'), dtype=np.int64)

    def __getitem__(self, idx):
        img_p = self.images[idx]
        mask_p = self.masks[idx]
        img = Image.open(img_p).convert('RGB')
        mask_arr = self._load_mask_index(mask_p)
        img_t = self.img_tr(img)
        mask_img = Image.fromarray(mask_arr.astype(np.uint8))
        mask_t = self.mask_tr(mask_img)
        mask_t = torch.as_tensor(np.array(mask_t), dtype=torch.long)
        if mask_t.dim() == 3:
            mask_t = mask_t[...,0]
        return img_t, mask_t


# Synthetic dataset builder (binary shapes)

def make_image_with_shape(size=IMG_SIZE, max_shapes=3):
    img = Image.new('RGB', (size,size), (128,128,128))
    mask = Image.new('L', (size,size), 0)
    draw = ImageDraw.Draw(img)
    mdraw = ImageDraw.Draw(mask)
    n = random.randint(1, max_shapes)
    for _ in range(n):
        shape_type = random.choice(['rect','ellipse','triangle'])
        x1 = random.randint(5, size//3)
        y1 = random.randint(5, size//3)
        x2 = random.randint(size//2, size-5)
        y2 = random.randint(size//2, size-5)
        color = tuple(np.random.randint(50, 230, size=3).tolist())
        if shape_type == 'rect':
            draw.rectangle([x1,y1,x2,y2], fill=color)
            mdraw.rectangle([x1,y1,x2,y2], fill=1)
        elif shape_type == 'ellipse':
            draw.ellipse([x1,y1,x2,y2], fill=color)
            mdraw.ellipse([x1,y1,x2,y2], fill=1)
        else:
            pts = [(x1,y2), ((x1+x2)//2, y1), (x2,y2)]
            draw.polygon(pts, fill=color)
            mdraw.polygon(pts, fill=1)
    return np.array(img), (np.array(mask) // 255).astype(np.uint8)

class SyntheticSegDataset(Dataset):
    def __init__(self, imgs_np, masks_np, normalize=True):
        # imgs_np: N,H,W,3 uint8; masks_np: N,H,W (0/1)
        self.normalize = normalize
        imgs = imgs_np.astype('float32') / 255.0
        if normalize:
            imgs = (imgs - IMAGENET_MEAN.reshape(1,1,3)) / IMAGENET_STD.reshape(1,1,3)
        self.imgs = imgs
        self.masks = masks_np.astype('int64')
    def __len__(self): return len(self.imgs)
    def __getitem__(self, idx):
        im = self.imgs[idx]
        m = self.masks[idx]
        im_t = torch.from_numpy(im.transpose(2,0,1)).float()
        m_t = torch.from_numpy(m).long()
        return im_t, m_t


# Backbone & FCN implementation

def make_resnet_backbone(name='resnet34', pretrained=True):
    # returns stages s0 (/4), s1 (/4), s2 (/8), s3 (/16), s4 (/32)
  
    try:
        model = getattr(models, name)(weights='IMAGENET1K_V1' if pretrained else None)
    except Exception:
        # older versions
        model = getattr(models, name)(pretrained=pretrained)
    s0 = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool)  # /4
    s1 = model.layer1
    s2 = model.layer2
    s3 = model.layer3
    s4 = model.layer4
    return s0, s1, s2, s3, s4

class FCNResNet(nn.Module):
    def __init__(self, backbone_name='resnet34', num_classes=21, mode='8s', upsample_mode='bilinear'):
        super().__init__()
        assert mode in ['32s','16s','8s']
        assert upsample_mode in ['transpose','bilinear']
        self.mode = mode
        self.upsample_mode = upsample_mode

        self.s0, self.s1, self.s2, self.s3, self.s4 = make_resnet_backbone(backbone_name, pretrained=True)

        self.score4 = nn.Conv2d(512, num_classes, kernel_size=1)
        if mode in ['16s','8s']:
            self.score3 = nn.Conv2d(256, num_classes, kernel_size=1)
        if mode == '8s':
            self.score2 = nn.Conv2d(128, num_classes, kernel_size=1)

        if upsample_mode == 'transpose':
            # convtranspose layer choices chosen to approximate the upsample factors
            self.up32 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=64, stride=32, padding=16, bias=False)
            self.up16 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=32, stride=16, padding=8, bias=False)
            self.up8  = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=16, stride=8, padding=4, bias=False)
        else:
            # use nn.Upsample modules (so they move to device with model.to(device))
            self.up32 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=False)
            self.up16 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=False)
            self.up8  = nn.Upsample(scale_factor=8,  mode='bilinear', align_corners=False)

    def forward(self, x):
        h0 = self.s0(x)   # /4
        h1 = self.s1(h0)  # /4
        h2 = self.s2(h1)  # /8
        h3 = self.s3(h2)  # /16
        h4 = self.s4(h3)  # /32

        s4 = self.score4(h4)
        if self.mode == '32s':
            out = self.up32(s4)
            return self._match_input(out, x)

        if self.mode == '16s':
            s4_up2 = F.interpolate(s4, scale_factor=2, mode='bilinear', align_corners=False)
            s3 = self.score3(h3)
            fuse = s4_up2 + s3
            out = self.up16(fuse)
            return self._match_input(out, x)

        # mode == '8s'
        s3 = self.score3(h3)
        s2 = self.score2(h2)
        s4_up2 = F.interpolate(s4, scale_factor=2, mode='bilinear', align_corners=False)
        fuse3 = s4_up2 + s3
        fuse3_up2 = F.interpolate(fuse3, scale_factor=2, mode='bilinear', align_corners=False)
        fuse2 = fuse3_up2 + s2
        out = self.up8(fuse2)
        return self._match_input(out, x)

    def _match_input(self, out, x):
        if out.shape[2:] != x.shape[2:]:
            out = F.interpolate(out, size=x.shape[2:], mode='bilinear', align_corners=False)
        return out


# Metrics: Pixel accuracy & robust mean IoU

def pixel_accuracy(pred, target, ignore_index=255):
    # pred: logits (B,C,H,W), target: (B,H,W)
    pred_label = pred.argmax(dim=1)
    valid = (target != ignore_index)
    correct = (pred_label == target) & valid
    correct_count = correct.sum().item()
    total = valid.sum().item()
    return correct_count / total if total > 0 else 0.0

def mean_iou(pred, target, num_classes=NUM_CLASSES, ignore_index=255):
    pred_label = pred.argmax(dim=1)
    miou_list = []
    for cls in range(num_classes):
        p = (pred_label == cls)
        t = (target == cls)
        # mask out ignore_index in target
        t = t & (target != ignore_index)
        inter = (p & t).sum().item()
        union = (p | t).sum().item()
        if union == 0:
            # skip class (no presence in gt & pred)
            continue
        miou_list.append(inter / union)
    if len(miou_list) == 0:
        return 0.0
    return float(np.mean(miou_list))


# Training & evaluation helpers

def evaluate_model(model, loader, num_classes=NUM_CLASSES, ignore_index=255):
    model.eval()
    total_loss = 0.0
    total_acc = 0.0
    total_miou = 0.0
    batches = 0
    criterion = nn.CrossEntropyLoss(ignore_index=ignore_index)
    with torch.no_grad():
        for imgs, masks in loader:
            imgs = imgs.to(DEVICE)
            masks = masks.to(DEVICE)
            logits = model(imgs)
            loss = criterion(logits, masks)
            total_loss += loss.item()
            total_acc += pixel_accuracy(logits, masks, ignore_index=ignore_index)
            total_miou += mean_iou(logits, masks, num_classes=num_classes, ignore_index=ignore_index)
            batches += 1
    if batches == 0:
        return 0,0,0
    return total_loss / batches, total_acc / batches, total_miou / batches

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    acc_sum = 0.0
    iou_sum = 0.0
    batches = 0
    for imgs, masks in loader:
        imgs = imgs.to(DEVICE)
        masks = masks.to(DEVICE)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, masks)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        acc_sum += pixel_accuracy(logits, masks)
        iou_sum += mean_iou(logits, masks)
        batches += 1
    if batches == 0:
        return 0,0,0
    return running_loss / batches, acc_sum / batches, iou_sum / batches

# Prepare dataset (local or synthetic)


print("Generating synthetic dataset:", NUM_IMAGES, "images of size", IMG_SIZE)
images, masks = [], []
for i in range(NUM_IMAGES):
    im, m = make_image_with_shape(size=IMG_SIZE, max_shapes=3)
    images.append(im)
    masks.append(m)
images = np.stack(images)        # (N,H,W,3) uint8
masks = np.stack(masks)          # (N,H,W) 0/1
# create dataset objects
full_ds = SyntheticSegDataset(images, masks, normalize=True)


# split train/val
N = len(full_ds)
indices = list(range(N))
random.shuffle(indices)
split = int(0.8 * N)
train_idx = indices[:split]
val_idx   = indices[split:]
train_ds = Subset(full_ds, train_idx)
val_ds   = Subset(full_ds, val_idx)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=False)
val_loader   = DataLoader(val_ds,   batch_size=1,           shuffle=False, num_workers=0, pin_memory=False)

print(f"Dataset total={N}, train={len(train_ds)}, val={len(val_ds)}, batch={BATCH_SIZE}")


# Main experiments loop: for each FCN mode and upsample mode

summary_rows = []
experiment_count = 0
for mode in MODE_LIST:
    for up_mode in UPSAMPLE_MODES:
        experiment_count += 1
        print(f"\n=== Experiment {experiment_count}: FCN-{mode} | upsample={up_mode} ===")
        model = FCNResNet(backbone_name=BACKBONE, num_classes=NUM_CLASSES, mode=mode, upsample_mode=up_mode).to(DEVICE)
        optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
        criterion = nn.CrossEntropyLoss(ignore_index=255)

        history = {'train_loss':[], 'train_acc':[], 'train_miou':[], 'val_loss':[], 'val_acc':[], 'val_miou':[]}

        start_time = time.time()
        for epoch in range(EPOCHS):
            tr_loss, tr_acc, tr_iou = train_one_epoch(model, train_loader, optimizer, criterion)
            va_loss, va_acc, va_iou = evaluate_model(model, val_loader, num_classes=NUM_CLASSES, ignore_index=255)
            history['train_loss'].append(tr_loss)
            history['train_acc'].append(tr_acc)
            history['train_miou'].append(tr_iou)
            history['val_loss'].append(va_loss)
            history['val_acc'].append(va_acc)
            history['val_miou'].append(va_iou)
            if (epoch % 2 == 0) or (epoch == EPOCHS-1):
                print(f"Epoch {epoch+1}/{EPOCHS} | tr_loss={tr_loss:.4f} tr_acc={tr_acc:.4f} tr_mIoU={tr_iou:.4f} || val_loss={va_loss:.4f} val_acc={va_acc:.4f} val_mIoU={va_iou:.4f}")
        elapsed = time.time() - start_time
        print(f"Finished experiment in {elapsed:.1f}s")

        # save model & history
        model_path = os.path.join(SAVE_DIR, f"fcn_{mode}_{up_mode}.pth")
        torch.save(model.state_dict(), model_path)
        # save history plot
        plt.figure(figsize=(10,4))
        plt.subplot(1,2,1)
        plt.plot(history['train_loss'], label='train_loss')
        plt.plot(history['val_loss'], label='val_loss', linestyle='--')
        plt.title(f'Loss ({mode}, {up_mode})'); plt.legend()
        plt.subplot(1,2,2)
        plt.plot(history['val_miou'], label='val_mIoU')
        plt.title('Validation mIoU'); plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(SAVE_DIR, f"curve_{mode}_{up_mode}.png"))
        plt.close()

        # evaluate final metrics
        final_val_loss, final_val_acc, final_val_miou = evaluate_model(model, val_loader, num_classes=NUM_CLASSES)
        summary_rows.append([mode, up_mode, final_val_loss, final_val_acc, final_val_miou, model_path])

        # store results
        results = {
            'mode': mode, 'upsample': up_mode, 'model': model, 'history': history
        }
        # visualize 3 examples from val set
        num_show = min(3, len(val_ds))
        if num_show > 0:
            plt.figure(figsize=(12,4*num_show))
            with torch.no_grad():
                for i in range(num_show):
                    img_t, mask_t = val_ds[i]
                    inp = img_t.unsqueeze(0).to(DEVICE)
                    out = model(inp)
                    pred = out.argmax(dim=1).squeeze(0).cpu().numpy()
                    # reconstruct image for display (undo normalization)
                    img_np = img_t.cpu().numpy().transpose(1,2,0)
                    img_np = (img_np * IMAGENET_STD.reshape(1,1,3)) + IMAGENET_MEAN.reshape(1,1,3)
                    img_np = np.clip(img_np, 0, 1)
                    ax = plt.subplot(num_show, 3, i*3+1); ax.imshow(img_np); ax.set_title('Image'); ax.axis('off')
                    ax = plt.subplot(num_show, 3, i*3+2); ax.imshow(mask_t.numpy(), cmap='gray'); ax.set_title('GT'); ax.axis('off')
                    ax = plt.subplot(num_show, 3, i*3+3); ax.imshow(pred, cmap='gray'); ax.set_title(f'Pred ({mode},{up_mode})'); ax.axis('off')
            plt.tight_layout()
            plt.savefig(os.path.join(SAVE_DIR, f'viz_{mode}_{up_mode}.png'))
            plt.close()


# Save summary table CSV and print a short table

summary_csv = os.path.join(SAVE_DIR, 'summary_table.csv')
with open(summary_csv, 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['mode','upsample','val_loss','val_pixel_acc','val_mIoU','model_path'])
    for r in summary_rows:
        writer.writerow(r)

print("\n=== Summary (final val metrics) ===")
print("mode\tupsample\tval_acc\t\tval_mIoU\tval_loss")
for r in summary_rows:
    print(f"{r[0]}\t{r[1]}\t{r[3]:.4f}\t\t{r[4]:.4f}\t{r[2]:.4f}")

print(f"\nModels, plots and visualizations saved to: {os.path.abspath(SAVE_DIR)}")
print("Notebook/Script complete. You can increase NUM_IMAGES and EPOCHS for more reliable results.")


Device: cpu
Generating synthetic dataset: 80 images of size 128
Dataset total=80, train=64, val=16, batch=8

=== Experiment 1: FCN-32s | upsample=transpose ===
Epoch 1/20 | tr_loss=0.6905 tr_acc=0.5476 tr_mIoU=0.2738 || val_loss=0.6857 val_acc=0.6389 val_mIoU=0.3195
Epoch 3/20 | tr_loss=0.6552 tr_acc=0.6835 tr_mIoU=0.3417 || val_loss=0.6237 val_acc=0.7459 val_mIoU=0.3729
Epoch 5/20 | tr_loss=0.5606 tr_acc=0.8465 tr_mIoU=0.4233 || val_loss=0.5118 val_acc=0.8953 val_mIoU=0.4476
Epoch 7/20 | tr_loss=0.4473 tr_acc=0.9535 tr_mIoU=0.4768 || val_loss=0.4063 val_acc=0.9582 val_mIoU=0.4791
Epoch 9/20 | tr_loss=0.3477 tr_acc=0.9770 tr_mIoU=0.4885 || val_loss=0.3171 val_acc=0.9801 val_mIoU=0.4900
Epoch 11/20 | tr_loss=0.2697 tr_acc=0.9842 tr_mIoU=0.4921 || val_loss=0.2466 val_acc=0.9844 val_mIoU=0.4922
Epoch 13/20 | tr_loss=0.2097 tr_acc=0.9973 tr_mIoU=0.8737 || val_loss=0.1968 val_acc=1.0000 val_mIoU=1.0000
Epoch 15/20 | tr_loss=0.1639 tr_acc=1.0000 tr_mIoU=1.0000 || val_loss=0.1588 val_acc=1.00