In [None]:
"""
Unified DeepLabV3+ training / validation / testing pipeline
Supports: CE (with class weights), Dice, Focal losses
- Computes class weights from train masks
- Saves per-epoch model files (filename contains filedate_epochXX.pth)
- Evaluates all saved epoch models on test set and selects best by test mIoU
- Computes detailed metrics for best model and per-image CSV
- Generates side-by-side visualizations for each test image
- Saves training history (CSV) and plots
- Records run time to TXT

Requirements:
- torch
- torchvision
- segmentation_models_pytorch (pip install segmentation-models-pytorch)
- PIL, numpy, pandas, matplotlib

Place your images and masks in folders and set `train_img_dir`, `train_mask_dir`, etc., below or pass via command-line args.
"""

import os
import time
import datetime
import glob
import csv
import argparse
from collections import defaultdict

import numpy as np
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

try:
    import segmentation_models_pytorch as smp
except Exception as e:
    raise ImportError("Please install segmentation-models-pytorch (pip install segmentation-models-pytorch)")

# -----------------------------
# Configuration (default)
# -----------------------------
IMG_SIZE = (512, 512)  # (H, W)
Channels = 3
batch_size = 8
num_epochs = 50
valsplit = 0.1
learning_rate = 1e-4
loss_type = "ce"  # "ce" / "dice" / "focal"
optimizer_type = "adam"  # adam / sgd
gamma_focal = 2.0
alpha_focal = None  # list or None, set after class weights computed

# Data paths (set to your dataset)
train_img_dir = "data/train/images"
train_mask_dir = "data/train/masks"
val_img_dir = None  # if None, split from train
val_mask_dir = None
test_img_dir = "data/test/images"
test_mask_dir = "data/test/masks"

# Output dirs (will include filedate)
filedate = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
out_dir = os.path.join("outputs", filedate)
model_dir = os.path.join(out_dir, "models")
viz_dir = os.path.join(out_dir, "visualization")
os.makedirs(model_dir, exist_ok=True)
os.makedirs(viz_dir, exist_ok=True)

