In [1]:
import os, math, time
import numpy as np
import torch, torch.nn as nn
import matplotlib.pyplot as plt
from PIL import Image

torch.set_default_dtype(torch.float32)
device = "cpu"

OUTDIR = "nf2d_out"
os.makedirs(OUTDIR, exist_ok=True)

IMG_PATH = "test.jpg"
IMG_PATH_OWN = "airpods/IMG_1410.jpeg"

def load_img(path):
    im = Image.open(path).convert("RGB")
    return np.array(im, dtype=np.float32)/255.0

def make_coords_and_targets(img_np):
    H, W = img_np.shape[:2]
    ys, xs = torch.meshgrid(torch.linspace(0,1,H), torch.linspace(0,1,W), indexing="ij")
    coords = torch.stack([xs, ys], dim=-1).reshape(-1,2).to(device)
    rgbs = torch.from_numpy(img_np.reshape(-1,3)).to(device)
    return H, W, coords, rgbs


In [3]:
class PosEnc(nn.Module):
    def __init__(self, L=10):
        super().__init__()
        self.L = L
        self.register_buffer("freqs", (2.0**torch.arange(L)).float()*math.pi)
    def forward(self, x):
        f = self.freqs
        xb = x.unsqueeze(-2) * f.view(-1,1)     # (..., L, 2)
        s = torch.sin(xb)
        c = torch.cos(xb)
        return torch.cat([x, s.reshape(*x.shape[:-1], -1), c.reshape(*x.shape[:-1], -1)], dim=-1)

def pe_dim(L): 
    return 2 + 4*L

class MLP(nn.Module):
    def __init__(self, in_dim, width=128, depth=4):
        super().__init__()
        layers = []
        d = in_dim
        for _ in range(depth-1):
            layers += [nn.Linear(d, width), nn.ReLU(inplace=True)]
            d = width
        layers += [nn.Linear(d, 3), nn.Sigmoid()]
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x)


In [4]:
def make_sampler(coords_full, rgbs_full):
    N = coords_full.shape[0]
    def sample(B):
        idx = torch.randint(0, N, (B,), device=coords_full.device)
        return coords_full[idx], rgbs_full[idx]
    return sample

@torch.no_grad()
def reconstruct_image(model, pe, H, W, coords_full):
    x = pe(coords_full)
    pred = model(x).clamp(0,1)
    im = pred.reshape(H, W, 3).detach().cpu().numpy()
    return (im*255).astype(np.uint8)

def psnr_from_mse(m):
    return -10.0*math.log10(max(m, 1e-10))


In [5]:
def train_nf(img_np, image_name="img", L=10, width=128, depth=4, iters=1500, bs=10_000, lr=1e-2, snaps=(0,100,500,1000,1500)):
    H, W, coords_full, rgbs_full = make_coords_and_targets(img_np)
    pe = PosEnc(L=L).to(device)
    model = MLP(pe_dim(L), width=width, depth=depth).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    mse_fn = nn.MSELoss()
    sampler = make_sampler(coords_full, rgbs_full)
    psnr_hist = []

    t0 = time.time()
    for it in range(0, iters+1):
        coords, rgbs = sampler(bs)
        x = pe(coords)
        pred = model(x)
        loss = mse_fn(pred, rgbs)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()

        mse = float(loss.item())
        psnr_hist.append(psnr_from_mse(mse))

        if it in snaps:
            rec = reconstruct_image(model, pe, H, W, coords_full)
            Image.fromarray(rec).save(os.path.join(OUTDIR, f"{image_name}_L{L}_W{width}_it{it}.png"))

    rec = reconstruct_image(model, pe, H, W, coords_full)
    final_path = os.path.join(OUTDIR, f"{image_name}_L{L}_W{width}_final.png")
    Image.fromarray(rec).save(final_path)

    plt.figure(figsize=(5,3))
    plt.plot(psnr_hist)
    plt.xlabel("iter"); plt.ylabel("PSNR (dB)"); plt.tight_layout()
    curve_path = os.path.join(OUTDIR, f"{image_name}_L{L}_W{width}_psnr.png")
    plt.savefig(curve_path, dpi=150)
    plt.close()

    print(f"[{image_name}] L={L} W={width} iters={iters} final PSNR={psnr_hist[-1]:.2f}dB time={time.time()-t0:.1f}s")
    return psnr_hist, final_path, curve_path


In [13]:
img_ref = load_img(IMG_PATH)
psnr_hist, final_img, psnr_curve = train_nf(
    img_np=img_ref,
    image_name="provided",
    L=1,
    width=24,
    depth=4,
    iters=1500,
    bs=10_000,
    lr=1e-2,
    snaps=(0,100,500,1000,1500)
)
print("saved:", final_img)
print("psnr curve:", psnr_curve)


[provided] L=1 W=24 iters=1500 final PSNR=18.40dB time=342.4s
saved: nf2d_out/provided_L1_W24_final.png
psnr curve: nf2d_out/provided_L1_W24_psnr.png
