# Inpainting From Scratch Pipeline (Notebook)

Input image (occluded) → Mask (occlusion ∩ object) → Train inpainting model (from scratch) → Output → Evaluate (PSNR, SSIM, LPIPS)

In [10]:
import os, math, random
from dataclasses import dataclass, asdict
from typing import Optional, Tuple, List

import numpy as np
import cv2
from PIL import Image, ImageOps

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

from torchvision import transforms as T
from torchvision.transforms import ToTensor as ToTensorTorchvision

from skimage.metrics import structural_similarity as skimage_ssim
from skimage.metrics import peak_signal_noise_ratio as skimage_psnr

try:
    import lpips  # perceptual metric
except Exception:
    lpips = None

# Reuse preprocessing (binarize mask, pad, protect borders)
from inpainting import preprocess_img_and_mask
from seg_maskcnn import get_device as seg_get_device, _model as seg_model


def get_device():
    if torch.backends.mps.is_available():
        return torch.device('mps')
    return torch.device('cuda' if torch.cuda.is_available() else 'cpu')


device = get_device()
print('Using device:', device)

torch.manual_seed(42)
np_random_seed = np.random.seed
np_random_seed(42)
random.seed(42)


Using device: mps


In [11]:
# Utils: tensor conversion, mask intersection, optional pet-mask prediction
def img_to_tensor(img: Image.Image) -> torch.Tensor:
    arr = np.array(img).astype(np.float32) / 255.0
    if arr.ndim == 2:
        arr = arr[..., None]
    arr = np.transpose(arr, (2, 0, 1))
    return torch.from_numpy(arr)

def tensor_to_pil(t: torch.Tensor) -> Image.Image:
    t = t.detach().clamp(0,1)
    arr = (t.permute(1,2,0).cpu().numpy()*255.0).astype(np.uint8)
    if arr.shape[2] == 1:
        arr = arr[...,0]
    return Image.fromarray(arr)

def safe_intersect_masks(occ_mask: Image.Image, pet_mask: Optional[Image.Image]) -> Image.Image:
    m_occ = (np.array(occ_mask.convert('L')) > 127).astype(np.uint8)
    if pet_mask is None:
        return Image.fromarray((m_occ*255).astype(np.uint8), 'L')
    m_pet = (np.array(pet_mask.convert('L')) > 127).astype(np.uint8)
    inter = (m_occ & m_pet).astype(np.uint8)
    if inter.sum() == 0:
        inter = m_occ
    return Image.fromarray((inter*255).astype(np.uint8), 'L')

def predict_pet_mask_safe(img_pil: Image.Image) -> Optional[Image.Image]:
    try:
        tr = ToTensorTorchvision()
        x = tr(img_pil).unsqueeze(0).to(seg_get_device())
        with torch.no_grad():
            out = seg_model(x)[0]
        labels = out['labels'].tolist()
        scores = out['scores'].tolist()
        masks = out['masks']
        keep = [(l in [17,18]) and (s >= 0.5) for l,s in zip(labels, scores)]
        if not any(keep):
            return None
        m = (masks[keep].squeeze(1) > 0.5).any(dim=0).float().cpu().numpy()
        return Image.fromarray((m*255).astype(np.uint8), 'L')
    except Exception as e:
        return None


