In [1]:
#sample data creation script

import numpy as np
import imageio.v2 as imageio
from scipy.ndimage import binary_dilation
import os

def generate_infestation_sequence(
    size=128, timesteps=10, n_seeds=3, spread_prob=0.3
):
    """
    Generates a sequence of binary masks simulating algae infestation over time.
    Returns a NumPy array of shape (timesteps, size, size)
    """
    rng = np.random.default_rng()
    masks = []

    # initialize empty grid
    mask = np.zeros((size, size), dtype=np.uint8)

    # random starting seeds
    seed_coords = rng.integers(0, size, (n_seeds, 2))
    for (x, y) in seed_coords:
        mask[x, y] = 1

    # 3x3 neighborhood kernel
    structure = np.ones((3, 3), dtype=bool)

    for _ in range(timesteps):
        masks.append(mask.copy())

        # grow probabilistically
        dilated = binary_dilation(mask, structure=structure)
        new_growth = (dilated & (mask == 0)) & (rng.random(mask.shape) < spread_prob)
        mask = mask | new_growth

    return np.stack(masks)


def generate_dataset(
    output_dir="infestation_dataset",
    n_samples=1000,
    size=128,
    timesteps=10,
    n_seeds_range=(2, 5),
    spread_prob_range=(0.2, 0.4),
    save_as_images=True
):
    """
    Generates a dataset of infestation masks and saves them to disk.
    Each sample folder contains a time series of infestation masks.
    """
    os.makedirs(output_dir, exist_ok=True)

    rng = np.random.default_rng()

    for i in range(n_samples):
        n_seeds = rng.integers(*n_seeds_range)
        spread_prob = rng.uniform(*spread_prob_range)

        masks = generate_infestation_sequence(
            size=size,
            timesteps=timesteps,
            n_seeds=n_seeds,
            spread_prob=spread_prob
        )

        sample_dir = os.path.join(output_dir, f"sample_{i:04d}")
        os.makedirs(sample_dir, exist_ok=True)

        if save_as_images:
            for t, mask in enumerate(masks):
                imageio.imwrite(
                    os.path.join(sample_dir, f"mask_{t:02d}.png"),
                    (mask * 255).astype(np.uint8)
                )
        else:
            np.save(os.path.join(sample_dir, "masks.npy"), masks)

        if (i + 1) % 50 == 0:
            print(f"Generated {i+1}/{n_samples} samples")

    print(f"\n✅ Dataset created at: {os.path.abspath(output_dir)}")


# --- Run the generator ---
generate_dataset(
    output_dir="infestation_dataset",
    n_samples=1000,    # total samples
    size=128,          # resolution
    timesteps=10,      # frames per sample
    n_seeds_range=(2, 5),
    spread_prob_range=(0.2, 0.4),
    save_as_images=True
)


Generated 50/1000 samples
Generated 100/1000 samples
Generated 150/1000 samples
Generated 200/1000 samples
Generated 250/1000 samples
Generated 300/1000 samples
Generated 350/1000 samples
Generated 400/1000 samples
Generated 450/1000 samples
Generated 500/1000 samples
Generated 550/1000 samples
Generated 600/1000 samples
Generated 650/1000 samples
Generated 700/1000 samples
Generated 750/1000 samples
Generated 800/1000 samples
Generated 850/1000 samples
Generated 900/1000 samples
Generated 950/1000 samples
Generated 1000/1000 samples

✅ Dataset created at: /content/infestation_dataset


In [2]:
#this creates and uses the data
"""
infestation_unet_train.py

Requirements:
- Python 3.8+
- numpy
- scipy
- imageio
- matplotlib (optional, for plotting)
- torch (PyTorch)
- torchvision
- scikit-learn

Run:
    python infestation_unet_train.py
"""

import os
import random
import numpy as np
from pathlib import Path
from scipy.ndimage import binary_dilation
import imageio.v2 as imageio
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

# -------------------------
# 1) Synthetic generator
# -------------------------
def generate_infestation_sequence(size=128, timesteps=6, n_seeds=3, spread_prob=0.3, rng=None):
    """
    Returns array shape (timesteps, H, W) with binary masks (0/1).
    """
    if rng is None:
        rng = np.random.default_rng()
    masks = []
    mask = np.zeros((size, size), dtype=np.uint8)

    seed_coords = rng.integers(0, size, (n_seeds, 2))
    for (x, y) in seed_coords:
        mask[x, y] = 1

    structure = np.ones((3, 3), dtype=bool)

    for _ in range(timesteps):
        masks.append(mask.copy())
        dilated = binary_dilation(mask, structure=structure)
        new_growth = (dilated & (mask == 0)) & (rng.random(mask.shape) < spread_prob)
        mask = mask | new_growth

    return np.stack(masks)  # (timesteps, H, W)


