In [None]:
"""
Unified DeepLabV3+ training / validation / testing pipeline
Supports: CE (with class weights), Dice, Focal losses
- Computes class weights from train masks
- Saves per-epoch model files (filename contains filedate_epochXX.pth)
- Evaluates all saved epoch models on test set and selects best by test mIoU
- Computes detailed metrics for best model and per-image CSV
- Generates side-by-side visualizations for each test image
- Saves training history (CSV) and plots
- Records run time to TXT

Requirements:
- torch
- torchvision (ensure it includes deeplabv3_resnet101)
- PIL, numpy, pandas, matplotlib

Place your images and masks in folders and set `train_img_dir`, `train_mask_dir`, etc., below or pass via command-line 
"""

import os
import time
import datetime
import glob
import csv
import argparse
from collections import defaultdict
from pathlib import Path

import numpy as np
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision.models.segmentation import deeplabv3_resnet101, DeepLabV3_ResNet101_Weights

print("Library import successful.")

# -----------------------------
# Configuration (default)
# -----------------------------
IMG_SIZE = (64, 64)  # (H, W)
Channels = 3
batch_size = 2
classes = ['background', 'debris']
num_classes = len(classes)
num_epochs = 10 #10以上にしないと重みセーブできなくてエラーになる
valsplit = 0.25
learning_rate = 1e-4
loss_type = "ce"  # "ce" / "dice" / "focal"
optimizer_type = "adam"  # adam / sgd
gamma_focal = 2.0
alpha_focal = None  # list or None, set after class weights computed

# Data paths (set to your dataset)
root_dir = r"C:\Users\kyohe\Aerial_Photo_Segmenter\20251209Data"

# Input dirs: the img and its mask has to have THE SAME FILENAME (different extensions allowed), or else they won't be paired.
train_img_dir = Path(root_dir + r"\TrainVal\img")
train_mask_dir = root_dir + r"\TrainVal\mask"

val_img_dir = None  # if None, split from train
val_mask_dir = None
test_img_dir = root_dir + r"\Test\img"
test_mask_dir = root_dir + r"\Test\mask"

# Output dirs (will include filedate)
history_root = root_dir + r"\History\\"
model_root = root_dir + r"\Weights\\"
result_root = root_dir + r"\Result_Segmentation\\"
os.makedirs(history_root, exist_ok=True)
os.makedirs(result_root, exist_ok=True)
os.makedirs(model_root, exist_ok=True)

filedate = datetime.datetime.now().strftime('%Y%m%d_%H%M')
history_dir = history_root + filedate
model_dir = model_root + filedate
result_dir = result_root + filedate
viz_dir = result_dir + r"\Visualizations"
pred_dir = result_dir + r"\PredMasks"

