In [1]:
from model.cae import AutoEncoder, training, load_best_model,
import torch
from torch.utils.data import TensorDataset, DataLoader, random_split
import matplotlib.pyplot as plt
from util.load_data import load_images_as_tensor

In [25]:

positive = load_images_as_tensor("/blue/iruchkin/pansiyuan/positive/")


print(positive.shape)

full_ds = TensorDataset(positive)

val_sz = max(1, int(0.05 * len(full_ds)))
train_sz = len(full_ds) - val_sz
train_ds, val_ds = random_split(full_ds, [train_sz, val_sz])

train_ld = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
val_ld   = DataLoader(val_ds,   batch_size=64, shuffle=False, num_workers=4, pin_memory=True)


negative = TensorDataset(load_images_as_tensor("/blue/iruchkin/pansiyuan/negative/images/"))
negative = DataLoader(negative,   batch_size=64, shuffle=False, num_workers=4, pin_memory=True)

torch.Size([1324, 3, 256, 256])


In [9]:
model = AutoEncoder(in_channels=3, base_ch=32)
# model

In [10]:
out = training(model, train_ld, val_ld,
            epochs=150,lr=1e-2,
            outdir="runs/ae1024_tensor",
            loss_mix=0)
print("Best ckpt:", out['best_ckpt'])


[train] Epoch 1/150
  train: mse=0.009453  mae=0.061811
  valid: mse=0.006103  mae=0.050523  psnr=22.15 dB
  ↑ new best (mse 0.006103) -> runs/ae1024_tensor/ckpts/best.pt

[train] Epoch 2/150
  train: mse=0.006436  mae=0.051527
  valid: mse=0.005959  mae=0.048906  psnr=22.26 dB
  ↑ new best (mse 0.005959) -> runs/ae1024_tensor/ckpts/best.pt

[train] Epoch 3/150
  train: mse=0.006359  mae=0.050891
  valid: mse=0.005925  mae=0.049430  psnr=22.28 dB
  ↑ new best (mse 0.005925) -> runs/ae1024_tensor/ckpts/best.pt

[train] Epoch 4/150
  train: mse=0.006352  mae=0.050873
  valid: mse=0.005920  mae=0.049576  psnr=22.29 dB
  ↑ new best (mse 0.005920) -> runs/ae1024_tensor/ckpts/best.pt

[train] Epoch 5/150
  train: mse=0.006356  mae=0.050978
  valid: mse=0.005944  mae=0.050389  psnr=22.27 dB

[train] Epoch 6/150
  train: mse=0.006358  mae=0.050937
  valid: mse=0.005935  mae=0.050191  psnr=22.27 dB

[train] Epoch 7/150
  train: mse=0.006349  mae=0.051080
  valid: mse=0.005933  mae=0.048909  ps

In [33]:
evaluate(model,val_ld,"cuda",sample_dir="positive")

{'mse': 0.00043633682335811585, 'mae': 0.013116729022427038}

In [32]:
evaluate(model,negative,"cuda",sample_dir="negative")

{'mse': 0.004291121783242464, 'mae': 0.046508773152108204}

In [37]:
import os
import torch
from torchvision.utils import save_image

@torch.no_grad()
def save_recon_diffs(
    model, loader,
    outdir="dif", device=None, suffix="neg",
    max_samples=100
):
    """
    Save combined images: [original | reconstruction | abs diff].
    - Diff is grayscale (mean over channels).
    - max_samples limits how many to save.
    """
    os.makedirs(outdir, exist_ok=True)
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device).eval()

    def _get_imgs_from_batch(batch):
        return batch[0] if isinstance(batch, (tuple, list)) else batch

    counter = 0
    for batch in loader:
        imgs = _get_imgs_from_batch(batch).to(device, non_blocking=True)  # [B,C,H,W]
        rec, _ = model(imgs)
        diff = (rec - imgs).abs().mean(dim=1, keepdim=True)               # [B,1,H,W]
        diff = diff.repeat(1, 3, 1, 1)                                    # -> 3-channel for saving

        for i in range(imgs.size(0)):
            if counter >= max_samples:
                print(f"[done] Saved {counter} samples to: {outdir}")
                return
            base = f"{suffix}_{counter:06d}"
            # concat along width: [C,H,W] → (orig | recon | diff)
            row = torch.cat([imgs[i], rec[i], diff[i]], dim=2)
            save_image(row, os.path.join(outdir, f"triple_{base}.png"))
            counter += 1

    print(f"[done] Saved {counter} samples to: {outdir}")

save_recon_diffs(model, negative, outdir="negdif", suffix="", max_samples=50)
save_recon_diffs(model, val_ld, outdir="posdif", suffix="", max_samples=50)

[done] Saved 50 samples to: negdif
[done] Saved 50 samples to: posdif


In [31]:
import torch.nn as nn
import os, math, torch
import torch.nn.functional as F
from torchvision import utils as vutils

