# Bad Apple SIREN Training (H=16)

Train 658 SIREN neural networks (3→16→16→3) for Bad Apple video playback on Zynq Z7020 FPGA.

**Before running:** Make sure runtime is set to **A100 GPU** (Runtime → Change runtime type → A100)

**Steps:**
1. Upload `source.mp4` when prompted
2. Frames get extracted automatically
3. Training runs (~10-15 min on A100)
4. Weights are exported and zipped for download

In [None]:
# Cell 1: Upload source video and extract frames
import os
from google.colab import files

# Upload source.mp4
print("Upload source.mp4 (Bad Apple, 320x172, 7MB)")
uploaded = files.upload()
assert 'source.mp4' in uploaded, f"Expected source.mp4, got {list(uploaded.keys())}"
print(f"Uploaded {len(uploaded['source.mp4'])} bytes")

# Extract frames
os.makedirs('frames', exist_ok=True)
!ffmpeg -i source.mp4 -vf "scale=320:172" frames/frame_%05d.png -y -loglevel warning
n_frames = len([f for f in os.listdir('frames') if f.endswith('.png')])
print(f"Extracted {n_frames} frames")

In [None]:
# Cell 2: Batched SIREN training — GPU-native sampling, no CPU bottleneck
import math, os, struct, time
import numpy as np
import torch
from pathlib import Path
from PIL import Image

# =========================================================
# Constants
# =========================================================
FRAME_DIR = Path('frames')
WEIGHTS_DIR = Path('weights')
TOTAL_FRAMES = len([f for f in os.listdir('frames') if f.endswith('.png')])
FRAME_W, FRAME_H = 320, 172
ASPECT_Y = FRAME_H / FRAME_W  # 0.5375
FRAMES_PER_SEGMENT = 10
N_SEGMENTS = (TOTAL_FRAMES + FRAMES_PER_SEGMENT - 1) // FRAMES_PER_SEGMENT
FRAC_BITS = 28
Q_SCALE = 1 << FRAC_BITS
HIDDEN = 16
OMEGA_0 = 10.0

# Training hyperparameters
EPOCHS = 5000
LR = 1e-4
SAMPLES = 50000
MINI_BATCH = 10000
EVAL_EVERY = 50

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

print(f"Config: {TOTAL_FRAMES} frames, {N_SEGMENTS} segments, H={HIDDEN}")
print(f"Training: {EPOCHS} epochs, {SAMPLES} samples/seg, mini-batch {MINI_BATCH}")

# =========================================================
# Load ALL frame data into a single GPU tensor
# =========================================================
print(f"\nLoading {TOTAL_FRAMES} frames into GPU...")
t0 = time.time()

# frames_gpu: (N_SEGMENTS, FRAMES_PER_SEGMENT, H, W) on GPU
all_seg_frames = []
for seg in range(N_SEGMENTS):
    start_frame = seg * FRAMES_PER_SEGMENT + 1
    end_frame = min(start_frame + FRAMES_PER_SEGMENT, TOTAL_FRAMES + 1)
    frames = []
    for i in range(start_frame, end_frame):
        path = FRAME_DIR / f"frame_{i:05d}.png"
        if not path.exists():
            break
        arr = np.array(Image.open(path).convert('L'), dtype=np.float32) / 127.5 - 1.0
        frames.append(arr)
    if not frames:
        break
    while len(frames) < FRAMES_PER_SEGMENT:
        frames.append(frames[-1])
    all_seg_frames.append(np.stack(frames))

n_seg = len(all_seg_frames)
frames_gpu = torch.from_numpy(np.stack(all_seg_frames)).to(device)  # (N, F, H, W)
del all_seg_frames  # free CPU memory

# Pre-compute coordinate mappings on GPU
x_coords = (torch.arange(FRAME_W, device=device, dtype=torch.float32) / (FRAME_W - 1)) * 2.0 - 1.0
y_coords = (torch.arange(FRAME_H, device=device, dtype=torch.float32) / (FRAME_H - 1)) * 2.0 * ASPECT_Y - ASPECT_Y
t_coords = (torch.arange(FRAMES_PER_SEGMENT, device=device, dtype=torch.float32) / max(FRAMES_PER_SEGMENT - 1, 1)) * 2.0 - 1.0