In [12]:
# Dataset: occluded image + (occlusion ∩ pet) mask, target = clean image
class OccludedPetDataset(Dataset):
    def __init__(self, base_img_dir='data/oxford-iiit-pet/images', occluded_dir='data/occluded', resize=256):
        self.base_img_dir = base_img_dir
        self.occluded_dir = occluded_dir
        self.resize = resize
        # Only take occluded images, skip mask files in the same folder
        self.files = [
            f for f in os.listdir(occluded_dir)
            if f.lower().endswith(('.jpg', '.png')) and '_mask' not in f.lower()
        ]
        self.to_tensor = T.ToTensor()
        self.resize_img = T.Resize((resize, resize), interpolation=T.InterpolationMode.BICUBIC)
        self.resize_mask = T.Resize((resize, resize), interpolation=T.InterpolationMode.NEAREST)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        name = self.files[idx]
        occ_path = os.path.join(self.occluded_dir, name)
        occ_mask_path = os.path.join(self.occluded_dir, name.rsplit('.',1)[0] + '_mask.png')
        clean_path = os.path.join(self.base_img_dir, name.rsplit('.',1)[0] + '.jpg')

        occ_img = Image.open(occ_path).convert('RGB')
        clean_img = Image.open(clean_path).convert('RGB')
        occ_mask = Image.open(occ_mask_path).convert('L')

        # Disable pet mask for stability/speed; use occlusion mask only
        pet_mask =  None             #predict_pet_mask_safe(occ_img)
        inpaint_mask = safe_intersect_masks(occ_mask, pet_mask)

        occ_img_p, inpaint_mask_p = preprocess_img_and_mask(occ_img, inpaint_mask)
        # Guard: if mask covers ~all image, erode to avoid gray outputs
        _arr = np.array(inpaint_mask_p)
        _ratio = (_arr == 255).mean()
        if _ratio > 0.95:
            _arr = cv2.erode(_arr, np.ones((15,15), np.uint8), 1)
            inpaint_mask_p = Image.fromarray(_arr, 'L')
        clean_img_p, _ = preprocess_img_and_mask(clean_img, inpaint_mask)

        occ_img_r = self.resize_img(occ_img_p)
        clean_img_r = self.resize_img(clean_img_p)
        mask_r = self.resize_mask(inpaint_mask_p)

        occ_t = self.to_tensor(occ_img_r)
        clean_t = self.to_tensor(clean_img_r)
        mask_t = (self.to_tensor(mask_r) > 0.5).float()[:1]

        # Zero-out the masked region in the network input (remove gray occluder)
        occ_in_blur = occ_t.clone()
        occ_in_blur[:, mask_t[0] > 0.5] = 0.0
        x = torch.cat([occ_in_blur, mask_t], dim=0)  # [4,H,W]
        y = clean_t
        # Also return original occluded RGB (for visualization/blending)
        return {'input': x, 'target': y, 'mask': mask_t, 'name': name, 'occ_rgb': occ_t}


In [13]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch, use_bn=True):
        super().__init__()
        layers = [
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=not use_bn),
        ]
        if use_bn:
            layers.append(nn.BatchNorm2d(out_ch))
        layers.append(nn.ReLU(inplace=True))

        layers.append(
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=not use_bn)
        )
        if use_bn:
            layers.append(nn.BatchNorm2d(out_ch))
        layers.append(nn.ReLU(inplace=True))

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


