In [1]:
import time
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

import cv2
import numpy as np
import os
from PIL import Image
import torchvision.transforms as T
from tqdm import tqdm
from glob import glob
from pathlib import Path

In [2]:


# Load pretrained landmark model (68 points, 2D)
device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [3]:
import face_alignment
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, device=device, flip_input=False)


In [4]:
def create_nose_mask(image_path, save_path, log):
    try:
        img = cv2.imread(image_path)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
        preds = fa.get_landmarks(img_rgb)
        if preds is None:
            return False
            
    except Warning as w:  # Catch warnings as exceptions
        return False

    landmarks = preds[0]
    nose_points = landmarks[27:36]  # indexes 27–35 in 0-based Python index

    mask = np.zeros(img.shape[:2], dtype=np.uint8)
    cv2.fillConvexPoly(mask, np.int32(nose_points), 255)

    cv2.imwrite(save_path, mask)
    return True

In [5]:
import os
import warnings
import logging

def batch_create_nose_masks(src_dir, dst_dir, create_func, exts=('.jpg','.jpeg','.png')):
    """
    Generate nose masks for all images in src_dir and save in dst_dir.
    Skips files if the mask already exists.
    """
    os.makedirs(dst_dir, exist_ok=True)
    warnings.filterwarnings("error", message="No faces were detected.")

    for file in os.listdir(src_dir):
        if file.lower().endswith(exts):
            in_path  = os.path.join(src_dir, file)
            out_path = os.path.join(dst_dir, file)
            if not os.path.exists(out_path):
                try:
                    create_func(in_path, out_path, logging)
                except Warning as w:
                    logging.warning(f"Skipped {file}: {w}")
                except Exception as e:
                    logging.error(f"Error with {file}: {e}")

base_dir = "/workspace/data_splits"
splits = ["train", "val", "test"]

for split in splits:
    input_folder = os.path.join(base_dir, split, "input")
    mask_folder  = os.path.join(base_dir, split, "mask_input")
    batch_create_nose_masks(input_folder, mask_folder, create_nose_mask)

In [6]:
def img_to_patches(x, patch_size):
    # x: (B, C, H, W)
    B, C, H, W = x.shape
    assert H % patch_size == 0 and W % patch_size == 0
    ph, pw = patch_size, patch_size
    nh, nw = H // ph, W // pw
    x = x.reshape(B, C, nh, ph, nw, pw)
    x = x.permute(0,2,4,3,5,1).reshape(B, nh*nw, ph*pw*C)  # (B, N, patch_dim)
    return x, (nh, nw)

In [7]:
def patches_to_img(patches, patch_size, nh_nw, C):
    # patches: (B, N, patch_dim)
    B, N, D = patches.shape
    nh, nw = nh_nw
    ph = pw = patch_size
    x = patches.reshape(B, nh, nw, ph, pw, C).permute(0,5,1,3,2,4).reshape(B, C, nh*ph, nw*pw)
    return x

In [8]:
def mask_to_patch_mask(mask, patch_size):
    # mask: (B,1,H,W) binary [0,1]
    B, _, H, W = mask.shape
    ph = pw = patch_size
    nh, nw = H//ph, W//pw
    mask = mask.reshape(B, 1, nh, ph, nw, pw)
    mask = mask.mean(dim=(3,5))  # (B,1,nh,nw)
    patch_mask = (mask.view(B, nh*nw) > 0.1).float()  # (B, N)
    return patch_mask  # 1 where patch contains nos

In [9]:
from pathlib import Path

IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}

def _map_stem_to_name(dir_path: Path):
    """Return {stem: filename} for allowed image files (first match per stem)."""
    m = {}
    for p in dir_path.iterdir():
        if p.is_file() and p.suffix.lower() in IMG_EXTS:
            m.setdefault(p.stem, p.name)  # keep first occurrence
    return m