def generate_dataset_in_memory(n_sequences=1000, size=128, timesteps=6,
                               n_seeds_range=(1,4), spread_prob_range=(0.15,0.45), seed=42):
    """
    Generates sequences and returns tuple (X_pairs, Y_pairs)
    where each pair is (mask_t, mask_t+1)
    X_pairs shape: (N_pairs, 1, H, W)
    Y_pairs shape: (N_pairs, 1, H, W)
    """
    rng = np.random.default_rng(seed)
    X_list = []
    Y_list = []

    for i in range(n_sequences):
        n_seeds = int(rng.integers(n_seeds_range[0], n_seeds_range[1]+1))
        spread_prob = float(rng.uniform(spread_prob_range[0], spread_prob_range[1]))
        seq = generate_infestation_sequence(size=size, timesteps=timesteps, n_seeds=n_seeds, spread_prob=spread_prob, rng=rng)
        # create pairs (t -> t+1)
        for t in range(seq.shape[0] - 1):
            X_list.append(seq[t:t+1].astype(np.float32))  # (1,H,W)
            Y_list.append(seq[t+1:t+2].astype(np.float32))
    X = np.stack(X_list)  # (N_pairs, 1, H, W)
    Y = np.stack(Y_list)
    return X, Y

# -------------------------
# 2) PyTorch Dataset
# -------------------------
class InfestationPairsDataset(Dataset):
    def __init__(self, X, Y, transform=None):
        self.X = X
        self.Y = Y
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = self.X[idx]  # float32 1xHxW
        y = self.Y[idx]
        # Optionally add small noise / augmentation
        if self.transform:
            x, y = self.transform(x, y)
        return torch.from_numpy(x), torch.from_numpy(y)

# -------------------------
# 3) Simple U-Net (small)
# -------------------------
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

