In [1]:
import os
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image, ImageDraw
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.optim as optim

In [2]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [3]:
img_datapath = "/srv/data/lt2326-h25/a1/images"
info_datapath = "/srv/data/lt2326-h25/a1/info.json"
ann_datapath = "/srv/data/lt2326-h25/a1/train.jsonl"

Dataset definition

In [4]:
class BinaryImageSegmentation(Dataset):
    def __init__(self, img_dir, ann_file, info_file, img_size=256):
        self.img_size = img_size

        self.names = []
        with open(info_file, "r", encoding="utf-8") as a:
            info = json.load(a)
            for item in info["train"]:
                self.names.append(item["file_name"])

        self.img_files = {}
        for f in os.listdir(img_dir):
            name = os.path.basename(f)
            if name in self.names:
                self.img_files[name] = os.path.join(img_dir, f)
                
        self.samples = []
        with open(ann_file, "r", encoding="utf-8") as f:
            for line in f:
                obj = json.loads(line)
                if obj["file_name"] in self.img_files:
                    self.samples.append(obj)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        file_name = sample["file_name"]
        img_path = self.img_files[file_name]

        #load image
        image = Image.open(img_path).convert("RGB")
        w, h = image.size
        image = image.resize((self.img_size, self.img_size), Image.BILINEAR)

        image = transforms.functional.to_tensor(image)
        image = transforms.functional.normalize(image, mean=[0.5]*3, std=[0.5]*3)

        #create blank binary mask
        bi_mask = Image.new("L", (w, h), 0)
        draw = ImageDraw.Draw(bi_mask)

        # draw polygons from ann
        for ann_group in sample.get("annotations", []):
            for ann in ann_group:
                if isinstance(ann, dict):
                    if ann.get("is_chinese") and not ann.get("ignore"):
                        polygon = [tuple(p) for p in ann["polygon"]]
                        draw.polygon(polygon, outline=1, fill=1)

        bi_mask = bi_mask.resize((self.img_size, self.img_size), Image.NEAREST)
        bi_mask = torch.from_numpy(np.array(bi_mask)).long()
        bi_mask = bi_mask.unsqueeze(0)
        return image, bi_mask
        
                           

Dataset split function

In [5]:
def split_dataset(dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=26):
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6

    tot_len = len(dataset)
    train_len = int(tot_len*train_ratio)
    val_len = int(tot_len*val_ratio)
    test_len = tot_len - train_len - val_len

    generator = torch.Generator().manual_seed(seed)
    return random_split(dataset, [train_len, val_len, test_len], generator=generator)
    

Create Dataset and split

In [6]:
dataset = BinaryImageSegmentation(
    img_dir=img_datapath,
    ann_file=ann_datapath,
    info_file=info_datapath,
    img_size=256
)

train_dataset, val_dataset, test_dataset = split_dataset(dataset)

Create dataloaders

In [7]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

create training loop

In [8]:
def epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for imgs, masks in dataloader:
        imgs, masks = imgs.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, masks.float())
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss /len(dataloader)

In [9]:
def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    preds, gts = [], []
    with torch.no_grad():
        for imgs, masks in dataloader:
            imgs, masks = imgs.to(device), masks.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, masks.float())
            total_loss += loss.item()

            preds.append(torch.sigmoid(outputs).cpu().numpy())
            gts.append(masks.cpu().numpy())

    preds = np.concatenate(preds, axis=0)
    gts = np.concatenate(gts, axis=0)
    return total_loss/len(dataloader), preds, gts

In [10]:
def dice_loss(preds, masks, smooth=1e-6):
    preds = preds.contiguous().view(-1)
    masks = masks.contiguous().view(-1)

    intersection = (preds * masks).sum()
    dice = (2. * intersection + smooth) / (preds.sum() + masks.sum() + smooth)
    return 1 - dice

In [11]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth

    def forward(self, preds, masks):
        return dice_loss(preds, masks, smooth=self.smooth)

In [12]:
def train_model(model, train_loader, val_loader, test_loader, device, epochs=50):
    from lab1_eval import (plot_training_curve, eval_threshold, iou, dice, eval_testset, visualize_pred)
    
    criterion = DiceLoss()
    optimizer = optim.Adam(model.parameters())

    train_losses, val_losses = [], []
    best_val_loss = float("inf")

    for e in range(epochs):
        train_loss = epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_preds, val_masks = validate(model, val_loader, criterion, device)

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        print(f"Epoch {e+1}/{epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

        if val_loss > best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_model.pth")

    #eval after training
    

    
    plot_training_curve(train_losses, val_losses)

    thresholds = np.linspace(0.1, 0.9, 9)
    metrics = eval_threshold(val_masks, val_preds, thresholds)
    for m in metrics:
        print(f"Threshold={m['threshold']:.2f} | Precision={m['precision']:.3f} | Recall={m['recall']:.3f} | F1={m['f1']:.3f} | Accuracy={m['accuracy']:.3f}")

    best_threshold = max(metrics, key=lambda x: x['f1'])['threshold']
    print(f'Best threshold for validations set: {best_threshold}')
    
    _, test_preds, test_masks = validate(model, test_loader, criterion, device)
    test_metrics = eval_testset(test_preds, test_masks, threshold=best_threshold)
    print(f"Final test metrics: {test_metrics}")

    for i in range(3):
        img, mask = test_loader.dataset[i]
        visualize_pred(img.permute(1, 2, 0), mask, test_preds[i])
    

Initialise model

In [13]:
from UNet import PixelPredictModel_UNet 
UNet_model = PixelPredictModel_UNet(in_ch=3, out_ch=1).to(device)

train_model(UNet_model, train_loader=train_loader,val_loader=val_loader, test_loader=test_loader, device=device, epochs=25)

NameError: name 'torch' is not defined

In [None]:
from SegNet import PixelPredictModel_SegNet
SegNet_model = PixelPredictModel_SegNet(in_ch=3, out_ch=1).to(device)

train_model(SegNet_model, train_loader=train_loader,val_loader=val_loader, test_loader=test_loader, device=device, epochs=25)