history_csv = os.path.join(out_dir, f"history_{filedate}.csv")
runinfo_txt = os.path.join(out_dir, f"runinfo_{filedate}.txt")
plots_dir = os.path.join(out_dir, "plots")
os.makedirs(plots_dir, exist_ok=True)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# -----------------------------
# Dataset
# -----------------------------
class MyDataset(Dataset):
    def __init__(self, img_paths, mask_paths, img_size=IMG_SIZE, transforms=None):
        assert len(img_paths) == len(mask_paths)
        self.img_paths = img_paths
        self.mask_paths = mask_paths
        self.img_size = img_size
        # transforms for image (tensor & normalize)
        self.transforms = transforms or T.Compose([
            T.Resize(img_size, interpolation=Image.BILINEAR),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        # mask transforms (nearest, preserve labels)
        self.mask_transform = T.Compose([
            T.Resize(img_size, interpolation=Image.NEAREST),
        ])

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img = Image.open(self.img_paths[idx]).convert('RGB')
        mask = Image.open(self.mask_paths[idx]).convert('L')
        img = self.transforms(img)
        mask = self.mask_transform(mask)
        mask = np.array(mask, dtype=np.uint8)
        # Convert {0,255} -> {0,1}
        mask = (mask > 127).astype(np.uint8)
        mask = torch.from_numpy(mask).long()
        return img, mask, os.path.basename(self.img_paths[idx])

# -----------------------------
# Utilities: file pairing
# -----------------------------

def pair_images_and_masks(img_dir, mask_dir, img_exts=['.jpg', '.png', '.tif', '.tiff'], mask_exts=['.png']):
    imgs = []
    masks = []
    for ext in img_exts:
        imgs.extend(glob.glob(os.path.join(img_dir, f"*{ext}")))
    imgs = sorted(imgs)
    mask_map = {}
    for ext in mask_exts:
        for p in glob.glob(os.path.join(mask_dir, f"*{ext}")):
            mask_map[os.path.splitext(os.path.basename(p))[0]] = p
    img_paths = []
    mask_paths = []
    for p in imgs:
        stem = os.path.splitext(os.path.basename(p))[0]
        if stem in mask_map:
            img_paths.append(p)
            mask_paths.append(mask_map[stem])
    return img_paths, mask_paths

# -----------------------------
# Class weight computation
# -----------------------------

def compute_class_weights(dataset, num_classes=2):
    counts = np.zeros(num_classes, dtype=np.int64)
    for i in range(len(dataset)):
        _, mask, _ = dataset[i]
        mask_np = mask.numpy().ravel()
        for c in range(num_classes):
            counts[c] += int((mask_np == c).sum())
    total = counts.sum()
    freq = counts / total
    # weight: inverse of frequency
    weights = total / (num_classes * counts)
    weights = weights.astype(np.float32)
    return weights, counts, freq

# -----------------------------
# Losses
# -----------------------------
class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps

    def forward(self, logits, target):
        # logits: [B,C,H,W], target: [B,H,W]
        num_classes = logits.shape[1]
        probs = torch.softmax(logits, dim=1)
        target_onehot = nn.functional.one_hot(target, num_classes).permute(0,3,1,2).float()
        dims = (0,2,3)
        intersection = torch.sum(probs * target_onehot, dims)
        cardinality = torch.sum(probs + target_onehot, dims)
        dice_score = (2. * intersection + self.eps) / (cardinality + self.eps)
        loss = 1. - dice_score.mean()
        return loss

class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=None, reduction='mean'):
        super().__init__()
        self.gamma = gamma
        if alpha is not None:
            self.alpha = torch.tensor(alpha, dtype=torch.float32)
        else:
            self.alpha = None
        self.reduction = reduction

    def forward(self, logits, target):
        # logits: [B,C,H,W], target: [B,H,W]
        ce = nn.functional.cross_entropy(logits, target, reduction='none')
        probs = torch.softmax(logits, dim=1)
        pt = probs.gather(1, target.unsqueeze(1)).squeeze(1)  # [B,H,W]
        focal_term = (1 - pt) ** self.gamma
        loss = focal_term * ce
        if self.alpha is not None:
            alpha = self.alpha.to(logits.device)
            at = alpha.gather(0, target.flatten()).view_as(target).to(logits.device)
            loss = at * loss
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss

# -----------------------------
# Metrics
# -----------------------------

def confusion_matrix_from_logits(logits, target, num_classes=2):
    # logits: [B,C,H,W] or [B,1,H,W]
    preds = torch.argmax(logits, dim=1).view(-1).cpu().numpy()
    t = target.view(-1).cpu().numpy()
    cm = np.zeros((num_classes, num_classes), dtype=np.int64)
    for gt, pd in zip(t, preds):
        cm[gt, pd] += 1
    return cm


def compute_metrics_from_cm(cm):
    # cm: [num_classes,num_classes] where rows=gt, cols=pred
    num_classes = cm.shape[0]
    eps = 1e-6
    tp = np.diag(cm).astype(float)
    fp = np.sum(cm, axis=0) - tp
    fn = np.sum(cm, axis=1) - tp
    tn = cm.sum() - (tp + fp + fn)
    # per-class IoU
    iou = tp / (tp + fp + fn + eps)
    mean_iou = np.nanmean(iou)
    pixel_acc = tp.sum() / (cm.sum() + eps)
    # per-class accuracy: tp / (tp + fn)
    class_acc = tp / (tp + fn + eps)
    mean_acc = np.nanmean(class_acc)
    precision = tp / (tp + fp + eps)
    recall = tp / (tp + fn + eps)
    f1 = 2 * precision * recall / (precision + recall + eps)
    return {
        'iou_per_class': iou,
        'mean_iou': mean_iou,
        'pixel_acc': pixel_acc,
        'mean_acc': mean_acc,
        'precision_per_class': precision,
        'recall_per_class': recall,
        'f1_per_class': f1,
    }

# -----------------------------
# Train / Eval loops
# -----------------------------

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    running_cm = np.zeros((2,2), dtype=np.int64)
    n = 0
    for imgs, masks, _ in loader:
        imgs = imgs.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        logits = model(imgs)['out'] if isinstance(model(imgs), dict) else model(imgs)
        loss = criterion(logits, masks)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
        cm = confusion_matrix_from_logits(logits.detach(), masks.detach())
        running_cm += cm
        n += imgs.size(0)
    avg_loss = running_loss / n
    metrics = compute_metrics_from_cm(running_cm)
    return avg_loss, metrics['mean_iou']


def eval_one_epoch(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    running_cm = np.zeros((2,2), dtype=np.int64)
    n = 0
    with torch.no_grad():
        for imgs, masks, _ in loader:
            imgs = imgs.to(device)
            masks = masks.to(device)
            logits = model(imgs)['out'] if isinstance(model(imgs), dict) else model(imgs)
            loss = criterion(logits, masks)
            running_loss += loss.item() * imgs.size(0)
            cm = confusion_matrix_from_logits(logits, masks)
            running_cm += cm
            n += imgs.size(0)
    avg_loss = running_loss / n
    metrics = compute_metrics_from_cm(running_cm)
    return avg_loss, metrics['mean_iou']

# -----------------------------
# Full evaluation utilities
# -----------------------------

def evaluate_model_on_test(model, testset, criterion, device):
    # As requested, use batch_size = len(testset)
    testloader = DataLoader(testset, batch_size=len(testset))
    return eval_one_epoch(model, testloader, criterion, device)


def detailed_test_evaluation(model, testset, criterion, device, viz_dir=None, per_image_csv=None):
    # Evaluate overall metrics and per-image metrics + save visualizations
    model.eval()
    all_cm = np.zeros((2,2), dtype=np.int64)
    per_image_rows = []
    testloader = DataLoader(testset, batch_size=1, shuffle=False)
    with torch.no_grad():
        for imgs, masks, names in testloader:
            imgs = imgs.to(device)
            masks = masks.to(device)
            logits = model(imgs)['out'] if isinstance(model(imgs), dict) else model(imgs)
            loss = criterion(logits, masks).item()
            cm = confusion_matrix_from_logits(logits, masks)
            all_cm += cm
            m = compute_metrics_from_cm(cm)
            pred = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy().astype(np.uint8)
            gt = masks.squeeze(0).cpu().numpy().astype(np.uint8)
            # Save visualization
            if viz_dir is not None:
                save_visualization(imgs.squeeze(0).cpu(), gt, pred, names[0], viz_dir)
            row = {
                'image': names[0],
                'loss': loss,
                'mIoU': m['mean_iou'],
                'pixel_acc': m['pixel_acc'],
                'mean_acc': m['mean_acc'],
                'precision_class0': m['precision_per_class'][0],
                'precision_class1': m['precision_per_class'][1],
                'recall_class0': m['recall_per_class'][0],
                'recall_class1': m['recall_per_class'][1],
                'f1_class0': m['f1_per_class'][0],
                'f1_class1': m['f1_per_class'][1],
            }
            per_image_rows.append(row)
    overall = compute_metrics_from_cm(all_cm)
    if per_image_csv is not None:
        df = pd.DataFrame(per_image_rows)
        df.to_csv(per_image_csv, index=False)
    return overall, per_image_rows


def save_visualization(img_tensor, gt_mask, pred_mask, name, viz_dir):
    # img_tensor: normalized tensor; convert back to RGB
    img = img_tensor.clone()
    img = img * torch.tensor([0.229,0.224,0.225]).view(3,1,1)
    img = img + torch.tensor([0.485,0.456,0.406]).view(3,1,1)
    img = img.clamp(0,1).permute(1,2,0).numpy()
    gt = gt_mask
    pred = pred_mask
    # Stack horizontally
    fig, axes = plt.subplots(1,3, figsize=(12,4))
    axes[0].imshow(img)
    axes[0].set_title('Input')
    axes[0].axis('off')
    axes[1].imshow(gt, cmap='gray')
    axes[1].set_title('GT')
    axes[1].axis('off')
    axes[2].imshow(pred, cmap='gray')
    axes[2].set_title('Pred')
    axes[2].axis('off')
    plt.tight_layout()
    outpath = os.path.join(viz_dir, f"{os.path.splitext(name)[0]}_viz.png")
    plt.savefig(outpath, dpi=150)
    plt.close(fig)

# -----------------------------
# Plot helpers
# -----------------------------

def plot_loss(history_csv, outpath):
    df = pd.read_csv(history_csv)
    plt.figure()
    plt.plot(df['epoch'], df['train_loss'], label='train_loss')
    plt.plot(df['epoch'], df['val_loss'], label='val_loss')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend()
    plt.grid(True)
    plt.savefig(outpath)
    plt.close()


def plot_acc(history_csv, outpath):
    df = pd.read_csv(history_csv)
    plt.figure()
    plt.plot(df['epoch'], df['train_mIoU'], label='train_mIoU')
    plt.plot(df['epoch'], df['val_mIoU'], label='val_mIoU')
    plt.xlabel('epoch')
    plt.ylabel('mIoU')
    plt.legend()
    plt.grid(True)
    plt.savefig(outpath)
    plt.close()

# -----------------------------
# Main pipeline
# -----------------------------

def main(args):
    start_time = time.time()
    # Prepare datasets
    train_imgs, train_masks = pair_images_and_masks(args.train_img_dir, args.train_mask_dir)
    test_imgs, test_masks = pair_images_and_masks(args.test_img_dir, args.test_mask_dir)
    assert len(train_imgs) > 0, 'No training images found'
    assert len(test_imgs) > 0, 'No test images found'

    # Split train/val if val dirs not provided
    if args.val_img_dir is None or args.val_mask_dir is None:
        n = len(train_imgs)
        nval = max(1, int(n * args.valsplit))
        # simple split
        val_imgs = train_imgs[:nval]
        val_masks = train_masks[:nval]
        train_imgs2 = train_imgs[nval:]
        train_masks2 = train_masks[nval:]
    else:
        val_imgs, val_masks = pair_images_and_masks(args.val_img_dir, args.val_mask_dir)
        train_imgs2, train_masks2 = train_imgs, train_masks

    trainset = MyDataset(train_imgs2, train_masks2, img_size=args.img_size)
    valset = MyDataset(val_imgs, val_masks, img_size=args.img_size)
    testset = MyDataset(test_imgs, test_masks, img_size=args.img_size)

    # Dataloaders
    trainloader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2)
    valloader = DataLoader(valset, batch_size=min(len(valset), args.batch_size), shuffle=False, num_workers=2)

    # Class weights
    class_weights, counts, freq = compute_class_weights(trainset)
    class_weights = torch.tensor(class_weights, dtype=torch.float32).to(DEVICE)
    print(f"Class counts: {counts}, freq: {freq}, weights: {class_weights}")

    # Prepare model
    model = smp.DeepLabV3Plus(encoder_name='resnet50', encoder_weights='imagenet', in_channels=3, classes=2)
    model.to(DEVICE)

    # Loss selection
    if args.loss_type == 'ce':
        criterion = nn.CrossEntropyLoss(weight=class_weights)
    elif args.loss_type == 'dice':
        criterion = DiceLoss()
    elif args.loss_type == 'focal':
        alpha = args.alpha_focal if args.alpha_focal is not None else class_weights.cpu().numpy().tolist()
        criterion = FocalLoss(gamma=args.gamma_focal, alpha=alpha)
    else:
        raise ValueError('Unknown loss type')

    # Optimizer
    if args.optimizer_type.lower() == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9)

    # History store
    history = []

    for epoch in range(1, args.num_epochs + 1):
        train_loss, train_miou = train_one_epoch(model, trainloader, optimizer, criterion, DEVICE)
        val_loss, val_miou = eval_one_epoch(model, valloader, criterion, DEVICE)

        print(f"Epoch {epoch}: train_loss={train_loss:.4f}, train_mIoU={train_miou:.4f}, val_loss={val_loss:.4f}, val_mIoU={val_miou:.4f}")
        history.append({'epoch': epoch, 'train_loss': train_loss, 'train_mIoU': train_miou, 'val_loss': val_loss, 'val_mIoU': val_miou})

        # Save every epoch (to support per-epoch test evaluation). Also make note for 10-epoch saves.
        model_path = os.path.join(args.model_dir, f"{args.filedate}_epoch{epoch:03d}.pth")
        torch.save(model.state_dict(), model_path)
        if epoch % 10 == 0:
            periodic_path = os.path.join(args.model_dir, f"{args.filedate}_epoch{epoch:03d}_periodic.pth")
            torch.save(model.state_dict(), periodic_path)

        # Save history to CSV each epoch
        df = pd.DataFrame(history)
        df.to_csv(args.history_csv, index=False)

    # After training: evaluate all saved epoch models on test set
    print('Evaluating saved epoch models on test set...')
    model_files = sorted(glob.glob(os.path.join(args.model_dir, f"{args.filedate}_epoch*.pth")))
    per_epoch_results = []
    for p in model_files:
        state = torch.load(p, map_location=DEVICE)
        model.load_state_dict(state)
        test_loss, test_miou = evaluate_model_on_test(model, testset, criterion, DEVICE)
        per_epoch_results.append({'model_file': os.path.basename(p), 'test_loss': test_loss, 'test_mIoU': test_miou})
        print(f"Model {os.path.basename(p)} -> test_loss={test_loss:.4f}, test_mIoU={test_miou:.4f}")

    per_epoch_df = pd.DataFrame(per_epoch_results)
    per_epoch_csv = os.path.join(args.out_dir, f"per_epoch_test_results_{args.filedate}.csv")
    per_epoch_df.to_csv(per_epoch_csv, index=False)

    # Select best model by test mIoU
    best_row = per_epoch_df.loc[per_epoch_df['test_mIoU'].idxmax()]
    best_model_file = os.path.join(args.model_dir, best_row['model_file'])
    print(f"Best model: {best_model_file} with test_mIoU={best_row['test_mIoU']:.4f}")
    state = torch.load(best_model_file, map_location=DEVICE)
    model.load_state_dict(state)

    # Detailed evaluation with best model
    per_image_csv = os.path.join(args.out_dir, f"per_image_metrics_{args.filedate}.csv")
    overall_metrics, per_image_rows = detailed_test_evaluation(model, testset, criterion, DEVICE, viz_dir=args.viz_dir, per_image_csv=per_image_csv)

    # Save overall metrics
    overall_csv = os.path.join(args.out_dir, f"overall_test_metrics_{args.filedate}.csv")
    with open(overall_csv, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['metric', 'value'])
        writer.writerow(['loss', 'n/a'])
        writer.writerow(['mIoU', overall_metrics['mean_iou']])
        writer.writerow(['pixel_acc', overall_metrics['pixel_acc']])
        writer.writerow(['mean_acc', overall_metrics['mean_acc']])
        for i, p in enumerate(overall_metrics['precision_per_class']):
            writer.writerow([f'precision_class{i}', p])
        for i, r in enumerate(overall_metrics['recall_per_class']):
            writer.writerow([f'recall_class{i}', r])
        for i, f1 in enumerate(overall_metrics['f1_per_class']):
            writer.writerow([f'f1_class{i}', f1])

    # Plots
    plot_loss(args.history_csv, os.path.join(args.plots_dir, f"loss_{args.filedate}.png"))
    plot_acc(args.history_csv, os.path.join(args.plots_dir, f"acc_{args.filedate}.png"))

    elapsed = time.time() - start_time
    with open(args.runinfo_txt, 'w') as f:
        f.write(f"Start: {start_time}\n")
        f.write(f"End: {time.time()}\n")
        f.write(f"Elapsed seconds: {elapsed}\n")

    print('Done. Outputs:')
    print(f" - models: {args.model_dir}")
    print(f" - per-epoch test CSV: {per_epoch_csv}")
    print(f" - best model: {best_model_file}")
    print(f" - per-image CSV: {per_image_csv}")
    print(f" - visualizations: {args.viz_dir}")
    print(f" - history CSV: {args.history_csv}")
    print(f" - plots: {args.plots_dir}")
    print(f" - run info: {args.runinfo_txt}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--img_size', type=lambda s: tuple(map(int, s.split(','))), default=IMG_SIZE)
    parser.add_argument('--batch_size', type=int, default=batch_size)
    parser.add_argument('--num_epochs', type=int, default=num_epochs)
    parser.add_argument('--valsplit', type=float, default=valsplit)
    parser.add_argument('--learning_rate', type=float, default=learning_rate)
    parser.add_argument('--loss_type', type=str, default=loss_type)
    parser.add_argument('--optimizer_type', type=str, default=optimizer_type)
    parser.add_argument('--train_img_dir', type=str, default=train_img_dir)
    parser.add_argument('--train_mask_dir', type=str, default=train_mask_dir)
    parser.add_argument('--val_img_dir', type=str, default=val_img_dir)
    parser.add_argument('--val_mask_dir', type=str, default=val_mask_dir)
    parser.add_argument('--test_img_dir', type=str, default=test_img_dir)
    parser.add_argument('--test_mask_dir', type=str, default=test_mask_dir)
    parser.add_argument('--model_dir', type=str, default=model_dir)
    parser.add_argument('--out_dir', type=str, default=out_dir)
    parser.add_argument('--viz_dir', type=str, default=viz_dir)
    parser.add_argument('--history_csv', type=str, default=history_csv)
    parser.add_argument('--plots_dir', type=str, default=plots_dir)
    parser.add_argument('--runinfo_txt', type=str, default=runinfo_txt)
    parser.add_argument('--filedate', type=str, default=filedate)
    parser.add_argument('--gamma_focal', type=float, default=gamma_focal)
    parser.add_argument('--alpha_focal', type=lambda s: list(map(float, s.split(','))) if s else None, default=alpha_focal)
    args = parser.parse_args()
    main(args)