class SelfAttention(nn.Module):
    """
    Non-Local Self-Attention (SAGAN style)
    """
    def __init__(self, in_dim):
        super().__init__()
        self.query = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.key = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.value = nn.Conv2d(in_dim, in_dim, 1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        B, C, H, W = x.size()

        q = self.query(x).view(B, -1, H * W).permute(0, 2, 1)  # B,N,Cq
        k = self.key(x).view(B, -1, H * W)                     # B,Ck,N
        attn = torch.bmm(q, k)                                 # B,N,N
        attn = F.softmax(attn, dim=-1)

        v = self.value(x).view(B, C, H * W)                    # B,C,N
        out = torch.bmm(v, attn.permute(0, 2, 1))              # B,C,N
        out = out.view(B, C, H, W)

        return self.gamma * out + x


class SimpleUNet(nn.Module):
    def __init__(self, in_ch=4, out_ch=3, base=64):
        """
        UNet nhẹ hơn + BatchNorm + Self-Attention ở bottleneck
        - base=64 → 64-128-256-512
        """
        super().__init__()
        self.down1 = DoubleConv(in_ch, base)
        self.pool1 = nn.MaxPool2d(2)

        self.down2 = DoubleConv(base, base * 2)
        self.pool2 = nn.MaxPool2d(2)

        self.down3 = DoubleConv(base * 2, base * 4)
        self.pool3 = nn.MaxPool2d(2)

        # Bottleneck: DoubleConv + SelfAttention
        self.mid = nn.Sequential(
            DoubleConv(base * 4, base * 8),
            SelfAttention(base * 8)
        )

        self.up3 = nn.ConvTranspose2d(base * 8, base * 4, 2, stride=2)
        self.dec3 = DoubleConv(base * 8, base * 4)

        self.up2 = nn.ConvTranspose2d(base * 4, base * 2, 2, stride=2)
        self.dec2 = DoubleConv(base * 4, base * 2)

        self.up1 = nn.ConvTranspose2d(base * 2, base, 2, stride=2)
        self.dec1 = DoubleConv(base * 2, base)

        self.outc = nn.Conv2d(base, out_ch, 1)

    def forward(self, x):
        d1 = self.down1(x)
        p1 = self.pool1(d1)

        d2 = self.down2(p1)
        p2 = self.pool2(d2)

        d3 = self.down3(p2)
        p3 = self.pool3(d3)

        m = self.mid(p3)

        u3 = self.up3(m)
        c3 = torch.cat([u3, d3], dim=1)
        d3 = self.dec3(c3)

        u2 = self.up2(d3)
        c2 = torch.cat([u2, d2], dim=1)
        d2 = self.dec2(c2)

        u1 = self.up1(d2)
        c1 = torch.cat([u1, d1], dim=1)
        d1 = self.dec1(c1)

        out = self.outc(d1)
        return torch.sigmoid(out)


def init_weights(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)

In [14]:
class PatchDiscriminator(nn.Module):
    """
    PatchGAN + SpectralNorm
    Input: 3ch RGB
    Output: N×N patch realism map
    """
    def __init__(self, ch=64):
        super().__init__()

        def block(in_c, out_c, use_bn=True):
            layers = [
                nn.utils.spectral_norm(nn.Conv2d(in_c, out_c, 4, 2, 1))
            ]
            if use_bn:
                layers.append(nn.BatchNorm2d(out_c))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(3, ch, use_bn=False),
            *block(ch, ch * 2),
            *block(ch * 2, ch * 4),
            *block(ch * 4, ch * 8),
            nn.utils.spectral_norm(nn.Conv2d(ch * 8, 1, 4, 1, 1))
        )

    def forward(self, x):
        return self.model(x)


In [15]:
# -------------------------
# Config, loaders, and losses
# -------------------------
@dataclass
class TrainConfig:
    base_img_dir: str = '../data/oxford-iiit-pet/images'
    occluded_dir: str = '../data/occluded'
    resize: int = 256
    batch_size: int = 4
    epochs: int = 20
    lr: float = 3e-4
    val_split: float = 0.2
    num_workers: int = 0
    lpips_weight: float = 0.1  # set >0 to enable LPIPS term


def masked_l1(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    # mask: 1=inpaint region
    w = mask
    return (torch.abs(pred - target) * w).sum() / (w.sum() + 1e-6)


def build_loaders(cfg: TrainConfig) -> Tuple[DataLoader, DataLoader]:
    ds = OccludedPetDataset(cfg.base_img_dir, cfg.occluded_dir, resize=cfg.resize)
    print("Total samples:", len(ds))
    n_val = max(1, int(len(ds) * cfg.val_split))
    n_train = max(1, len(ds) - n_val)
    train_ds, val_ds = random_split(ds, [n_train, n_val])
    train_dl = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers)
    val_dl = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)
    print("Train:", len(train_ds), "Val:", len(val_ds))
    return train_dl, val_dl


def init_lpips(device):
    if lpips is None:
        return None
    try:
        net = lpips.LPIPS(net='alex').to(device).eval()
        return net
    except Exception:
        return None