class UNetSmall(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=32):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.pool = nn.MaxPool2d(2)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)

        self.bottleneck = ConvBlock(base_ch*4, base_ch*8)

        self.up3 = nn.ConvTranspose2d(base_ch*8, base_ch*4, kernel_size=2, stride=2)
        self.dec3 = ConvBlock(base_ch*8, base_ch*4)
        self.up2 = nn.ConvTranspose2d(base_ch*4, base_ch*2, kernel_size=2, stride=2)
        self.dec2 = ConvBlock(base_ch*4, base_ch*2)
        self.up1 = nn.ConvTranspose2d(base_ch*2, base_ch, kernel_size=2, stride=2)
        self.dec1 = ConvBlock(base_ch*2, base_ch)

        self.head = nn.Conv2d(base_ch, out_ch, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        d3 = self.up3(b)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        out = self.head(d1)
        return out  # logits

# -------------------------
# 4) Metrics
# -------------------------
def iou_score(preds, targets, threshold=0.5, eps=1e-7):
    """
    preds: torch tensor logits or probabilities (B,1,H,W)
    targets: tensor (B,1,H,W)
    """
    if preds.dtype != torch.uint8 and preds.dtype != torch.bool:
        probs = torch.sigmoid(preds)
        preds_bin = (probs > threshold).float()
    else:
        preds_bin = preds.float()
    inter = (preds_bin * targets).sum(dim=(1,2,3))
    union = ((preds_bin + targets) >= 1).float().sum(dim=(1,2,3))
    iou = (inter + eps) / (union + eps)
    return iou.mean().item()

# -------------------------
# 5) Training loop
# -------------------------
def train_model(model, train_loader, val_loader, device, epochs=10, lr=1e-3, save_dir="checkpoints"):
    opt = Adam(model.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()
    os.makedirs(save_dir, exist_ok=True)

    best_val_iou = 0.0
    for epoch in range(1, epochs+1):
        model.train()
        train_loss = 0.0
        for xb, yb in train_loader:
            xb = xb.to(device)
            yb = yb.to(device)
            logits = model(xb)
            loss = criterion(logits, yb)
            opt.zero_grad()
            loss.backward()
            opt.step()
            train_loss += loss.item() * xb.size(0)
        train_loss /= len(train_loader.dataset)

        # Validation
        model.eval()
        val_loss = 0.0
        iou_acc = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb = xb.to(device)
                yb = yb.to(device)
                logits = model(xb)
                loss = criterion(logits, yb)
                val_loss += loss.item() * xb.size(0)
                iou_acc += iou_score(logits.detach().cpu(), yb.detach().cpu()) * xb.size(0)
        val_loss /= len(val_loader.dataset)
        val_iou = iou_acc / len(val_loader.dataset)

        print(f"Epoch {epoch}/{epochs} — train_loss: {train_loss:.4f} val_loss: {val_loss:.4f} val_iou: {val_iou:.4f}")

        # Save best
        if val_iou > best_val_iou:
            best_val_iou = val_iou
            torch.save(model.state_dict(), os.path.join(save_dir, "best_model.pth"))
    print("Training finished. Best val IoU:", best_val_iou)

# -------------------------
# 6) Utilities: visualize some examples
# -------------------------
def save_predictions_sample(model, dataset, device, out_dir="pred_samples", n_samples=8):
    os.makedirs(out_dir, exist_ok=True)
    model.eval()
    with torch.no_grad():
        for i in range(n_samples):
            xb, yb = dataset[i]
            xb_t = xb.unsqueeze(0).to(device)  # add batch
            logits = model(xb_t)
            probs = torch.sigmoid(logits).squeeze().cpu().numpy()
            pred_bin = (probs > 0.5).astype(np.uint8)
            input_mask = xb.squeeze().numpy().astype(np.uint8)
            true_mask = yb.squeeze().numpy().astype(np.uint8)

            # Compose image: left=input, middle=truth, right=pred
            H, W = input_mask.shape
            canvas = np.zeros((H, W*3), dtype=np.uint8)
            canvas[:, :W] = input_mask * 255
            canvas[:, W:2*W] = true_mask * 255
            canvas[:, 2*W:3*W] = pred_bin * 255
            imageio.imwrite(os.path.join(out_dir, f"sample_{i:02d}.png"), canvas)

    print("Saved sample predictions to", out_dir)

# -------------------------
# 7) Main: generate data, train, eval
# -------------------------
def main():
    # SETTINGS - adjust as needed
    IMG_SIZE = 128
    SEQS = 250           # number of sequences (each has timesteps frames)
    TIMESTEPS = 6
    BATCH_SIZE = 32
    EPOCHS = 12
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", DEVICE)

    # 1) Generate dataset (in memory)
    print("Generating synthetic dataset ...")
    X, Y = generate_dataset_in_memory(n_sequences=SEQS, size=IMG_SIZE, timesteps=TIMESTEPS,
                                      n_seeds_range=(1,4), spread_prob_range=(0.12,0.45), seed=123)
    # X,Y shapes: (N_pairs, 1, H, W)
    print("Total pairs:", X.shape[0])

    # 2) Split into train/val
    X_train, X_val, Y_train, Y_val = train_test_split(X, Y, test_size=0.12, random_state=42)
    train_ds = InfestationPairsDataset(X_train, Y_train)
    val_ds = InfestationPairsDataset(X_val, Y_val)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    # 3) Model
    model = UNetSmall(in_ch=1, out_ch=1, base_ch=32).to(DEVICE)

    # 4) Train
    train_model(model, train_loader, val_loader, device=DEVICE, epochs=EPOCHS, lr=1e-3, save_dir="checkpoints")

    # 5) Save some predictions
    save_predictions_sample(model, val_ds, device=DEVICE, out_dir="pred_samples", n_samples=12)

if __name__ == "__main__":
    main()


Using device: cuda
Generating synthetic dataset ...
Total pairs: 1250
Epoch 1/12 — train_loss: 0.3587 val_loss: 0.2337 val_iou: 0.2220
Epoch 2/12 — train_loss: 0.1999 val_loss: 0.1630 val_iou: 0.5082
Epoch 3/12 — train_loss: 0.1255 val_loss: 0.1047 val_iou: 0.5352
Epoch 4/12 — train_loss: 0.0810 val_loss: 0.0642 val_iou: 0.5358
Epoch 5/12 — train_loss: 0.0549 val_loss: 0.0465 val_iou: 0.5379
Epoch 6/12 — train_loss: 0.0391 val_loss: 0.0349 val_iou: 0.5403
Epoch 7/12 — train_loss: 0.0291 val_loss: 0.0253 val_iou: 0.5408
Epoch 8/12 — train_loss: 0.0226 val_loss: 0.0204 val_iou: 0.5385
Epoch 9/12 — train_loss: 0.0181 val_loss: 0.0162 val_iou: 0.5391
Epoch 10/12 — train_loss: 0.0149 val_loss: 0.0135 val_iou: 0.5364
Epoch 11/12 — train_loss: 0.0126 val_loss: 0.0115 val_iou: 0.5394
Epoch 12/12 — train_loss: 0.0108 val_loss: 0.0099 val_iou: 0.5386
Training finished. Best val IoU: 0.5407860430081686
Saved sample predictions to pred_samples
