In [6]:
import os
import numpy as np
from datasets import load_dataset
from scipy.fftpack import dct, idct
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import normalized_root_mse as nrmse
from skimage.metrics import structural_similarity as ssim
import torch
import torchvision.transforms as T
from torchvision.models import inception_v3
from scipy.linalg import sqrtm
from PIL import Image

def get_device():
    if torch.backends.mps.is_available() and torch.backends.mps.is_built():
        return torch.device("mps")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")

device = get_device()
print("Device:", device)

base = (
    "/Users/byeongchanmac/Library/CloudStorage/"
    "GoogleDrive-jeong382@umn.edu/My Drive/"
    "University of Minnesota/CSCI 5527/Project/"
    "aloha_sim_insertion_human_image/data/chunk-000"
)
ds = load_dataset(
    "parquet",
    data_files={"train": os.path.join(base, "episode_*.parquet")},
    split="train"
)
img_key = "observation.images.top"
images = [
    np.array(ex[img_key], np.float32) / 255.0
    for ex in ds.select(range(min(100, len(ds))))
]
H, W = images[0].shape[:2]
print("Original Resolution:", (H, W))

def fft_tr(x):
    return np.fft.fft2(x)

def fft_inv(f):
    return np.abs(np.fft.ifft2(f))

def dct_tr(x):
    return dct(dct(x.T, norm="ortho").T, norm="ortho")

def dct_inv(c):
    return idct(idct(c.T, norm="ortho").T, norm="ortho")

def reconstruct_gray(im, transform, inverse, bh, bw):
    coeff = transform(im)
    masked = np.zeros_like(coeff)
    masked[:bh, :bw] = coeff[:bh, :bw]
    return inverse(masked)

def reconstruct_color(im_color, transform, inverse, bh, bw):
    chans = []
    for c in range(3):
        rec_c = reconstruct_gray(im_color[..., c], transform, inverse, bh, bw)
        chans.append(rec_c)
    rec = np.stack(chans, axis=-1)
    assert rec.shape == im_color.shape
    return rec

def eval_gray(origs, recs):
    ps, nr, ss_vals = [], [], []
    for o, r in zip(origs, recs):
        dr = o.max() - o.min() or 1.0
        ps.append(psnr(o, r, data_range=dr))
        nr.append(nrmse(o, r))
        ss_vals.append(ssim(o, r, data_range=dr))
    return {
        "PSNR": np.mean(ps),
        "NRMSE": np.mean(nr),
        "SSIM": np.mean(ss_vals)
    }

def eval_color(origs, recs):
    ps, nr, ss_vals = [], [], []
    for o, r in zip(origs, recs):
        dr = o.max() - o.min() or 1.0
        ps.append(psnr(o, r, data_range=dr))
        nr.append(nrmse(o.flatten(), r.flatten()))
        ss_vals.append(ssim(o, r, data_range=dr, channel_axis=-1))
    return {
        "PSNR": np.mean(ps),
        "NRMSE": np.mean(nr),
        "SSIM": np.mean(ss_vals)
    }

def get_acts(imgs, model, device, bs=32):
    model.eval()
    tf = T.Compose([
        T.ToTensor(),
        T.Resize((299, 299)),
        T.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
    ])
    acts = []
    with torch.no_grad():
        for i in range(0, len(imgs), bs):
            batch = imgs[i:i+bs]
            tns = []
            for im in batch:
                im255 = (im * 255).clip(0,255).astype(np.uint8)
                tns.append(tf(im255).unsqueeze(0))
            tns = torch.cat(tns, 0).to(device)
            acts.append(model(tns).cpu().numpy())
    return np.concatenate(acts, 0)

def compute_fid(origs, recs, device):
    net = inception_v3(pretrained=True, transform_input=False)
    net.fc = torch.nn.Identity()
    net.to(device)
    a1 = get_acts(origs, net, device)
    a2 = get_acts(recs, net, device)
    mu1, mu2 = a1.mean(0), a2.mean(0)
    s1, s2 = np.cov(a1, rowvar=False), np.cov(a2, rowvar=False)
    cs = sqrtm(s1.dot(s2))
    if np.iscomplexobj(cs):
        cs = cs.real
    return np.sum((mu1 - mu2)**2) + np.trace(s1 + s2 - 2*cs)

gray = [im.mean(-1) for im in images]