In [16]:
# -------------------------
# Train and evaluate
# -------------------------
def train_one_epoch(
    model, D,
    opt_G, opt_D,
    train_dl, device,
    lpips_net=None, lpips_w=0.0
):
    model.train()
    D.train()
    total = 0.0

    criterion_gan = nn.BCEWithLogitsLoss()

    for batch in train_dl:
        x = batch['input'].to(device)
        y = batch['target'].to(device)
        m = batch['mask'].to(device)

        # =======================================================
        # 1) GENERATOR FORWARD
        # =======================================================
        y_hat = model(x)

        # reconstruction inside mask
        l1_in = masked_l1(y_hat, y, m)
        id_loss = torch.mean(torch.abs(y_hat - x[:, :3]) * (1 - m))

        # lpips
        lp_loss = 0
        if lpips_net is not None and lpips_w > 0:
            lp_loss = lpips_net(y_hat * 2 - 1, y * 2 - 1).mean()

        # final blended output (for D and for realism)
        final_fake = y_hat * m + x[:, :3] * (1 - m)

        # =======================================================
        # 2) TRAIN DISCRIMINATOR
        # =======================================================
        opt_D.zero_grad(set_to_none=True)

        real_out = D(y)             # ground truth clean
        fake_out = D(final_fake.detach())  # blended prediction

        d_loss_real = criterion_gan(real_out, torch.ones_like(real_out))
        d_loss_fake = criterion_gan(fake_out, torch.zeros_like(fake_out))
        d_loss = 0.5 * (d_loss_real + d_loss_fake)

        d_loss.backward()
        opt_D.step()

        # =======================================================
        # 3) TRAIN GENERATOR (GAN Loss)
        # =======================================================
        opt_G.zero_grad(set_to_none=True)

        fake_out = D(final_fake)
        gan_loss = criterion_gan(fake_out, torch.ones_like(fake_out))

        loss = (
            2.0 * l1_in +
            0.005 * id_loss +
            lpips_w * lp_loss +
            0.01 * gan_loss
        )

        loss.backward()
        opt_G.step()

        total += float(loss.detach().cpu().item()) * x.size(0)

    return total / max(1, len(train_dl.dataset))


In [17]:
@torch.no_grad()
def evaluate_model(model, val_dl, device):
    model.eval()
    psnr_vals, ssim_vals = [], []
    for batch in val_dl:
        x = batch['input'].to(device)
        y = batch['target'].to(device)
        m = batch['mask'].to(device)
        y_hat = model(x)
        # Blend outside mask for fair evaluation
        final = y_hat * m + x[:, :3] * (1 - m)
        y_np = y.cpu().numpy()
        yhat_np = final.cpu().numpy()
        m_np = m.cpu().numpy()
        B = y_np.shape[0]
        for i in range(B):
            gt = (y_np[i].transpose(1, 2, 0) * 255.0).astype(np.uint8)
            pr = (yhat_np[i].transpose(1, 2, 0) * 255.0).astype(np.uint8)
            mm = (m_np[i][0] > 0.5).astype(np.uint8)
            if mm.sum() == 0:
                continue
            pr_masked = pr.copy()
            gt_masked = gt.copy()
            pr_masked[mm == 0] = gt[mm == 0]
            psnr_vals.append(skimage_psnr(gt_masked, pr_masked, data_range=255))
            ssim_vals.append(skimage_ssim(gt_masked, pr_masked, channel_axis=2, data_range=255))
    mean_psnr = float(np.mean(psnr_vals)) if psnr_vals else 0.0
    mean_ssim = float(np.mean(ssim_vals)) if ssim_vals else 0.0
    return mean_psnr, mean_ssim


In [18]:
# -------------------------
# Run training pipeline
# -------------------------
cfg = TrainConfig()
train_dl, val_dl = build_loaders(cfg)

model = SimpleUNet(in_ch=4, out_ch=3, base=64).to(device)
model.apply(init_weights)
D = PatchDiscriminator().to(device)

opt_G = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=1e-4)
opt_D = torch.optim.AdamW(D.parameters(), lr=cfg.lr * 2, weight_decay=1e-4)

lpips_net = init_lpips(device) if cfg.lpips_weight > 0 else None

best_ssim = -1.0
best_ckpt = 'outputs/inpaint_unet_best.pt'
os.makedirs('outputs', exist_ok=True)

for epoch in range(cfg.epochs):
    loss = train_one_epoch(model, D, opt_G, opt_D, train_dl, device, lpips_net, cfg.lpips_weight)
    psnr, ssim = evaluate_model(model, val_dl, device)
    print(f'Epoch {epoch+1}/{cfg.epochs} - loss={loss:.4f} PSNR={psnr:.2f} SSIM={ssim:.3f}')

    # lưu best theo SSIM
    if ssim > best_ssim:
        best_ssim = ssim
        torch.save(
            {'state_dict': model.state_dict(), 'cfg': asdict(cfg)},
            best_ckpt
        )