print(f"Loaded {n_seg} segments in {time.time()-t0:.1f}s")
print(f"  frames_gpu: {frames_gpu.shape} = {frames_gpu.element_size() * frames_gpu.nelement() / 1e6:.0f} MB on GPU")


# =========================================================
# GPU-native sampling — no numpy, no CPU-GPU transfer
# =========================================================
def sample_gpu(frames_gpu, n_seg, samples_per_seg):
    """Sample training data entirely on GPU. Returns coords (N,S,3), targets (N,S,3)."""
    fi = torch.randint(0, FRAMES_PER_SEGMENT, (n_seg, samples_per_seg), device=device)
    yi = torch.randint(0, FRAME_H, (n_seg, samples_per_seg), device=device)
    xi = torch.randint(0, FRAME_W, (n_seg, samples_per_seg), device=device)

    x = x_coords[xi]       # (N, S)
    y = y_coords[yi]       # (N, S)
    t = t_coords[fi]       # (N, S)
    coords = torch.stack([x, y, t], dim=2)  # (N, S, 3)

    # Gather pixel values: frames_gpu is (N, F, H, W)
    seg_idx = torch.arange(n_seg, device=device).unsqueeze(1).expand_as(fi)  # (N, S)
    targets = frames_gpu[seg_idx, fi, yi, xi]  # (N, S)
    targets_3ch = targets.unsqueeze(2).expand(-1, -1, 3)  # (N, S, 3)

    return coords, targets_3ch


# =========================================================
# Batched SIREN
# =========================================================
def init_weights(n_seg, hidden, omega_0, device):
    W1 = torch.empty(n_seg, hidden, 3, device=device)
    W1.uniform_(-1.0 / 3, 1.0 / 3)
    b1 = torch.zeros(n_seg, 1, hidden, device=device)
    bound2 = math.sqrt(6.0 / hidden) / omega_0
    W2 = torch.empty(n_seg, hidden, hidden, device=device)
    W2.uniform_(-bound2, bound2)
    b2 = torch.zeros(n_seg, 1, hidden, device=device)
    bound3 = math.sqrt(6.0 / hidden) / omega_0
    W3 = torch.empty(n_seg, 3, hidden, device=device)
    W3.uniform_(-bound3, bound3)
    b3 = torch.zeros(n_seg, 1, 3, device=device)
    params = [W1, b1, W2, b2, W3, b3]
    for p in params:
        p.requires_grad_(True)
    return params


def batched_forward(coords, params, omega_0):
    W1, b1, W2, b2, W3, b3 = params
    h = torch.bmm(coords, W1.transpose(1, 2)) + b1
    h = torch.sin(omega_0 * h)
    h = torch.bmm(h, W2.transpose(1, 2)) + b2
    h = torch.sin(omega_0 * h)
    out = torch.bmm(h, W3.transpose(1, 2)) + b3
    return torch.sin(out)


# =========================================================
# Q4.28 export
# =========================================================
def float_to_q428(val):
    clamped = max(-8.0, min(val, 8.0 - 1.0 / Q_SCALE))
    raw = int(round(clamped * Q_SCALE))
    if raw < 0:
        raw = raw & 0xFFFFFFFF
    return raw


def export_segment_binary(params, seg_idx, omega_0):
    W1, b1, W2, b2, W3, b3 = params
    layers = [
        (W1[seg_idx].detach().cpu().numpy() * omega_0,
         b1[seg_idx, 0].detach().cpu().numpy() * omega_0),
        (W2[seg_idx].detach().cpu().numpy() * omega_0,
         b2[seg_idx, 0].detach().cpu().numpy() * omega_0),
        (W3[seg_idx].detach().cpu().numpy(),
         b3[seg_idx, 0].detach().cpu().numpy()),
    ]
    all_vals = []
    for weights, biases in layers:
        for j in range(weights.shape[0]):
            for k in range(weights.shape[1]):
                all_vals.append(float_to_q428(weights[j, k]))
        for j in range(biases.shape[0]):
            all_vals.append(float_to_q428(biases[j]))
    bin_path = WEIGHTS_DIR / f"segment_{seg_idx:03d}.bin"
    with open(bin_path, 'wb') as f:
        for val in all_vals:
            f.write(struct.pack('<I', val))
    return bin_path