def validate_files(img_dir, mask_dir, target_dir, verbose=True):
    img_dir, mask_dir, target_dir = map(Path, (img_dir, mask_dir, target_dir))

    img_map    = _map_stem_to_name(img_dir)
    mask_map   = _map_stem_to_name(mask_dir)
    target_map = _map_stem_to_name(target_dir)

    img_stems    = set(img_map.keys())
    mask_stems   = set(mask_map.keys())
    target_stems = set(target_map.keys())

    common = sorted(img_stems & mask_stems & target_stems)

    missing_mask   = sorted(img_stems - mask_stems)
    missing_target = sorted(img_stems - target_stems)
    orphan_masks   = sorted(mask_stems - img_stems)
    orphan_targets = sorted(target_stems - img_stems)

    if verbose:
        print(f"[validate] counts: img={len(img_stems)} mask={len(mask_stems)} target={len(target_stems)}")
        print(f"[validate] common triples: {len(common)}")
        if missing_mask:   print(f"[validate] missing masks for {len(missing_mask)} imgs, e.g. {missing_mask[:5]}")
        if missing_target: print(f"[validate] missing targets for {len(missing_target)} imgs, e.g. {missing_target[:5]}")
        if orphan_masks:   print(f"[validate] masks without imgs: {len(orphan_masks)}, e.g. {orphan_masks[:5]}")
        if orphan_targets: print(f"[validate] targets without imgs: {len(orphan_targets)}, e.g. {orphan_targets[:5]}")

    # Build (img_filename, target_filename) using the exact filenames found
    pairs = [(img_map[s], target_map[s]) for s in common]

    if len(pairs) == 0:
        raise ValueError(
            "No valid (img, mask, target) triples found.\n"
            f"Checked:\n  img_dir={img_dir}\n  mask_dir={mask_dir}\n  target_dir={target_dir}\n"
            "See [validate] logs above for mismatches."
        )
    return pairs

In [10]:
import torch
import torch.nn.functional as F

def collate_keep_aspect(batch, multiple=32):
    """
    Batch collate that keeps aspect ratios.
    Pads all images/masks/targets in the batch to the max H and W,
    rounded up to the nearest multiple.
    """
    imgs, tgts, masks, orig_hw, files = [], [], [], [], []

    for b in batch:
        x = b["input"]    # [3,H,W]
        y = b["target"]   # [3,H,W]
        m = b["mask"]     # [1,H,W]
        H, W = x.shape[1:]

        imgs.append(x)
        tgts.append(y)
        masks.append(m)
        orig_hw.append(torch.tensor([H, W], dtype=torch.int32))
        files.append((b.get("input_file",""), b.get("mask_file",""), b.get("target_file","")))

    # find max height/width in this batch
    Ht = max(t.shape[1] for t in imgs)
    Wt = max(t.shape[2] for t in imgs)

    # round up to nearest multiple (stride)
    Ht = (Ht + multiple - 1) // multiple * multiple
    Wt = (Wt + multiple - 1) // multiple * multiple

    def pad_to(t, Ht, Wt):
        # pad as (left, right, top, bottom)
        return F.pad(t, (0, Wt - t.shape[2], 0, Ht - t.shape[1]))

    X = torch.stack([pad_to(t, Ht, Wt) for t in imgs])
    Y = torch.stack([pad_to(t, Ht, Wt) for t in tgts])
    M = torch.stack([pad_to(t, Ht, Wt) for t in masks])
    OHW = torch.stack(orig_hw)

    return {
        "input": X,     # [B,3,Ht,Wt]
        "target": Y,    # [B,3,Ht,Wt]
        "mask": M,      # [B,1,Ht,Wt]
        "orig_hw": OHW, # [B,2]
        "files": files
    }

In [11]:
from pathlib import Path
from torch.utils.data import Dataset
import torch, cv2
import numpy as np

class NoseFolderDataset(Dataset):
    def __init__(self, img_dir, mask_dir, target_dir, size=256):
        self.img_dir = Path(img_dir)
        self.mask_dir = Path(mask_dir)
        self.target_dir = Path(target_dir)
        self.file_pairs = validate_files(self.img_dir, self.mask_dir, self.target_dir)   # [(img_name, tgt_name), ...]

    def __len__(self):
        return len(self.file_pairs)

    def _read_rgb(self, path):
        img = cv2.imread(str(path), cv2.IMREAD_COLOR)
        if img is None:
            raise FileNotFoundError(path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)        # (H,W,3) uint8
        return img

    def _read_mask(self, path):
        
        m = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
        if m is None:
            raise FileNotFoundError(path)
        m = (m > 127).astype(np.uint8)                    # (H,W) 0/1
        return m

    def __getitem__(self, idx):
        fname, target_f = self.file_pairs[idx]
        ip, mp, tp = self.img_dir/fname, self.mask_dir/fname, self.target_dir/target_f

        img    = self._read_rgb(ip)     # H,W,3
        target = self._read_rgb(tp)     # H,W,3
        mask   = self._read_mask(mp)    # H,W

        if img.shape[:2] != target.shape[:2]:
            # minimal: resize target to img (or vice-versa, pick one consistently)
            target = cv2.resize(target, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_AREA)
        if img.shape[:2] != mask.shape[:2]:
            mask = cv2.resize(mask, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST)

        H, W = img.shape[:2]
        img_t  = torch.from_numpy(img).permute(2,0,1).float()/255.0
        tgt_t  = torch.from_numpy(target).permute(2,0,1).float()/255.0
        mask_t = torch.from_numpy(mask).unsqueeze(0).float()

        return {
            "input": img_t, "target": tgt_t, "mask": mask_t,
            "orig_h": H, "orig_w": W,
            "input_file": str(ip), "mask_file": str(mp), "target_file": str(tp),
        }