Total samples: 999
Train: 800 Val: 199
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /Users/quan0207/miniforge3/envs/deeprestore/lib/python3.12/site-packages/lpips/weights/v0.1/alex.pth
Epoch 1/20 - loss=1.0725 PSNR=29.03 SSIM=0.965
Epoch 2/20 - loss=0.8487 PSNR=30.12 SSIM=0.968
Epoch 3/20 - loss=0.7988 PSNR=29.73 SSIM=0.968
Epoch 4/20 - loss=0.7792 PSNR=30.33 SSIM=0.969
Epoch 5/20 - loss=0.7631 PSNR=30.63 SSIM=0.970
Epoch 6/20 - loss=0.7384 PSNR=30.54 SSIM=0.970
Epoch 7/20 - loss=0.7383 PSNR=30.53 SSIM=0.970
Epoch 8/20 - loss=0.7252 PSNR=30.62 SSIM=0.970
Epoch 9/20 - loss=0.7201 PSNR=30.75 SSIM=0.970
Epoch 10/20 - loss=0.7159 PSNR=30.51 SSIM=0.970
Epoch 11/20 - loss=0.7052 PSNR=30.57 SSIM=0.970
Epoch 12/20 - loss=0.7037 PSNR=30.95 SSIM=0.971
Epoch 13/20 - loss=0.6868 PSNR=30.66 SSIM=0.971
Epoch 14/20 - loss=0.6839 PSNR=30.77 SSIM=0.970
Epoch 15/20 - loss=0.6757 PSNR=30.73 SSIM=0.971
Epoch 16/20 - loss=0.6638 PSNR=31.01 SSIM=0.971
Epoch 17/20 - loss=0.6705 PSNR=30.87 SSIM=0.971
Epoch 18/20 - loss=0.6617 PSNR=30.72 SSIM=0.971
Epoch 19/20 - loss=0.650

In [19]:
# -------------------------
# Save a few qualitative results
# -------------------------
saved = 0
model.eval()
for batch in val_dl:
    x = batch['input'].to(device)
    y = batch['target'].to(device)
    occ_rgb = batch.get('occ_rgb', x[:, :3]).to(device)
    names = batch['name']
    with torch.no_grad():
        y_hat = model(x)
        m = batch['mask'].to(device)
        final = y_hat * m + x[:, :3] * (1 - m)
    B = x.size(0)
    for i in range(B):
        occ = tensor_to_pil(occ_rgb[i])
        pred = tensor_to_pil(final[i])
        gt = tensor_to_pil(y[i])

        raw_pred = tensor_to_pil(y_hat[i])
        mask_img = tensor_to_pil(m[i])
        _u, _c = torch.unique(
            m[i].round().to(torch.int32), return_counts=True
        )
        print(names[i], 'mask counts:',
              dict(zip(_u.tolist(), [int(x) for x in _c.tolist()])))

        occ.save(os.path.join('outputs', f'occ_{names[i]}'))
        pred.save(os.path.join('outputs', f'pred_{names[i]}'))
        raw_pred.save(os.path.join('outputs', f'debug_raw_{names[i]}'))
        mask_img.save(os.path.join('outputs', f'debug_mask_{names[i]}'))
        gt.save(os.path.join('outputs', f'gt_{names[i]}'))
        saved += 1
        if saved >= 8:
            break
    if saved >= 8:
        break

# -------------------------
# Final metrics (PSNR / SSIM / LPIPS)
# -------------------------
psnr, ssim = evaluate_model(model, val_dl, device)
try:
    lp = lpips.LPIPS(net='alex').to(device).eval() if lpips is not None else None
    lpips_vals = []
    if lp is not None:
        for batch in val_dl:
            x = batch['input'].to(device)
            y = batch['target'].to(device)
            with torch.no_grad():
                y_hat = model(x)
                m = batch['mask'].to(device)
                final = y_hat * m + x[:, :3] * (1 - m)
            lpips_vals.append(lp(final * 2 - 1, y * 2 - 1).mean().item())
    mean_lpips = float(np.mean(lpips_vals)) if lpips_vals else 0.0