def export_segment_pt(params, seg_idx, omega_0):
    W1, b1, W2, b2, W3, b3 = params
    state_dict = {
        'layers.0.linear.weight': W1[seg_idx].detach().cpu(),
        'layers.0.linear.bias': b1[seg_idx, 0].detach().cpu(),
        'layers.1.linear.weight': W2[seg_idx].detach().cpu(),
        'layers.1.linear.bias': b2[seg_idx, 0].detach().cpu(),
        'output_layer.weight': W3[seg_idx].detach().cpu(),
        'output_layer.bias': b3[seg_idx, 0].detach().cpu(),
    }
    pt_path = WEIGHTS_DIR / f"segment_{seg_idx:03d}.pt"
    torch.save(state_dict, pt_path)
    return pt_path


# =========================================================
# Evaluation (GPU-native)
# =========================================================
def evaluate_psnr(params, frames_gpu, omega_0, device, max_segs=None):
    W1, b1, W2, b2, W3, b3 = params
    n_seg = frames_gpu.shape[0] if max_segs is None else min(max_segs, frames_gpu.shape[0])
    H, W = FRAME_H, FRAME_W

    # Pre-compute full coordinate grid on GPU
    xx = x_coords.unsqueeze(0).expand(H, -1)  # (H, W)
    yy = y_coords.unsqueeze(1).expand(-1, W)  # (H, W)
    xy_flat = torch.stack([xx.reshape(-1), yy.reshape(-1)], dim=1)  # (H*W, 2)
    n_pixels = H * W

    psnrs = []
    chunk = 64
    for c_start in range(0, n_seg, chunk):
        c_end = min(c_start + chunk, n_seg)
        c_size = c_end - c_start
        seg_mse_sum = torch.zeros(c_size, device=device)

        for fi in range(FRAMES_PER_SEGMENT):
            t_val = t_coords[fi]
            tt = torch.full((n_pixels, 1), t_val, device=device)
            coords = torch.cat([xy_flat, tt], dim=1)  # (H*W, 3)
            coords = coords.unsqueeze(0).expand(c_size, -1, -1)  # (chunk, H*W, 3)

            with torch.no_grad():
                p = [W1[c_start:c_end], b1[c_start:c_end],
                     W2[c_start:c_end], b2[c_start:c_end],
                     W3[c_start:c_end], b3[c_start:c_end]]
                pred = batched_forward(coords, p, omega_0)  # (chunk, H*W, 3)

            pred_gray = pred[:, :, 0].reshape(c_size, H, W)
            pred_gray = pred_gray.clamp(-1.0, 1.0)
            gt = frames_gpu[c_start:c_end, fi]  # (chunk, H, W)
            seg_mse_sum += ((pred_gray - gt) ** 2).mean(dim=(1, 2))

        seg_mse_avg = seg_mse_sum / FRAMES_PER_SEGMENT
        seg_psnr = 10 * torch.log10(4.0 / seg_mse_avg)
        for s in range(c_size):
            psnrs.append((c_start + s, seg_psnr[s].item()))

    return psnrs


# =========================================================
# Main training loop
# =========================================================
WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)

print(f"\nInitializing {n_seg} SIREN networks (3->{HIDDEN}->{HIDDEN}->3)...")
params = init_weights(n_seg, HIDDEN, OMEGA_0, device)
W1, b1, W2, b2, W3, b3 = params

optimizer = torch.optim.Adam(params, lr=LR)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=EPOCHS, eta_min=LR * 0.01)

best_loss = torch.full((n_seg,), float('inf'), device=device)
best_params = [p.detach().clone() for p in params]