out_dir = "recons"
os.makedirs(out_dir, exist_ok=True)
scales = {
    "1":    (H, W),
    "1/2":  (H//2, W//2),
    "1/4":  (H//4, W//4),
    "1/8":  (H//8, W//8),
    "1/16": (H//16, W//16)
}
methods = [("FFT", (fft_tr, fft_inv)), ("DCT", (dct_tr, dct_inv))]

sample_idx = 0
orig_c = (images[sample_idx] * 255).clip(0,255).astype(np.uint8)
Image.fromarray(orig_c).save(os.path.join(out_dir, "original_color.png"))
orig_g = (gray[sample_idx] * 255).clip(0,255).astype(np.uint8)
Image.fromarray(orig_g).save(os.path.join(out_dir, "original_gray.png"))
print("Saved: original_color.png, original_gray.png")

results = {}
for method, (tr, inv) in methods:
    results[method] = {}
    for name, (bh, bw) in scales.items():
        recs_g = [reconstruct_gray(im, tr, inv, bh, bw) for im in gray]
        recs_c = [reconstruct_color(im, tr, inv, bh, bw) for im in images]

        m_gray  = eval_gray(gray, recs_g)
        m_color = eval_color(images, recs_c)

        orig_gray_rgb = [np.stack([img]*3, axis=-1) for img in gray]
        rec_gray_rgb  = [np.stack([img]*3, axis=-1) for img in recs_g]
        fid_gray  = compute_fid(orig_gray_rgb, rec_gray_rgb, device)
        fid_color = compute_fid(images, recs_c, device)

        results[method][name] = {
            "gray": m_gray,
            "color": m_color,
            "FID_gray": fid_gray,
            "FID_color": fid_color
        }

        safe = name.replace("/", "_")
        
        g_arr = (recs_g[sample_idx] * 255).clip(0,255).astype(np.uint8)
        Image.fromarray(g_arr).save(
            os.path.join(out_dir, f"recon_gray_{method.lower()}_{safe}.png")
        )
        c_arr = (recs_c[sample_idx] * 255).clip(0,255).astype(np.uint8)
        Image.fromarray(c_arr).save(
            os.path.join(out_dir, f"recon_color_{method.lower()}_{safe}.png")
        )
        print(f"Saved: recon_gray_{method.lower()}_{safe}.png, recon_color_{method.lower()}_{safe}.png")

label_map = {
    "1": "480×640", "1/2": "240×320",
    "1/4": "120×160", "1/8": "60×80",
    "1/16": "30×40"
}
for method in results:
    print(f"\n=== {method} Reconstruction Metrics ===")
    for name, m in results[method].items():
        size = label_map[name]
        print(
            f"[{size}] "
            f"GRAY → PSNR={m['gray']['PSNR']:.3f}, "
            f"NRMSE={m['gray']['NRMSE']:.3f}, "
            f"SSIM={m['gray']['SSIM']:.3f}, "
            f"FID={m['FID_gray']:.3f} | "
            f"COLOR → PSNR={m['color']['PSNR']:.3f}, "
            f"NRMSE={m['color']['NRMSE']:.3f}, "
            f"SSIM={m['color']['SSIM']:.3f}, "
            f"FID={m['FID_color']:.3f}"
        )


Device: mps


Resolving data files:   0%|          | 0/50 [00:00<?, ?it/s]

원본 해상도: (480, 640)
Saved: original_color.png, original_gray.png




Saved: recon_gray_fft_1.png, recon_color_fft_1.png
Saved: recon_gray_fft_1_2.png, recon_color_fft_1_2.png
Saved: recon_gray_fft_1_4.png, recon_color_fft_1_4.png
Saved: recon_gray_fft_1_8.png, recon_color_fft_1_8.png
Saved: recon_gray_fft_1_16.png, recon_color_fft_1_16.png
Saved: recon_gray_dct_1.png, recon_color_dct_1.png
Saved: recon_gray_dct_1_2.png, recon_color_dct_1_2.png
Saved: recon_gray_dct_1_4.png, recon_color_dct_1_4.png
Saved: recon_gray_dct_1_8.png, recon_color_dct_1_8.png
Saved: recon_gray_dct_1_16.png, recon_color_dct_1_16.png

=== FFT Reconstruction Metrics ===
[480×640] GRAY → PSNR=322.443, NRMSE=0.000, SSIM=1.000, FID=-0.000 | COLOR → PSNR=323.275, NRMSE=0.000, SSIM=1.000, FID=0.001
[240×320] GRAY → PSNR=16.925, NRMSE=0.502, SSIM=0.371, FID=432.233 | COLOR → PSNR=17.788, NRMSE=0.505, SSIM=0.375, FID=353.908
[120×160] GRAY → PSNR=16.906, NRMSE=0.503, SSIM=0.367, FID=436.031 | COLOR → PSNR=17.769, NRMSE=0.506, SSIM=0.371, FID=356.581
[60×80] GRAY → PSNR=16.871, NRMSE=0.50