except Exception:
    mean_lpips = 0.0

print(f'Final metrics → PSNR: {psnr:.2f}, SSIM: {ssim:.3f}, LPIPS: {mean_lpips:.3f}')

ckpt_path = 'outputs/inpaint_unet_final.pt'
torch.save(
    {
        'state_dict': model.state_dict(),
        'cfg': asdict(cfg),
    },
    ckpt_path
)
print('Saved last model to', ckpt_path)
print('Best model (by SSIM) saved to', best_ckpt)


Sphynx_162.jpg mask counts: {0: 64032, 1: 1504}
Bombay_190.jpg mask counts: {0: 63880, 1: 1656}
newfoundland_171.jpg mask counts: {0: 64032, 1: 1504}
basset_hound_187.jpg mask counts: {0: 60824, 1: 4712}
Bengal_117.jpg mask counts: {0: 63802, 1: 1734}
newfoundland_2.jpg mask counts: {0: 63096, 1: 2440}
boxer_79.jpg mask counts: {0: 63891, 1: 1645}
pomeranian_175.jpg mask counts: {0: 61047, 1: 4489}
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /Users/quan0207/miniforge3/envs/deeprestore/lib/python3.12/site-packages/lpips/weights/v0.1/alex.pth
Final metrics → PSNR: 31.01, SSIM: 0.971, LPIPS: 0.037
Saved last model to outputs/inpaint_unet_final.pt
Best model (by SSIM) saved to outputs/inpaint_unet_best.pt


In [20]:
# -------------------------
# Inference helper
# -------------------------
@torch.no_grad()
def infer_one(
    model,
    occ_img: Image.Image,
    occ_mask: Image.Image,
    use_segmentation=False,
    resize=256,
    debug=False
):
    # segmentation is disabled for stability
    inpaint_mask = occ_mask

    occ_img_p, mask_p = preprocess_img_and_mask(occ_img, inpaint_mask)

    resize_img = T.Resize((resize, resize),
                          interpolation=T.InterpolationMode.BICUBIC)
    resize_mask = T.Resize((resize, resize),
                           interpolation=T.InterpolationMode.NEAREST)

    occ_r = resize_img(occ_img_p)
    mask_r = resize_mask(mask_p)

    occ_t = img_to_tensor(occ_r)
    mask_t = (img_to_tensor(mask_r)[:1] > 0.5).float()

    # blur inside mask
    occ_blur = occ_t.clone()
    occ_blur[:, mask_t[0] > 0.5] = 0.0

    x = torch.cat([occ_blur, mask_t], dim=0).unsqueeze(0).to(device)

    model.eval()
    y_hat = model(x)[0]

    mask_t = mask_t.to(device)
    occ_t = occ_t.to(device)

    final = y_hat * mask_t + occ_t * (1 - mask_t)

    out = tensor_to_pil(final)

    # resize back to original
    if out.size != occ_img.size:
        out = out.resize(occ_img.size, Image.BICUBIC)

    if debug:
        print("Mask sum:", mask_t.sum().item())
        tensor_to_pil(mask_t).save("debug_mask.png")
        tensor_to_pil(y_hat).save("debug_raw.png")
        tensor_to_pil(final).save("debug_final.png")

    return out


# -------------------------
# Load best checkpoint & test
# -------------------------
ckpt = torch.load('outputs/inpaint_unet_best.pt', map_location=device)
cfg_loaded = ckpt["cfg"]

best_model = SimpleUNet(in_ch=4, out_ch=3, base=64).to(device)
best_model.load_state_dict(ckpt["state_dict"])
best_model.eval()

# 2) Load input images (example paths; đổi sang path của bạn)
occ = Image.open("/Users/quan0207/School/Computer vision/DeepRestore/data/occluded/yorkshire_terrier_8.jpg")
mask = Image.open("/Users/quan0207/School/Computer vision/DeepRestore/data/occluded/yorkshire_terrier_8_mask.png")

# 3) Inference
res = infer_one(best_model, occ, mask, use_segmentation=False, debug=True)
res.save("final_result.png")

Mask sum: 2256.0