n_mini = max(1, SAMPLES // MINI_BATCH)
print(f"Training: {EPOCHS} epochs, {SAMPLES} samples/seg, {n_mini} mini-batches")
print(f"Printing every {EVAL_EVERY} epochs\n")

t_train = time.time()
for epoch in range(EPOCHS):
    # GPU-native sampling — no CPU involved
    coords_all, targets_all = sample_gpu(frames_gpu, n_seg, SAMPLES)

    epoch_loss = 0.0
    perm = torch.randperm(SAMPLES, device=device)

    for mb in range(n_mini):
        start = mb * MINI_BATCH
        end = min(start + MINI_BATCH, SAMPLES)
        idx = perm[start:end]
        batch_coords = coords_all[:, idx]
        batch_targets = targets_all[:, idx]
        pred = batched_forward(batch_coords, params, OMEGA_0)
        loss = ((pred - batch_targets) ** 2).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    scheduler.step()
    avg_loss = epoch_loss / n_mini

    # Track best weights per segment
    if (epoch + 1) % 100 == 0 or epoch == 0:
        with torch.no_grad():
            seg_mse = ((pred - batch_targets) ** 2).mean(dim=(1, 2))
            improved = seg_mse < best_loss
            if improved.any():
                for i, p in enumerate(params):
                    best_params[i][improved] = p[improved].detach()
                best_loss[improved] = seg_mse[improved]

    if (epoch + 1) % EVAL_EVERY == 0 or epoch == 0:
        elapsed = time.time() - t_train
        remaining = elapsed / (epoch + 1) * (EPOCHS - epoch - 1)
        print(f"epoch {epoch+1:5d}/{EPOCHS}: loss={avg_loss:.6f} "
              f"lr={scheduler.get_last_lr()[0]:.2e} "
              f"[{elapsed:.0f}s elapsed, ~{remaining:.0f}s remaining]")
        if (epoch + 1) % 500 == 0:
            psnrs = evaluate_psnr(params, frames_gpu, OMEGA_0, device, max_segs=5)
            avg_psnr = np.mean([p for _, p in psnrs])
            print(f"  PSNR (first 5 segs): {avg_psnr:.1f} dB")

total_time = time.time() - t_train
print(f"\nTraining done in {total_time:.0f}s ({total_time/60:.1f} min)")

# Restore best weights
for i, p in enumerate(params):
    p.data.copy_(best_params[i])

# Full PSNR evaluation
print("\nEvaluating all segments...")
t_eval = time.time()
psnrs = evaluate_psnr(params, frames_gpu, OMEGA_0, device)
eval_time = time.time() - t_eval
all_psnr = [p for _, p in psnrs]
print(f"PSNR: avg={np.mean(all_psnr):.1f}dB "
      f"min={np.min(all_psnr):.1f}dB max={np.max(all_psnr):.1f}dB "
      f"[{eval_time:.0f}s]")

# Export all segments
print(f"\nExporting {n_seg} segments...")
t_export = time.time()
for seg in range(n_seg):
    export_segment_pt(params, seg, OMEGA_0)
    export_segment_binary(params, seg, OMEGA_0)
print(f"Exported in {time.time()-t_export:.1f}s")
print(f"  .pt files: {WEIGHTS_DIR}/segment_*.pt")
print(f"  .bin files: {WEIGHTS_DIR}/segment_*.bin")

print(f"\n{'='*60}")
print(f"Total: {n_seg} segments, {EPOCHS} epochs")
print(f"Time: {total_time:.0f}s training + {eval_time:.0f}s eval")
print(f"PSNR: {np.mean(all_psnr):.1f} dB average")

In [None]:
# Cell 3: Zip weights and download
import shutil
from google.colab import files

# Create zip of all weight files
shutil.make_archive('weights_h16', 'zip', '.', 'weights')
print(f"Created weights_h16.zip")
!ls -lh weights_h16.zip
!echo "Contents:" && ls weights/ | head -10 && echo "..." && ls weights/ | wc -l && echo "total files"

# Download
files.download('weights_h16.zip')