In [13]:
# --------------------------- Generator (UnE) ----------

import torch
import torch.nn as nn
import torch.nn.functional as F

# ----- tiny mask helpers
def dilate_mask_binary(mask, k=11, iters=2):
    out = mask
    for _ in range(iters):
        out = F.max_pool2d(out, kernel_size=k, stride=1, padding=k//2)
    return out.clamp(0,1)

def feather_mask(mask, k=9):
    return F.avg_pool2d(mask, kernel_size=k, stride=1, padding=k//2).clamp(0,1)


# --------------------------- UNet blocks ---------------------------
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.InstanceNorm2d(out_ch, affine=True),
            nn.SiLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.InstanceNorm2d(out_ch, affine=True),
            nn.SiLU(inplace=True),
        )
    def forward(self, x): return self.block(x)

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.pool = nn.MaxPool2d(2)
        self.conv = ConvBlock(in_ch, out_ch)
    def forward(self, x): return self.conv(self.pool(x))

class Up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up   = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.conv = ConvBlock(in_ch, out_ch)
    def forward(self, x, skip):
        x = self.up(x)
        # pad if needed (in case odd dims)
        diffY = skip.size(2) - x.size(2)
        diffX = skip.size(3) - x.size(3)
        if diffY != 0 or diffX != 0:
            x = F.pad(x, (0, diffX, 0, diffY))
        x = torch.cat([skip, x], dim=1)
        return self.conv(x)


# --------------------------- UNet Generator (residual) ---------------------------
class UNetNoseGenerator(nn.Module):
    """
    UNet that takes RGB+mask (4ch) and predicts a 3ch residual 'delta'.
    Output = rgb + soft_mask * tanh(raw) * scale
    """
    def __init__(self, in_ch=4, out_ch=3, base=64, depth=5, res_max=0.75):
        """
        depth=5 -> downsample x2 five times (stride 32). Use collate multiple=32.
        For less memory, set depth=4 (stride 16).
        """
        super().__init__()
        self.out_ch = out_ch
        self.res_max = res_max

        # encoder
        self.inc  = ConvBlock(in_ch, base)                 # H
        self.down1 = Down(base,     base*2)                # H/2
        self.down2 = Down(base*2,   base*4)                # H/4
        self.down3 = Down(base*4,   base*8)                # H/8
        self.down4 = Down(base*8,   base*8)                # H/16
        self.has_down5 = (depth >= 5)
        if self.has_down5:
            self.down5 = Down(base*8, base*8)              # H/32

        # decoder
        if self.has_down5:
            self.up1 = Up(base*8 + base*8, base*8)         # concat with down4
            ch_up_in = base*8 + base*8
        else:
            # if no down5, first up will concatenate down3 and bottleneck at base*8
            ch_up_in = base*8 + base*8  # consistent with next lines

        self.up2 = Up(base*8 + base*8, base*8)             # + down3
        self.up3 = Up(base*8 + base*4, base*4)             # + down2
        self.up4 = Up(base*4 + base*2, base*2)             # + down1
        self.up5 = Up(base*2 + base,   base)               # + inc

        self.outc = nn.Conv2d(base, out_ch, kernel_size=3, padding=1)
        nn.init.zeros_(self.outc.weight)
        nn.init.zeros_(self.outc.bias)

        # learnable residual cap
        self._alpha = nn.Parameter(torch.tensor(0.0))  # sigmoid ~0.5 initially

    def forward(self, inp, return_full=False):
        """
        inp: [B,4,H,W]  (RGB + binary/soft mask in channel 4)
        returns: [B,3,H,W] blended output at original size
        """
        rgb   = inp[:, :3]
        mask1 = inp[:, 3:4]

        # encoder
        x1 = self.inc(inp)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        if self.has_down5:
            x6 = self.down5(x5)
            u1 = self.up1(x6, x5)
        else:
            u1 = self.up2(x5, x4)  # skip one level if depth=4

        # decoder path
        if self.has_down5:
            u2 = self.up2(u1, x4)
        else:
            u2 = u1
        u3 = self.up3(u2, x3)
        u4 = self.up4(u3, x2)
        u5 = self.up5(u4, x1)

        raw   = self.outc(u5)               # [B,3,H,W]
        scale = torch.sigmoid(self._alpha) * self.res_max
        delta = torch.tanh(raw) * scale     # bounded residual

        # soft blend within a feathered mask band
        hard   = dilate_mask_binary(mask1, k=11, iters=2)
        m_soft = feather_mask(hard, k=9)            # [B,1,H,W]
        m3     = m_soft.repeat(1, 3, 1, 1)
        full_rgb = (rgb + delta)
        out    = rgb + delta * m3
        if return_full:
            return out, full_rgb, m_soft
        return out



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class VGGPerceptualLoss(nn.Module):
    """
    VGG16 feature L1 loss. Works on inputs in [0,1].
    If mask is given (B,1,H,W), it is downsampled per VGG stage and used
    to weight the feature differences spatially.
    """
    def __init__(self, layers=(3, 8, 15, 22), layer_weights=None):
        super().__init__()
        # relu1_2=3, relu2_2=8, relu3_3=15, relu4_3=22 in torchvision VGG16.features
        self.layers = tuple(layers)
        self.layer_weights = (
            [1.0] * len(self.layers) if layer_weights is None else layer_weights
        )

        vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_FEATURES).features
        self.vgg = vgg.eval()
        for p in self.vgg.parameters():
            p.requires_grad_(False)

        # register mean/std buffers for ImageNet normalization
        mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
        std  = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)
        self.register_buffer('mean', mean, persistent=False)
        self.register_buffer('std',  std,  persistent=False)

    def _norm(self, x):
        # x expected in [0,1]
        return (x - self.mean) / self.std

    def forward(self, pred, target, mask=None):
        """
        pred, target: (B,3,H,W) in [0,1]
        mask (optional): (B,1,H,W) with 0..1 weights (e.g., your nose mask)
        """
        x = self._norm(pred)
        y = self._norm(target)

        loss = 0.0
        idx_set = set(self.layers)
        lw_iter = iter(self.layer_weights)

        # run through the VGG and collect selected layers
        for i, layer in enumerate(self.vgg):
            x = layer(x)
            y = layer(y)

            if i in idx_set:
                w = next(lw_iter)
                if mask is None:
                    # plain feature L1
                    loss += w * F.l1_loss(x, y)
                else:
                    # masked feature L1 (downsample mask to feature size)
                    m = F.interpolate(mask, size=x.shape[-2:], mode='bilinear', align_corners=False)
                    num = (m * (x - y).abs()).sum()
                    den = m.sum() + 1e-6
                    loss += w * (num / den)

        return loss