os.makedirs(history_dir, exist_ok=True)
history_csv = os.path.join(history_dir, f"history_{filedate}.csv")
runinfo_txt = os.path.join(history_dir, f"runinfo_{filedate}.txt")
plots_dir = os.path.join(result_dir, "plots")

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# -----------------------------
# Dataset
# -----------------------------
class MyDataset(Dataset):
    def __init__(self, img_paths, mask_paths, img_size=IMG_SIZE, transforms=None):
        assert len(img_paths) == len(mask_paths)
        self.img_paths = img_paths
        self.mask_paths = mask_paths
        self.img_size = img_size
        # transforms for image (tensor & normalize)
        self.transforms = transforms or T.Compose([
            T.Resize(img_size, interpolation=Image.BILINEAR),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        # mask transforms (nearest, preserve labels)
        self.mask_transform = T.Compose([
            T.Resize(img_size, interpolation=Image.NEAREST),
        ])

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img = Image.open(self.img_paths[idx]).convert('RGB')
        mask = Image.open(self.mask_paths[idx]).convert('L')
        img = self.transforms(img)
        mask = self.mask_transform(mask)
        mask = np.array(mask, dtype=np.uint8)
        # Convert {0,255} -> {0,1}
        mask = (mask > 127).astype(np.uint8)
        mask = torch.from_numpy(mask).long()
        return img, mask, os.path.basename(self.img_paths[idx])

# -----------------------------
# Utilities: file pairing
# -----------------------------

def pair_images_and_masks(img_dir, mask_dir, img_exts=['.jpg', '.png', '.tif', '.tiff'], mask_exts=['.png']):
    imgs = []
    masks = []
    for ext in img_exts:
        imgs.extend(glob.glob(os.path.join(img_dir, f"*{ext}")))
    imgs = sorted(imgs)
    mask_map = {}
    for ext in mask_exts:
        for p in glob.glob(os.path.join(mask_dir, f"*{ext}")):
            mask_map[os.path.splitext(os.path.basename(p))[0]] = p
    img_paths = []
    mask_paths = []
    for p in imgs:
        stem = os.path.splitext(os.path.basename(p))[0]
        if stem in mask_map:
            img_paths.append(p)
            mask_paths.append(mask_map[stem])
    return img_paths, mask_paths

# For the Test Dataset: record original sizes
def pair_test_images_and_masks(img_dir, mask_dir, img_exts=['.jpg', '.png', '.tif', '.tiff'], mask_exts=['.png']):
    imgs = []
    masks = []
    for ext in img_exts:
        imgs.extend(glob.glob(os.path.join(img_dir, f"*{ext}")))
    imgs = sorted(imgs)
    mask_map = {}
    for ext in mask_exts:
        for p in glob.glob(os.path.join(mask_dir, f"*{ext}")):
            mask_map[os.path.splitext(os.path.basename(p))[0]] = p
            
    img_paths = []
    mask_paths = []
    test_orig_sizes = {}
    for p in imgs:
        stem = os.path.splitext(os.path.basename(p))[0]
        if stem in mask_map:
            img_paths.append(p)
            mask_paths.append(mask_map[stem])
            try:
                with Image.open(p) as im:
                    w, h = im.size
                test_orig_sizes[stem] = (h, w)
            except Exception:
                pass
    return img_paths, mask_paths, test_orig_sizes

# -----------------------------
# Class weight computation
# -----------------------------

def compute_class_weights(dataset):
    counts = np.zeros(num_classes, dtype=np.int64)
    for i in range(len(dataset)):
        _, mask, _ = dataset[i]
        mask_np = mask.numpy().ravel()
        for c in range(num_classes):
            counts[c] += int((mask_np == c).sum())
    total = counts.sum()
    freq = counts / total
    # weight: inverse of frequency
    weights = total / (num_classes * counts)
    weights = weights.astype(np.float32)
    return weights, counts, freq

# -----------------------------
# Losses
# -----------------------------
class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps

    def forward(self, logits, target):
        # logits: [B,C,H,W], target: [B,H,W]
        # num_classes = logits.shape[1]
        probs = torch.softmax(logits, dim=1)
        target_onehot = nn.functional.one_hot(target, num_classes).permute(0,3,1,2).float()
        dims = (0,2,3)
        intersection = torch.sum(probs * target_onehot, dims)
        cardinality = torch.sum(probs + target_onehot, dims)
        dice_score = (2. * intersection + self.eps) / (cardinality + self.eps)
        loss = 1. - dice_score.mean()
        return loss

class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=None, reduction='mean'):
        super().__init__()
        self.gamma = gamma
        if alpha is not None:
            self.alpha = torch.tensor(alpha, dtype=torch.float32)
        else:
            self.alpha = None
        self.reduction = reduction

    def forward(self, logits, target):
        # logits: [B,C,H,W], target: [B,H,W]
        ce = nn.functional.cross_entropy(logits, target, reduction='none')
        probs = torch.softmax(logits, dim=1)
        pt = probs.gather(1, target.unsqueeze(1)).squeeze(1)  # [B,H,W]
        focal_term = (1 - pt) ** self.gamma
        loss = focal_term * ce
        if self.alpha is not None:
            alpha = self.alpha.to(logits.device)
            at = alpha.gather(0, target.flatten()).view_as(target).to(logits.device)
            loss = at * loss
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss

# -----------------------------
# Metrics
# -----------------------------

def confusion_matrix_from_logits(logits, target):
    # logits: [B,C,H,W] or [B,1,H,W]
    preds = torch.argmax(logits, dim=1).view(-1).cpu().numpy()
    t = target.view(-1).cpu().numpy()
    cm = np.zeros((num_classes, num_classes), dtype=np.int64)
    for gt, pd in zip(t, preds):
        cm[gt, pd] += 1
    return cm


def compute_metrics_from_cm(cm):
    # cm: [num_classes,num_classes] where rows=gt, cols=pred
    num_classes = cm.shape[0]
    eps = 1e-6
    tp = np.diag(cm).astype(float)
    fp = np.sum(cm, axis=0) - tp
    fn = np.sum(cm, axis=1) - tp
    tn = cm.sum() - (tp + fp + fn)
    # per-class IoU
    iou = tp / (tp + fp + fn + eps)
    mean_iou = np.nanmean(iou)
    pixel_acc = tp.sum() / (cm.sum() + eps)
    # per-class accuracy: tp / (tp + fn)
    class_acc = tp / (tp + fn + eps)
    mean_acc = np.nanmean(class_acc)
    precision = tp / (tp + fp + eps)
    recall = tp / (tp + fn + eps)
    f1 = 2 * precision * recall / (precision + recall + eps)
    return {
        'iou_per_class': iou,
        'mean_iou': mean_iou,
        'pixel_acc': pixel_acc,
        'mean_acc': mean_acc,
        'precision_per_class': precision,
        'recall_per_class': recall,
        'f1_per_class': f1,
        'tp': tp,
        'fp': fp,
        'tn': tn,
        'fn': fn,
    }

# -----------------------------
# Train / Eval loops
# -----------------------------

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    running_cm = np.zeros((2,2), dtype=np.int64)
    n = 0
    for imgs, masks, _ in loader:
        imgs = imgs.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        logits = model(imgs)['out'] if isinstance(model(imgs), dict) else model(imgs)
        loss = criterion(logits, masks)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
        cm = confusion_matrix_from_logits(logits.detach(), masks.detach())
        running_cm += cm
        n += imgs.size(0)
    avg_loss = running_loss / n
    metrics = compute_metrics_from_cm(running_cm)
    return avg_loss, metrics['mean_iou']


def eval_one_epoch(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    running_cm = np.zeros((2,2), dtype=np.int64)
    n = 0
    with torch.no_grad():
        for imgs, masks, _ in loader:
            imgs = imgs.to(device)
            masks = masks.to(device)
            logits = model(imgs)['out'] if isinstance(model(imgs), dict) else model(imgs)
            loss = criterion(logits, masks)
            running_loss += loss.item() * imgs.size(0)
            cm = confusion_matrix_from_logits(logits, masks)
            running_cm += cm
            n += imgs.size(0)
    avg_loss = running_loss / n
    metrics = compute_metrics_from_cm(running_cm)
    return avg_loss, metrics['mean_iou']

# -----------------------------
# Full evaluation utilities
# -----------------------------

def evaluate_model_on_test(model, testset, criterion, device):
    # As requested, use batch_size = len(testset)
    testloader = DataLoader(testset, batch_size=len(testset))
    return eval_one_epoch(model, testloader, criterion, device)


def detailed_test_evaluation(model, testset, criterion, device, viz_dir=None, per_image_csv=None, overall_image_csv=None, test_orig_sizes=None):
    """Evaluate overall metrics and per-image metrics + save visualizations.

    Adds per-image TP/FP/TN/FN/Total pixel counts (binary positive=class 1),
    and per-class IoU values to the per-image CSV.
    Also appends aggregate binary counts to the returned overall metrics.
    """
    model.eval()
    all_cm = np.zeros((num_classes, num_classes), dtype=np.int64)
    per_image_rows = []
    testloader = DataLoader(testset, batch_size=1, shuffle=False)
    with torch.no_grad():
        for imgs, masks, names in testloader:
            imgs = imgs.to(device)
            masks = masks.to(device)
            logits = model(imgs)['out'] if isinstance(model(imgs), dict) else model(imgs)
            loss = criterion(logits, masks).item()
            cm = confusion_matrix_from_logits(logits, masks)
            all_cm += cm
            m = compute_metrics_from_cm(cm)

            pred = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy().astype(np.uint8)
            gt = masks.squeeze(0).cpu().numpy().astype(np.uint8)

            # Save visualization (if requested)
            if viz_dir is not None:
                save_visualization(imgs.squeeze(0).cpu(), gt, pred, names[0], viz_dir, pred_dir, test_orig_sizes=test_orig_sizes)

            # Binary confusion counts for class 1 (debris) vs class 0 (background)
            # For 2-class case: TP = cm[1,1], FP = cm[0,1], FN = cm[1,0], TN = cm[0,0]
            total_pixels = int(cm.sum())

            # per-class IoU
            iou_c = m['iou_per_class']
            iou_c0 = float(iou_c[0])
            iou_c1 = float(iou_c[1])

            row = {
                'image': names[0],
                'loss': loss,
                'mIoU': m['mean_iou'],
                'pixel_acc': m['pixel_acc'],
                'mean_acc': m['mean_acc'],
                'iou_class0': iou_c0,
                'iou_class1': iou_c1,
                'precision_class0': m['precision_per_class'][0],
                'precision_class1': m['precision_per_class'][1],
                'recall_class0': m['recall_per_class'][0],
                'recall_class1': m['recall_per_class'][1],
                'f1_class0': m['f1_per_class'][0],
                'f1_class1': m['f1_per_class'][1],
                'TP': m['tp'],
                'FP': m['fp'],
                'TN': m['tn'],
                'FN': m['fn'],
                'Total': total_pixels,
            }
            per_image_rows.append(row)
    
    if per_image_csv is not None:
        df = pd.DataFrame(per_image_rows)
        df.to_csv(per_image_csv, index=False)

    overall = compute_metrics_from_cm(all_cm)
    # Add aggregate binary counts to the overall metrics for convenience

    overall_row = {
        'pixel_acc': overall['pixel_acc'],
        'mean_acc': overall['mean_acc'],
        'iou_class0': float(overall['iou_per_class'][0]),
        'iou_class1': float(overall['iou_per_class'][1]),
        'precision_class0': overall['precision_per_class'][0],
        'precision_class1': overall['precision_per_class'][1],
        'recall_class0': overall['recall_per_class'][0],
        'recall_class1': overall['recall_per_class'][1],
        'f1_class0': overall['f1_per_class'][0],
        'f1_class1': overall['f1_per_class'][1],
        'TP': overall['tp'],
        'FP': overall['fp'],
        'TN': overall['tn'],
        'FN': overall['fn'],
        'Total': int(all_cm.sum()),
    }

    if overall_image_csv is not None:
        df = pd.DataFrame([overall_row])
        df.to_csv(overall_image_csv, index=False)

    return overall_row, per_image_rows


def save_visualization(img_tensor, gt_mask, pred_mask, name, viz_dir, pred_dir, test_orig_sizes={}):
    """Save side-by-side visualization (Input | GT | Pred) and also save the predicted mask
    as a separate grayscale PNG under viz_dir/pred_masks/ with filename <name>_pred_mask.png.
    """
    # img_tensor: normalized tensor; convert back to RGB
    img = img_tensor.clone()
    img = img * torch.tensor([0.229,0.224,0.225]).view(3,1,1)
    img = img + torch.tensor([0.485,0.456,0.406]).view(3,1,1)
    img = img.clamp(0,1).permute(1,2,0).numpy()
    gt = gt_mask
    pred = pred_mask

    # Stack horizontally
    fig, axes = plt.subplots(1,3, figsize=(12,4))
    axes[0].imshow(img)
    axes[0].set_title('Input')
    axes[0].axis('off')
    axes[1].imshow(gt, cmap='gray')
    axes[1].set_title('GT')
    axes[1].axis('off')
    axes[2].imshow(pred, cmap='gray')
    axes[2].set_title('Pred')
    axes[2].axis('off')
    plt.tight_layout()

    # Ensure visualization dir exists and save combined image
    os.makedirs(viz_dir, exist_ok=True)
    outpath = os.path.join(viz_dir, f"{os.path.splitext(name)[0]}_viz.png")
    plt.savefig(outpath, dpi=150)
    plt.close(fig)

    # Also save predicted mask as a grayscale PNG (0 or 255 values) in 'pred_masks' subdir
    os.makedirs(pred_dir, exist_ok=True)
    # pred may be 0/1; convert to 0/255 uint8
    pred_img = (pred.astype(np.uint8) * 255)
    pred_pil = Image.fromarray(pred_img, mode='L')
    # Determine orig size: precedence argument -> test_orig_sizes mapping
    stem = os.path.splitext(name)[0]
    if test_orig_sizes is not None:
        orig = test_orig_sizes.get(stem, None)
        # orig is (H, W)
        orig_w = int(orig[1])
        orig_h = int(orig[0])
        pred_pil = pred_pil.resize((orig_w, orig_h), resample=Image.NEAREST)
    pred_out = os.path.join(pred_dir, f"{stem}_pred_mask.png")
    pred_pil.save(pred_out)

# -----------------------------
# Plot helpers
# -----------------------------

def plot_loss(history_csv, outpath):
    df = pd.read_csv(history_csv)
    plt.figure()
    plt.plot(df['epoch'], df['train_loss'], label='train_loss')
    plt.plot(df['epoch'], df['val_loss'], label='val_loss')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend()
    plt.grid(True)
    plt.savefig(outpath)
    plt.close()


def plot_acc(history_csv, outpath):
    df = pd.read_csv(history_csv)
    plt.figure()
    plt.plot(df['epoch'], df['train_mIoU'], label='train_mIoU')
    plt.plot(df['epoch'], df['val_mIoU'], label='val_mIoU')
    plt.xlabel('epoch')
    plt.ylabel('mIoU')
    plt.legend()
    plt.grid(True)
    plt.savefig(outpath)
    plt.close()

# -----------------------------
# Main pipeline
# -----------------------------

def main():
    start_time = datetime.datetime.now().strftime('%Y%m%d_%H:%M:%S')
    # Prepare datasets
    train_imgs, train_masks = pair_images_and_masks(train_img_dir, train_mask_dir)
    test_imgs, test_masks, test_orig_sizes = pair_test_images_and_masks(test_img_dir, test_mask_dir)
    assert len(train_imgs) > 0, 'No training images found'
    assert len(test_imgs) > 0, 'No test images found'

    # Split train/val if val dirs not provided
    if val_img_dir is None or val_mask_dir is None:
        n = len(train_imgs)
        nval = max(1, int(n * valsplit))
        # simple split
        val_imgs = train_imgs[:nval]
        val_masks = train_masks[:nval]
        train_imgs2 = train_imgs[nval:]
        train_masks2 = train_masks[nval:]
    else:
        val_imgs, val_masks = pair_images_and_masks(val_img_dir, val_mask_dir)
        train_imgs2, train_masks2 = train_imgs, train_masks

    trainset = MyDataset(train_imgs2, train_masks2, img_size=IMG_SIZE)
    valset = MyDataset(val_imgs, val_masks, img_size=IMG_SIZE)
    testset = MyDataset(test_imgs, test_masks, img_size=IMG_SIZE)

    # Dataloaders
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)
    valloader = DataLoader(valset, batch_size=min(len(valset), batch_size), shuffle=False, num_workers=0)

    # Class weights
    class_weights, counts, freq = compute_class_weights(trainset)
    class_weights = torch.tensor(class_weights, dtype=torch.float32).to(DEVICE)
    print(f"Class counts: {counts}, freq: {freq}, weights: {class_weights}")

    # Prepare model (use torchvision's deeplabv3_resnet101)
    model = deeplabv3_resnet101(weights=DeepLabV3_ResNet101_Weights.DEFAULT)
    # ヘッドを置き換え
    model.classifier[4] = nn.Conv2d(256, num_classes, kernel_size=1)
    model.aux_classifier[4] = nn.Conv2d(256, num_classes, kernel_size=1)
    model.to(DEVICE)

    # Loss selection
    if loss_type == 'ce':
        criterion = nn.CrossEntropyLoss(weight=class_weights)
    elif loss_type == 'dice':
        criterion = DiceLoss()
    elif loss_type == 'focal':
        alpha = alpha_focal if alpha_focal is not None else class_weights.cpu().numpy().tolist()
        criterion = FocalLoss(gamma=gamma_focal, alpha=alpha)
    else:
        raise ValueError('Unknown loss type')

    # Optimizer
    if optimizer_type.lower() == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

    # History store
    history = []

    for epoch in range(1, num_epochs + 1):
        train_loss, train_miou = train_one_epoch(model, trainloader, optimizer, criterion, DEVICE)
        val_loss, val_miou = eval_one_epoch(model, valloader, criterion, DEVICE)

        print(f"Epoch {epoch}: train_loss={train_loss:.4f}, train_mIoU={train_miou:.4f}, val_loss={val_loss:.4f}, val_mIoU={val_miou:.4f}")
        history.append({'epoch': epoch, 'train_loss': train_loss, 'train_mIoU': train_miou, 'val_loss': val_loss, 'val_mIoU': val_miou})

        # Save every 10 epoch (to support per10-epoch test evaluation).
        if epoch % 10 == 0:
            model_path = os.path.join(model_dir, f"{filedate}_epoch{epoch:03d}.pth")
            os.makedirs(model_dir, exist_ok=True)
            torch.save(model.state_dict(), model_path)

        # Save history to CSV each epoch
        df = pd.DataFrame(history)
        df.to_csv(history_csv, index=False)

    # After training: evaluate all saved epoch models on test set
    print('Evaluating saved epoch models on test set...')
    model_files = sorted(glob.glob(os.path.join(model_dir, f"{filedate}_epoch*.pth")))
    per_epoch_results = []
    for p in model_files:
        state = torch.load(p, map_location=DEVICE)
        model.load_state_dict(state)
        test_loss, test_miou = evaluate_model_on_test(model, testset, criterion, DEVICE)
        per_epoch_results.append({'model_file': os.path.basename(p), 'test_loss': test_loss, 'test_mIoU': test_miou})
        print(f"Model {os.path.basename(p)} -> test_loss={test_loss:.4f}, test_mIoU={test_miou:.4f}")

    per_epoch_df = pd.DataFrame(per_epoch_results)
    per_epoch_csv = os.path.join(result_dir, f"per_10epoch_test_results_{filedate}.csv")
    os.makedirs(result_dir, exist_ok=True)
    per_epoch_df.to_csv(per_epoch_csv, index=False)

    # Select best model by test mIoU
    best_row = per_epoch_df.loc[per_epoch_df['test_mIoU'].idxmax()]
    best_model_file = os.path.join(model_dir, best_row['model_file'])
    print(f"Best model: {best_model_file} with test_mIoU={best_row['test_mIoU']:.4f}")
    best_epoch = best_row['model_file'].split('_')[-1].split('.')[0]
    state = torch.load(best_model_file, map_location=DEVICE)
    model.load_state_dict(state)

    # Detailed evaluation with best model
    per_image_csv = os.path.join(result_dir, f"Bestmodel_{best_epoch}_per_image_metrics_{filedate}.csv")
    overall_image_csv = os.path.join(result_dir, f"Bestmodel_{best_epoch}_overall_image_metrics_{filedate}.csv")
    overall_metrics, per_image_rows = detailed_test_evaluation(
        model, testset, criterion, DEVICE, viz_dir=viz_dir, per_image_csv=per_image_csv, overall_image_csv=overall_image_csv, test_orig_sizes=test_orig_sizes)

    # Plots
    os.makedirs(plots_dir, exist_ok=True)
    plot_loss(history_csv, os.path.join(plots_dir, f"loss_{filedate}.png"))
    plot_acc(history_csv, os.path.join(plots_dir, f"acc_{filedate}.png"))

    end_time = datetime.datetime.now().strftime('%Y%m%d_%H:%M:%S')
    with open(runinfo_txt, 'w') as f:
        f.write(f"Start: {start_time}\n")
        f.write(f"End: {end_time}\n")

    print('-----TRAINING AND EVALUATION ALL DONE!!!-----')
    print('Outputs:')
    print(f" - models: {model_dir}")
    print(f" - per-epoch test CSV: {per_epoch_csv}")
    print(f" - best model: {best_model_file}")
    print(f" - per-image CSV: {per_image_csv}")
    print(f" - visualizations: {viz_dir}")
    print(f" - history CSV: {history_csv}")
    print(f" - plots: {plots_dir}")
    print(f" - run info: {runinfo_txt}")

main()


Library import successful.
Class counts: [21940  2636], freq: [0.89274089 0.10725911], weights: tensor([0.5601, 4.6616], device='cuda:0')
Epoch 1: train_loss=0.6671, train_mIoU=0.3787, val_loss=0.7981, val_mIoU=0.3262
Epoch 2: train_loss=0.6413, train_mIoU=0.3968, val_loss=0.7598, val_mIoU=0.3282
Epoch 3: train_loss=0.6020, train_mIoU=0.4157, val_loss=0.7702, val_mIoU=0.3563
Epoch 4: train_loss=0.5479, train_mIoU=0.4774, val_loss=0.8045, val_mIoU=0.3338
Epoch 5: train_loss=0.5157, train_mIoU=0.4993, val_loss=0.7192, val_mIoU=0.3192
Epoch 6: train_loss=0.4658, train_mIoU=0.4826, val_loss=0.8254, val_mIoU=0.2977
Epoch 7: train_loss=0.4228, train_mIoU=0.5110, val_loss=0.9108, val_mIoU=0.2831
Epoch 8: train_loss=0.3859, train_mIoU=0.5439, val_loss=0.9730, val_mIoU=0.2951
Epoch 9: train_loss=0.3654, train_mIoU=0.5791, val_loss=0.9736, val_mIoU=0.2572
Epoch 10: train_loss=0.3285, train_mIoU=0.5951, val_loss=0.9957, val_mIoU=0.3104
Evaluating saved epoch models on test set...
Model 20260103_1

In [None]:
# Record original sizes for test masks and redefine utilities to use them
# Builds mapping: stem -> (H, W)
test_orig_sizes = {}
for p in glob.glob(os.path.join(test_mask_dir, '*')):
    stem = os.path.splitext(os.path.basename(p))[0]
    try:
        with Image.open(p) as im:
            w, h = im.size
        test_orig_sizes[stem] = (h, w)
    except Exception:
        pass


def save_visualization(img_tensor, gt_mask, pred_mask, name, viz_dir, pred_dir, orig_size=None):
    """Save combined visualization and save predicted mask resized to original size (if available).
    pred_mask is expected as a numpy array (H_resized, W_resized) with values 0/1.
    """
    # Convert image back to display RGB
    img = img_tensor.clone()
    img = img * torch.tensor([0.229,0.224,0.225]).view(3,1,1)
    img = img + torch.tensor([0.485,0.456,0.406]).view(3,1,1)
    img = img.clamp(0,1).permute(1,2,0).numpy()

    gt = gt_mask
    pred = pred_mask

    # Combined visualization (at resized IMG_SIZE)
    fig, axes = plt.subplots(1,3, figsize=(12,4))
    axes[0].imshow(img)
    axes[0].set_title('Input')
    axes[0].axis('off')
    axes[1].imshow(gt, cmap='gray')
    axes[1].set_title('GT')
    axes[1].axis('off')
    axes[2].imshow(pred, cmap='gray')
    axes[2].set_title('Pred')
    axes[2].axis('off')
    plt.tight_layout()

    os.makedirs(viz_dir, exist_ok=True)
    outpath = os.path.join(viz_dir, f"{os.path.splitext(name)[0]}_viz.png")
    plt.savefig(outpath, dpi=150)
    plt.close(fig)

    # Save predicted mask as grayscale PNG in pred_dir, resized to original size if available
    os.makedirs(pred_dir, exist_ok=True)
    pred_img = (pred.astype(np.uint8) * 255)
    pred_pil = Image.fromarray(pred_img, mode='L')
    # Determine orig size: precedence argument -> test_orig_sizes mapping
    stem = os.path.splitext(name)[0]
    orig = orig_size if orig_size is not None else test_orig_sizes.get(stem, None)
    if orig is not None:
        # orig is (H, W)
        orig_w = int(orig[1])
        orig_h = int(orig[0])
        pred_pil = pred_pil.resize((orig_w, orig_h), resample=Image.NEAREST)
    pred_out = os.path.join(pred_dir, f"{stem}_pred_mask.png")
    pred_pil.save(pred_out)


# Redefine detailed_test_evaluation to use orig sizes when saving pred masks
def detailed_test_evaluation(model, testset, criterion, device, viz_dir=None, per_image_csv=None, overall_image_csv=None):
    model.eval()
    all_cm = np.zeros((num_classes, num_classes), dtype=np.int64)
    per_image_rows = []
    testloader = DataLoader(testset, batch_size=1, shuffle=False)
    with torch.no_grad():
        for batch in testloader:
            imgs = batch[0].to(device)
            masks = batch[1].to(device)
            names = batch[2]
            # orig_size if dataset provides it as 4th item in batch, else lookup from mapping
            orig_size = None
            if len(batch) > 3:
                # batch[3] is a list of orig_size tuples; take first element
                orig_size = batch[3][0]
            logits = model(imgs)['out'] if isinstance(model(imgs), dict) else model(imgs)
            loss = criterion(logits, masks).item()
            cm = confusion_matrix_from_logits(logits, masks)
            all_cm += cm
            m = compute_metrics_from_cm(cm)

            pred = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy().astype(np.uint8)
            gt = masks.squeeze(0).cpu().numpy().astype(np.uint8)

            # Save visualization and predicted mask (resized to orig size)
            if viz_dir is not None:
                save_visualization(imgs.squeeze(0).cpu(), gt, pred, names[0], viz_dir, pred_dir, orig_size=orig_size)

            total_pixels = int(cm.sum())
            iou_c = m['iou_per_class']
            iou_c0 = float(iou_c[0])
            iou_c1 = float(iou_c[1])

            row = {
                'image': names[0],
                'loss': loss,
                'mIoU': m['mean_iou'],
                'pixel_acc': m['pixel_acc'],
                'mean_acc': m['mean_acc'],
                'iou_class0': iou_c0,
                'iou_class1': iou_c1,
                'precision_class0': m['precision_per_class'][0],
                'precision_class1': m['precision_per_class'][1],
                'recall_class0': m['recall_per_class'][0],
                'recall_class1': m['recall_per_class'][1],
                'f1_class0': m['f1_per_class'][0],
                'f1_class1': m['f1_per_class'][1],
                'TP': int(m['tp'].sum()) if hasattr(m['tp'], 'sum') else int(m['tp']),
                'FP': int(m['fp'].sum()) if hasattr(m['fp'], 'sum') else int(m['fp']),
                'TN': int(m['tn'].sum()) if hasattr(m['tn'], 'sum') else int(m['tn']),
                'FN': int(m['fn'].sum()) if hasattr(m['fn'], 'sum') else int(m['fn']),
                'Total': total_pixels,
            }
            per_image_rows.append(row)

    if per_image_csv is not None:
        df = pd.DataFrame(per_image_rows)
        df.to_csv(per_image_csv, index=False)

    overall = compute_metrics_from_cm(all_cm)
    overall_row = {
        'pixel_acc': overall['pixel_acc'],
        'mean_acc': overall['mean_acc'],
        'iou_class0': float(overall['iou_per_class'][0]),
        'iou_class1': float(overall['iou_per_class'][1]),
        'precision_class0': overall['precision_per_class'][0],
        'precision_class1': overall['precision_per_class'][1],
        'recall_class0': overall['recall_per_class'][0],
        'recall_class1': overall['recall_per_class'][1],
        'f1_class0': overall['f1_per_class'][0],
        'f1_class1': overall['f1_per_class'][1],
        'TP': int(overall['tp'].sum()) if hasattr(overall['tp'], 'sum') else int(overall['tp']),
        'FP': int(overall['fp'].sum()) if hasattr(overall['fp'], 'sum') else int(overall['fp']),
        'TN': int(overall['tn'].sum()) if hasattr(overall['tn'], 'sum') else int(overall['tn']),
        'FN': int(overall['fn'].sum()) if hasattr(overall['fn'], 'sum') else int(overall['fn']),
        'Total': int(all_cm.sum()),
    }

    if overall_image_csv is not None:
        df = pd.DataFrame([overall_row])
        df.to_csv(overall_image_csv, index=False)

    return overall_row, per_image_rows