@torch.no_grad()
def evaluate(model, loader, device, sample_dir=None, epoch=0, max_grids=1, grid_n=40):
    model.eval()
    totals = {'mse': 0.0, 'mae': 0.0}
    count, saved = 0, 0
    for batch in loader:
        imgs = batch[0].to(device, non_blocking=True)
        rec, _ = model(imgs)
        mse = F.mse_loss(rec, imgs, reduction='mean').item()
        mae = F.l1_loss(rec, imgs, reduction='mean').item()
        b = imgs.size(0)
        totals['mse'] += mse * b
        totals['mae'] += mae * b
        count += b
        
        if sample_dir and saved < max_grids:
            grid = vutils.make_grid(torch.cat([imgs[:grid_n], rec[:grid_n]], dim=0), nrow=grid_n)
            os.makedirs(sample_dir, exist_ok=True)
            vutils.save_image(grid, os.path.join(sample_dir, f"val_ep{epoch:03d}_{saved:02d}.png"))
            saved += 1
            
    if count == 0:
        return {k: float("nan") for k in totals}
    return {k: v / count for k, v in totals.items()}

class DownBlock(nn.Module):
    """Conv -> GroupNorm -> SiLU, with optional downsampling via stride."""

    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.block = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False),
                                   nn.GroupNorm(num_groups=min(8, out_ch), num_channels=out_ch), nn.SiLU(inplace=True),
                                   nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False),
                                   nn.GroupNorm(num_groups=min(8, out_ch), num_channels=out_ch),
                                   nn.SiLU(inplace=True), )

    def forward(self, x):
        return self.block(x)


class UpBlock(nn.Module):
    """Upsample by 2x (ConvT) -> Conv block."""

    def __init__(self, in_ch, out_ch):
        super().__init__()

        self.block = nn.Sequential(nn.ConvTranspose2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1, bias=False),
                                   nn.GroupNorm(num_groups=min(8, out_ch), num_channels=out_ch), nn.SiLU(inplace=True),
                                   nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False),
                                   nn.GroupNorm(num_groups=min(8, out_ch), num_channels=out_ch),
                                   nn.SiLU(inplace=True), )

    def forward(self, x):
        return self.block(x)


# ---------- the autoencoder ----------


class AutoEncoder(nn.Module):
    """
    Fully convolutional 5-stage AE for 256x256 using ONLY DownBlock/UpBlock.
      Encoder: 256 -> 128 -> 64 -> 32 -> 16 -> 8    (DownBlock x5)
      Bottleneck: (ch5, 8, 8)
      Decoder: 8 -> 16 -> 32 -> 64 -> 128 -> 256    (UpBlock x5)
    """
    def __init__(self, in_channels: int = 3, base_ch: int = 32):
        super().__init__()
        ch1 = base_ch          # 32
        ch2 = base_ch * 2      # 64
        ch3 = base_ch * 4      # 128
        ch4 = base_ch * 8      # 256
        ch5 = base_ch * 16     # 512

        # Encoder
        self.enc1 = DownBlock(in_channels, ch1, stride=2)   # 256 -> 128
        self.enc2 = DownBlock(ch1,        ch2, stride=2)    # 128 -> 64
        self.enc3 = DownBlock(ch2,        ch3, stride=2)    # 64  -> 32
        self.enc4 = DownBlock(ch3,        ch4, stride=2)    # 32  -> 16
        self.enc5 = DownBlock(ch4,        ch5, stride=2)    # 16  -> 8

        # Decoder (mirror)
        self.dec1 = UpBlock(ch5, ch4)   # 8  -> 16
        self.dec2 = UpBlock(ch4, ch3)   # 16 -> 32
        self.dec3 = UpBlock(ch3, ch2)   # 32 -> 64
        self.dec4 = UpBlock(ch2, ch1)   # 64 -> 128
        self.dec5 = UpBlock(ch1, ch1)   # 128 -> 256

        # Output head
        self.out = nn.Sequential(
            nn.Conv2d(ch1, in_channels, kernel_size=1, stride=1, bias=True),
            nn.Sigmoid(),
        )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
            elif isinstance(m, nn.GroupNorm):
                if m.weight is not None: nn.init.ones_(m.weight)
                if m.bias  is not None: nn.init.zeros_(m.bias)

    def encode(self, x):
        x = self.enc1(x)  # [B, ch1,128,128]
        x = self.enc2(x)  # [B, ch2, 64, 64]
        x = self.enc3(x)  # [B, ch3, 32, 32]
        x = self.enc4(x)  # [B, ch4, 16, 16]
        x = self.enc5(x)  # [B, ch5,  8,  8]
        return x

    def decode(self, h):
        x = self.dec1(h)  # [B, ch4, 16,16]
        x = self.dec2(x)  # [B, ch3, 32,32]
        x = self.dec3(x)  # [B, ch2, 64,64]
        x = self.dec4(x)  # [B, ch1,128,128]
        x = self.dec5(x)  # [B, ch1,256,256]
        return self.out(x)

    def forward(self, x):
        h = self.encode(x)
        x_rec = self.decode(h)
        return x_rec, h