In [None]:
def dilate_mask_binary(mask, k=11, iters=2):
    """Binary dilation. mask: (B,1,H,W) in {0,1} float."""
    out = mask
    for _ in range(iters):
        out = F.max_pool2d(out, kernel_size=k, stride=1, padding=k//2)
    return out.clamp(0,1)

def erode_mask_binary(mask, k=11, iters=1):
    """Binary erosion via max-pool trick."""
    x = 1.0 - mask
    for _ in range(iters):
        x = F.max_pool2d(x, kernel_size=k, stride=1, padding=k//2)
    return (1.0 - x).clamp(0,1)

def feather_mask(mask, k=9):
    """Soft edge (0..1) via avg-pool (feathered band)."""
    return F.avg_pool2d(mask, kernel_size=k, stride=1, padding=k//2).clamp(0,1)

# ---------- patch mask helper ----------
def mask_to_patch_mask(mask, patch_size):
    """
    Convert (B,1,H,W) 0/1 (or soft) mask to (B, N) patch weights by
    averaging within non-overlapping p×p patches.
    """
    B, _, H, W = mask.shape
    p = patch_size
    # trim if H/W not divisible by p
    Ht, Wt = (H // p) * p, (W // p) * p
    m = mask[:, :, :Ht, :Wt]
    # (B,1, H/p, p, W/p, p) -> mean over the two 'p' dims -> (B,1,H/p,W/p)
    m = m.view(B, 1, Ht//p, p, Wt//p, p).mean(dim=(3,5))
    return m.flatten(1)  # (B, N)

# ---------- simple PatchGAN (1-channel real/fake map) ----------
class PatchDiscriminator(nn.Module):
    def __init__(self, in_ch=6, base=64):
        super().__init__()
        c = base
        self.net = nn.Sequential(
            nn.Conv2d(in_ch,   c, 4, 2, 1), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(c,     c*2, 4, 2, 1), nn.BatchNorm2d(c*2), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(c*2,   c*4, 4, 2, 1), nn.BatchNorm2d(c*4), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(c*4,   c*8, 4, 2, 1), nn.BatchNorm2d(c*8), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(c*8,     1, 4, 1, 1)
        )
    def forward(self, x): return self.net(x)

        

# ---------- validation (masked L1 using the same soft mask) ----------
@torch.no_grad()
def validate_epoch(G, loader, device):
    G.eval()
    tot, n = 0.0, 0
    for batch in loader:
        rgb    = batch["input"].to(device)   # (B,4,H,W)  [RGB+mask]
        target = batch["target"].to(device)  # (B,3,H,W)
        mask   = batch["mask"].to(device)    # (B,1,H,W)
        inp = torch.cat([rgb, mask], dim=1)
        # same expansion/feathering as train
        hard   = dilate_mask_binary(mask, k=11, iters=2)
        m_soft = feather_mask(hard, k=9)

        pred = G(inp)

        l1m = ((pred - target).abs() * m_soft).sum() / (m_soft.sum() + 1e-6)
        b = inp.size(0)
        tot += l1m.item() * b
        n   += b
    return tot / max(1,n)

# ---------- training ----------
def train_gan(
    train_loader, val_loader, *,
    epochs=20, out_dir="ckpts_gan",
    lr_G=2e-4, lr_D=1e-4,
    adv_start_w=0.05,
    lambda_l1=5.0, lambda_out_id=0.5,
    use_perc=False, perceptual_fn=None,
    amp=True, device="cuda",
    resume_from=None                 # <--- NEW
):
    os.makedirs(out_dir, exist_ok=True)

    # ----- build models -----
    G = UNetNoseGenerator(in_ch=4, out_ch=3, base=64, depth=5, res_max=0.75).to(device)
    D = PatchDiscriminator(in_ch=6, base=64).to(device)

    # ----- opt -----
    opt_G = torch.optim.AdamW(G.parameters(), lr=lr_G, betas=(0.5, 0.999))
    opt_D = torch.optim.AdamW(D.parameters(), lr=lr_D, betas=(0.5, 0.999))

    # ----- (optional) resume -----
    start_ep = 1
    best = float("inf")
    if resume_from and os.path.isfile(resume_from):
        ckpt = torch.load(resume_from, map_location=device)

        # load weights (strict=True recommended if same arch; relax to False if needed)
        G.load_state_dict(ckpt["G"], strict=True)
        D.load_state_dict(ckpt["D"], strict=True)

        # load optimizers
        if "opt_G" in ckpt: opt_G.load_state_dict(ckpt["opt_G"])
        if "opt_D" in ckpt: opt_D.load_state_dict(ckpt["opt_D"])

        # resume epoch/best
        start_ep = int(ckpt.get("epoch", 0)) + 1
        best = float(ckpt.get("val_l1m", best))
        print(f"[resume] from {resume_from} → start_ep={start_ep}  best={best:.4f}")
    else:
        if resume_from:
            print(f"[resume] file not found: {resume_from} (starting fresh)")

    # ----- schedulers (recreated; positioned via last_epoch) -----
    # We stepped schedulers once per epoch at the END, so set last_epoch=start_ep-2
    # so that after the first loop’s step(), they land on epoch (start_ep-1)→start_ep properly.
    sch_G = torch.optim.lr_scheduler.CosineAnnealingLR(
        opt_G, T_max=epochs, eta_min=lr_G*0.1, last_epoch=start_ep-2
    )
    sch_D = torch.optim.lr_scheduler.CosineAnnealingLR(
        opt_D, T_max=epochs, eta_min=lr_D*0.1, last_epoch=start_ep-2
    )

    scaler_G = torch.amp.GradScaler('cuda', enabled=amp)
    scaler_D = torch.amp.GradScaler('cuda', enabled=amp)

    best_path = None

    for ep in range(start_ep, epochs+1):
        G.train(); D.train()
        pbar = tqdm(train_loader, desc=f"train {ep}/{epochs}")
        adv_w = adv_start_w * min(1.0, ep/5.0)

        for batch in pbar:
            rgb    = batch["input"].to(device)   # (B,3,H,W)
            target = batch["target"].to(device)  # (B,3,H,W)
            mask   = batch["mask"].to(device)    # (B,1,H,W)
            inp    = torch.cat([rgb, mask], dim=1)  # (B,4,H,W)

            hard   = dilate_mask_binary(mask, k=11, iters=2)
            m_soft = feather_mask(hard, k=9)
            outside = 1.0 - erode_mask_binary(hard, k=11, iters=2)

            # ---- D step ----
            opt_D.zero_grad(set_to_none=True)
            with torch.amp.autocast('cuda', enabled=amp):
                with torch.no_grad():
                    fake = G(inp)
                real_logits = D(torch.cat([rgb, target], dim=1))
                fake_logits = D(torch.cat([rgb, fake],   dim=1))
                d_loss = torch.relu(1. - real_logits).mean() + torch.relu(1. + fake_logits).mean()
            scaler_D.scale(d_loss).backward()
            scaler_D.step(opt_D)
            scaler_D.update()

            # ---- G step ----
            opt_G.zero_grad(set_to_none=True)
            with torch.amp.autocast('cuda', enabled=amp):
                pred = G(inp)
                l1_mask  = ((pred - target).abs() * m_soft).sum() / (m_soft.sum() + 1e-6)
                l_out_id = ((pred - rgb).abs()    * outside).sum() / (outside.sum() + 1e-6)
                fake_logits_g = D(torch.cat([rgb, pred], dim=1))
                adv_loss = -fake_logits_g.mean()
                if use_perc and (perceptual_fn is not None):
                    perc_loss = perceptual_fn(pred, target, mask=m_soft)
                else:
                    perc_loss = 0.0
                g_loss = (lambda_l1 * l1_mask
                          + lambda_out_id * l_out_id
                          + adv_w * adv_loss
                          + 0.2 * perc_loss)
            scaler_G.scale(g_loss).backward()
            scaler_G.step(opt_G)
            scaler_G.update()

            pbar.set_postfix({
                "D": f"{d_loss.item():.3f}",
                "G": f"{g_loss.item():.3f}",
                "L1m": f"{l1_mask.item():.4f}",
                "OID": f"{l_out_id.item():.4f}",
                "ADV": f"{adv_loss.item():.4f}",
                "lrG": f"{sch_G.get_last_lr()[0]:.2e}"
            })

        sch_G.step(); sch_D.step()

        # ---- validation ----
        val_l1m = validate_epoch(G, val_loader, device)

        # ---- save ----
        ckpt = {
            "epoch": ep,
            "G": G.state_dict(),
            "D": D.state_dict(),
            "opt_G": opt_G.state_dict(),
            "opt_D": opt_D.state_dict(),
            "val_l1m": val_l1m,
            "cfg": {
                "lambda_l1": lambda_l1,
                "lambda_out_id": lambda_out_id,
                "adv_start_w": adv_start_w
            }
            # (Optional) you can also save sched/scaler state next time:
            # "sch_G": sch_G.state_dict(), "sch_D": sch_D.state_dict(),
            # "scaler_G": scaler_G.state_dict(), "scaler_D": scaler_D.state_dict(),
        }
        last_path = os.path.join(out_dir, "last.pt")
        torch.save(ckpt, last_path)

        if val_l1m < best:
            best = val_l1m
            best_path = os.path.join(out_dir, "best_l1_mask.pt")
            torch.save(ckpt, best_path)
            saved_best = True
        else:
            saved_best = False

        print(f"[epoch {ep}] val L1(mask)={val_l1m:.4f} | best={best:.4f} | saved_best={saved_best}")

    return best_path or last_path

In [None]:
print("autocast is:", getattr(globals(), "autocast", "<no alias>"))
print("torch.autocast:", torch.autocast)
print("GradScaler:", torch.amp.GradScaler)

In [None]:
from torch.utils.data import DataLoader

# Use the keep-size dataset (returns RGB (3ch) + mask (1ch) separately)
train_ds = NoseFolderDataset(
    "/workspace/data_splits/train/input",
    "/workspace/data_splits/train/mask_input",
    "/workspace/data_splits/train/target"
)
val_ds = NoseFolderDataset(
    "/workspace/data_splits/val/input",
    "/workspace/data_splits/val/mask_input",
    "/workspace/data_splits/val/target"
)

MULT = 32

train_loader = DataLoader(
    train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=True,
    persistent_workers=False,
    collate_fn=lambda b: collate_keep_aspect(b, multiple=MULT)   # no token cap needed
)

val_loader = DataLoader(
    val_ds, batch_size=2, shuffle=False, num_workers=2, pin_memory=True,
    persistent_workers=False,
    collate_fn=lambda b: collate_keep_aspect(b, multiple=MULT)
)

In [None]:
if __name__ == "__main__":
    TRAIN_IMG = "/workspace/data_splits/train/input"
    TRAIN_MSK = "/workspace/data_splits/mask_input"
    TRAIN_TGT = "/workspace/data_splits/train/target"

    VAL_IMG   = "/workspace/data_splits/val/input"
    VAL_MSK   = "/workspace/data_splits/val/mask_input"
    VAL_TGT   = "/workspace/data_splits/val/target"

    size = 256
    patch_size = 8
    bs = 8

    perceptual_fn = VGGPerceptualLoss(layers=(3,8,15,22)).to(device)
    
    resume_ckpt = "ckpts_Unet_PatchGan_Resv1/last.pt"

    best_path = train_gan(
    train_loader,
    val_loader,
    epochs=40,                     # total epochs you want to run
    out_dir="ckpts_Unet_PatchGan_Resv1",
    amp=True,
    device=device,
    resume_from=resume_ckpt        # <-- resume here
    )
   

In [60]:
import os, torch
import torch.nn.functional as F
from torchvision.utils import save_image

# ---- helpers: torch <-> numpy, mask prep ----
def _torch_rgb_to_u8_bgr(x: torch.Tensor) -> np.ndarray:
    # x: [3,H,W] float in [0,1]
    x = (x.clamp(0,1) * 255.0).round().byte().cpu().numpy()       # CHW uint8
    x = np.transpose(x, (1,2,0))                                   # HWC
    return cv2.cvtColor(x, cv2.COLOR_RGB2BGR)                      # HWC BGR uint8

def _bgr_u8_to_torch_rgb(x: np.ndarray) -> torch.Tensor:
    # x: HWC BGR uint8
    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
    x = torch.from_numpy(np.transpose(x, (2,0,1))).float() / 255.0 # CHW float
    return x

def _mask_to_u8(mask_01: torch.Tensor, thresh=0.05, dilate=3):
    # mask_01: [1,H,W] float in [0,1], soft or hard
    m = (mask_01.squeeze(0).clamp(0,1) > thresh).byte().cpu().numpy() * 255  # H,W uint8 {0,255}
    if dilate and dilate > 0:
        k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*dilate+1, 2*dilate+1))
        m = cv2.dilate(m, k)
    return m  # H,W uint8

def _mask_center(mask_u8: np.ndarray):
    # center for seamlessClone
    M = cv2.moments(mask_u8)
    if M["m00"] > 0:
        cx = int(M["m10"] / M["m00"])
        cy = int(M["m01"] / M["m00"])
    else:
        h, w = mask_u8.shape[:2]
        cy, cx = h//2, w//2
    return (cx, cy)

# ---- main: Poisson blend one sample ----
def poisson_blend_sample(pred_rgb, base_rgb, mask_01, mode="normal", dilate=2, thresh=0.3):
    src = _torch_rgb_to_u8_bgr(pred_rgb)   # full repaint
    dst = _torch_rgb_to_u8_bgr(base_rgb)   # original
    msk = _mask_to_u8(mask_01, thresh=thresh, dilate=dilate)  # ~binary 0/255
    center = _mask_center(msk)
    flag = cv2.NORMAL_CLONE if mode=="normal" else cv2.MIXED_CLONE
    out = cv2.seamlessClone(src, dst, msk, center, flag)
    return _bgr_u8_to_torch_rgb(out).to(pred_rgb.device)


def save_triplet_fullsize(rgb_i, pred_i, tgt_i, save_dir, stem):
    os.makedirs(save_dir, exist_ok=True)
    save_image(rgb_i.clamp(0,1),  os.path.join(save_dir, f"{stem}_input.png"))
    save_image(pred_i.clamp(0,1), os.path.join(save_dir, f"{stem}_pred.png"))
    save_image(tgt_i.clamp(0,1),  os.path.join(save_dir, f"{stem}_target.png"))

def pad_to_square(img, value=0.0):  # img: [3,H,W]
    _, H, W = img.shape
    S = max(H, W)
    pad_t = (S - H)//2; pad_b = S - H - pad_t
    pad_l = (S - W)//2; pad_r = S - W - pad_l
    return F.pad(img, (pad_l, pad_r, pad_t, pad_b), value=value)  # [3,S,S]

@torch.no_grad()
def run_test_unet(
    ckpt_path,
    test_loader,
    device="cuda",
    save_dir="results_unet",
    panel_mode="none",          # "none" or "letterbox"
    panel_size=512,             # used only when panel_mode="letterbox"
    amp=True
):
    os.makedirs(save_dir, exist_ok=True)

    # Build UNet exactly like in training
    G = UNetNoseGenerator(in_ch=4, out_ch=3, base=64, depth=5, res_max=0.75).to(device)
    ckpt = torch.load(ckpt_path, map_location=device)
    G.load_state_dict(ckpt["G"], strict=True)
    G.eval()

    idx = 0
    autocast_ctx = torch.amp.autocast('cuda', enabled=amp)

    for batch in test_loader:
        rgb  = batch["input"].to(device)    # [B,3,Ht,Wt] (padded)
        mask = batch["mask"].to(device)     # [B,1,Ht,Wt]
        tgt  = batch["target"].to(device)   # [B,3,Ht,Wt]
        ohw  = batch.get("orig_hw", None)   # [B,2] if using keep-aspect collate

        inp = torch.cat([rgb, mask], dim=1) # [B,4,Ht,Wt]

        with autocast_ctx:
             out_masked, full_rgb, m_soft = G(inp, return_full=True)                   # [B,3,Ht,Wt]

        B = full_rgb.size(0)
        for i in range(B):
            # crop back to each sample's original H,W (or current if not provided)
            if ohw is not None:
                H, W = int(ohw[i,0]), int(ohw[i,1])
            else:
                _, H, W = rgb[i].shape
            rgb_i         = rgb[i, :, :H, :W]
            tgt_i         = tgt[i,  :, :H, :W]
            pred_masked_i = out_masked[i, :, :H, :W]
            pred_full_i = full_rgb[i, :, :H, :W]
            soft_i        = m_soft[i, :, :H, :W]   # absolute repaint (for Poisson)
            '''  # masked blend (optional)
            pred_i = pred[i, :, :H, :W]
            tgt_i  = tgt[i,  :, :H, :W]
            rgb_i  = rgb[i,  :, :H, :W]
            mask_i = mask[i, :, :H, :W]

            hard   = dilate_mask_binary(mask_i.unsqueeze(0), k=21, iters=3).squeeze(0) 
            soft   = feather_mask(hard.unsqueeze(0), k=15).squeeze(0)
            '''
            #blended_i = poisson_blend_sample(pred_rgb=pred_i, base_rgb=rgb_i, mask_01=soft, mode="normal", dilate=2, thresh=0.05)
            blended_i = poisson_blend_sample(
            pred_rgb=pred_full_i,
            base_rgb=rgb_i,
            mask_01=soft_i,         # internally thresholded & slightly dilated
            mode="normal",
            dilate=2,
            thresh=0.3              # a bit higher than 0.05 → crisper mask
            )
            stem = f"{idx:05d}"
            save_image(pred_masked_i.clamp(0,1), os.path.join(save_dir, f"{stem}_pred_masked.png"))
            save_image(pred_full_i.clamp(0,1),   os.path.join(save_dir, f"{stem}_pred_full.png"))
            save_image(blended_i.clamp(0,1),     os.path.join(save_dir, f"{stem}_pred_blended.png"))
            save_triplet_fullsize(rgb_i, pred_masked_i, tgt_i, save_dir, stem)  # input/masked/targe                             
            
            # 1) always save full-size triplet (no stretch)
            #save_triplet_fullsize(rgb_i, pred_i, tgt_i, save_dir, stem)

            # 2) optional preview panel (letterboxed, not stretched)
            if panel_mode == "letterbox":
                tiles = [pad_to_square(x.clamp(0,1)) for x in (rgb_i, pred_i, tgt_i)]
                # (optional) downscale large squares for disk space
                if panel_size is not None:
                    tiles = [F.interpolate(t.unsqueeze(0), size=(panel_size, panel_size),
                                            mode="bilinear", align_corners=False).squeeze(0)
                             for t in tiles]
                row = torch.stack(tiles, dim=0)  # [3,3,S,S]
                save_image(row, os.path.join(save_dir, f"{stem}_panel.png"), nrow=3)

            idx += 1

In [61]:
test_ds = NoseFolderDataset(
    "/workspace/data_splits/val/input",
    "/workspace/data_splits/val/mask_input",
    "/workspace/data_splits/val/target"
)

from torch.utils.data import DataLoader
test_loader = DataLoader(
    test_ds, batch_size=2, shuffle=False, num_workers=2, pin_memory=True,
    persistent_workers=False,
    collate_fn=lambda b: collate_keep_aspect(b, multiple=32)
)

run_test_unet(
    ckpt_path="ckpts_Unet_PatchGan_Resv1/best_l1_mask.pt",
    test_loader=test_loader,
    device=device,
    save_dir="/workspace/results_unet_v3.1_postprocess",
    
    panel_size=256,
    amp=True
)

[validate] counts: img=67 mask=63 target=67
[validate] common triples: 63
[validate] missing masks for 4 imgs, e.g. ['WhatsApp Image 2025-07-12 at 5.46.30 PM', 'WhatsApp Image 2025-07-12 at 5.46.31 PM (2)', 'WhatsApp Image 2025-07-12 at 5.53.17 PM (1)', 'WhatsApp Image 2025-07-12 at 6.20.03 PM']
