In [1]:
import os
import time
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import datasets, transforms
import kornia
import kornia.filters as kf
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PIN_MEMORY = True if DEVICE.type == "cuda" else False
NUM_WORKERS = 0

print(f"Running on: {DEVICE}")

DATA_DIR = "./data"
CHECKPOINT_DIR = "./checkpoints_joint_5k"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if DEVICE.type == "cuda":
    torch.cuda.manual_seed_all(SEED)

In [None]:
SUBSET_SIZE = 5000        
COORD_ITERS = 20            
PHI_EPOCHS_PER_ITER = 5    
THETA_STEPS_PER_ITER = 50  

PRETRAIN_EPOCHS = 3
PRETRAIN_BATCH = 64

HOAG_MAX_INNER_STEPS = 100
INNER_LR = 0.02
OUTER_LR = 0.01
CG_MAX_ITERS = 50           
CG_TOL = 1e-6

THETA_INIT = torch.tensor([0.0, -3.0], dtype=torch.float32)
BLUR_KERNEL_SIZE = (11, 11)
BLUR_SIGMA = (3.0, 3.0)
NOISE_STD = 0.12

GRAD_CLIP_VALUE = 2.0
THETA_CLAMP = (-12.0, 12.0)
EPSILON_MIN = 1e-8

torch.backends.cudnn.benchmark = True
THETA_CLAMP_MIN = torch.tensor([-12.0, -12.0])
THETA_CLAMP_MAX = torch.tensor([12.0, -2.0]) 

In [None]:
def clamp_theta(theta):
    return torch.max(torch.min(theta, THETA_CLAMP_MAX.to(theta.device)), THETA_CLAMP_MIN.to(theta.device))

def ensure_bchw(x):
    if isinstance(x, np.ndarray): x = torch.from_numpy(x)
    if not torch.is_tensor(x): x = torch.as_tensor(x)
    d = x.dim()
    if d == 4: return x
    if d == 3: return x.unsqueeze(0) if x.size(0) in [1,3] else x.unsqueeze(1)
    if d == 2: return x.unsqueeze(0).unsqueeze(0)
    raise ValueError(f"Unsupported shape {x.shape}")

def clamp_theta(theta):
    return torch.clamp(theta, THETA_CLAMP[0], THETA_CLAMP[1])

raw_k = kornia.filters.get_gaussian_kernel2d(BLUR_KERNEL_SIZE, BLUR_SIGMA)
kernel = torch.as_tensor(raw_k, dtype=torch.float32, device=DEVICE)
while kernel.dim() > 2: kernel = kernel.squeeze(0)
kernel = kernel / kernel.sum()
BLUR_KERNEL = kernel.unsqueeze(0).unsqueeze(0).to(DEVICE)

In [None]:
class PrecomputedBlurredDataset(Dataset):
    def __init__(self, path):
        d = torch.load(path, map_location='cpu')
        self.y = d['y']
        self.x = d.get('x', None)
        self.labels = d['labels']
    def __len__(self): return len(self.labels)
    def __getitem__(self, i): 
        return self.y[i], (self.x[i] if self.x is not None else torch.zeros_like(self.y[i])), self.labels[i]

def precompute_blur_to_disk(sharp_dataset, out_path):
    loader = DataLoader(sharp_dataset, batch_size=512, shuffle=False, num_workers=NUM_WORKERS)
    Ys, Xs, Ls = [], [], []
    print(f"Precomputing blur for {len(sharp_dataset)} samples -> {out_path}...")
    t0 = time.time()
    for xb, labs in loader:
        xb = xb.to(DEVICE)
        # y = A(x) + noise
        yb = kf.gaussian_blur2d(xb, kernel_size=BLUR_KERNEL_SIZE, sigma=BLUR_SIGMA)
        if NOISE_STD > 0: yb += torch.randn_like(yb) * NOISE_STD
        Ys.append(yb.clamp(0,1).cpu())
        Xs.append(xb.cpu())
        Ls.append(labs)
    torch.save({'y': torch.cat(Ys), 'x': torch.cat(Xs), 'labels': torch.cat(Ls)}, out_path)
    print(f"Done in {time.time()-t0:.1f}s")

In [None]:
# Models
class Classifier(nn.Module):
    def __init__(self, n_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 48, 3, padding=1), nn.BatchNorm2d(48), nn.ReLU(),
            nn.Conv2d(48, 48, 3, padding=1), nn.BatchNorm2d(48), nn.ReLU(), nn.MaxPool2d(2,2),
            nn.Conv2d(48, 96, 3, padding=1), nn.BatchNorm2d(96), nn.ReLU(),
            nn.Conv2d(96, 96, 3, padding=1), nn.BatchNorm2d(96), nn.ReLU(), nn.MaxPool2d(2,2),
            nn.Conv2d(96, 192, 3, padding=1), nn.BatchNorm2d(192), nn.ReLU(), nn.AdaptiveAvgPool2d((3,3))
        )
        self.classifier = nn.Sequential(nn.Flatten(), nn.Linear(192*9, 256), nn.ReLU(), nn.Dropout(0.25), nn.Linear(256, n_classes))
    def forward(self, x): return self.classifier(self.features(x))

def total_variation_regularizer(x, theta):
    x = ensure_bchw(x)
    theta = clamp_theta(theta)
    scale = torch.exp(theta[0])           
    eps = torch.exp(theta[1])**2 + EPSILON_MIN  
    grad = kornia.filters.spatial_gradient(x, mode='diff')
    magsq = grad[:,:,0,:,:]**2 + grad[:,:,1,:,:]**2 + eps
    return scale * torch.sum(torch.sqrt(magsq))

def inner_loss_func(x, theta, fixed_params, y):
    x = ensure_bchw(x)
    y = ensure_bchw(y)
    pad_h, pad_w = BLUR_KERNEL.shape[-2]//2, BLUR_KERNEL.shape[-1]//2
    pred = F.conv2d(x, BLUR_KERNEL, padding=(pad_h, pad_w))
    recon = 0.5 * torch.sum((y - pred)**2)
    reg = total_variation_regularizer(x, theta)
    return recon + reg

def outer_loss_func(x_hat, fixed_params, label):
    classifier_net = Classifier(10).to(DEVICE)
    classifier_net.load_state_dict(fixed_params['phi'])
    classifier_net.eval()
    logits = classifier_net(x_hat.clamp(0,1))
    return nn.CrossEntropyLoss()(logits, label.long().to(DEVICE))

In [None]:
def hessian_vector_product(h_scalar, w, v):
    grad_h = autograd.grad(h_scalar, w, create_graph=True, retain_graph=True)[0]
    grad_h_v = torch.dot(grad_h.reshape(-1), v.reshape(-1))
    return autograd.grad(grad_h_v, w, retain_graph=True)[0]

def conjugate_gradient(A_func, b, max_iters=CG_MAX_ITERS, tol=CG_TOL):
    x = torch.zeros_like(b)
    r = b.clone(); p = r.clone()
    rsold = torch.dot(r.view(-1), r.view(-1))
    for i in range(max_iters):
        if rsold < tol: break
        Ap = A_func(p)
        denom = torch.dot(p.view(-1), Ap.view(-1)) + 1e-12
        alpha = rsold / denom
        x = x + alpha * p
        r = r - alpha * Ap
        rsnew = torch.dot(r.view(-1), r.view(-1))
        p = r + (rsnew/rsold) * p
        rsold = rsnew
    return x

In [None]:
# PHASE A
def hoag_step_for_theta(theta, fixed_params, loader, steps=10):
    theta = theta.detach().requires_grad_(True)
    history = []
    iterator = iter(loader)
    
    for i in range(steps):
        try:
            batch = next(iterator)
        except StopIteration:
            iterator = iter(loader)
            batch = next(iterator)
            
        y_cpu, _, label_cpu = batch
        y = ensure_bchw(y_cpu).to(DEVICE)
        label = label_cpu.to(DEVICE)
        w = y.clone().detach().requires_grad_(True)
        v_mom = torch.zeros_like(w)
        beta = 0.9
        
        for _ in range(HOAG_MAX_INNER_STEPS):
            loss = inner_loss_func(w, theta, fixed_params, y)
            gw = autograd.grad(loss, w, create_graph=True)[0]
            with torch.no_grad():
                v_mom = beta * v_mom + (1-beta) * gw
                w = w - INNER_LR * v_mom
                w.clamp_(0, 1)
            w.requires_grad_(True)
            if torch.norm(gw) < 1e-4: break
            
        outer_loss = outer_loss_func(w, fixed_params, label)
        b = autograd.grad(outer_loss, w, retain_graph=True)[0]
        
        # Inverse Hessian (v = H^-1 g) via CG
        w_cg = w.detach().requires_grad_(True)
        theta_cg = theta.detach()
        def A_hvp(v):
            h = inner_loss_func(w_cg, theta_cg, fixed_params, y)
            return hessian_vector_product(h, w_cg, v)
        
        q = conjugate_gradient(A_hvp, b)
        
        # Hypergradient
        w_clean = w.detach().requires_grad_(True)
        theta_clean = theta.detach().requires_grad_(True)
        scalar = inner_loss_func(w_clean, theta_clean, fixed_params, y)
        gw_h = autograd.grad(scalar, w_clean, create_graph=True)[0]
        
        # Cross derivative vector product
        cross = autograd.grad(gw_h, theta_clean, grad_outputs=q, allow_unused=True)[0]
        
        hypergrad = -cross if cross is not None else torch.zeros_like(theta)
        
        # Gradient Descent on Theta
        with torch.no_grad():
            theta = theta - OUTER_LR * hypergrad
            theta = clamp_theta(theta)
            theta.requires_grad_(True)
            
        history.append(outer_loss.item())
        if (i+1) % 10 == 0:
            print(f"    [Theta-HOAG] Step {i+1}/{steps} | Loss: {outer_loss.item():.4f} | |Grad|: {torch.norm(hypergrad):.4e}")

    return theta.detach(), np.mean(history)

In [None]:
# PHASE B
def update_phi_on_reconstructions(classifier, theta, loader, epochs=1):
    print(f"  [Phi-SGD] Training classifier on reconstructions...")
    classifier.train()
    optimizer = torch.optim.AdamW(classifier.parameters(), lr=2e-4, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()
    theta_fixed = theta.detach().to(DEVICE)
    
    for epoch in range(epochs):
        avg_loss = 0
        count = 0
        for y_cpu, _, labels in loader:
            y = ensure_bchw(y_cpu).to(DEVICE)
            labels = labels.to(DEVICE)
            w_hat = y.clone().detach().requires_grad_(True)
            for _ in range(50): 
                loss = inner_loss_func(w_hat, theta_fixed, None, y)
                gw = autograd.grad(loss, w_hat, create_graph=False)[0]
                with torch.no_grad():
                    w_hat = w_hat - 0.02 * gw
                    w_hat.clamp_(0, 1)
                w_hat.requires_grad_(True)
            
            optimizer.zero_grad()
            logits = classifier(w_hat)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            
            avg_loss += loss.item() * y.size(0)
            count += y.size(0)
            
        print(f"    [Phi-SGD] Epoch {epoch+1} Avg Loss: {avg_loss/count:.4f}")
    
    return classifier

In [None]:
def evaluate_pipeline(theta, classifier, test_loader):
    classifier.eval()
    theta_eval = clamp_theta(theta)
    correct = 0; total = 0
    
    # Process in chunks
    for y_cpu, _, labels in test_loader:
        y = ensure_bchw(y_cpu).to(DEVICE)
        labels = labels.to(DEVICE)
        
        # High quality reconstruction for evaluation (100 steps)
        w_hat = y.clone().detach().requires_grad_(True)
        for _ in range(100):
            loss = inner_loss_func(w_hat, theta_eval, None, y)
            gw = autograd.grad(loss, w_hat, create_graph=False)[0]
            with torch.no_grad():
                w_hat = (w_hat - 0.01 * gw).clamp(0,1)
            w_hat.requires_grad_(True)
            
        with torch.no_grad():
            preds = classifier(w_hat).argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            
    return correct / total

In [None]:
def main():
    print(f"\n (Subset: {SUBSET_SIZE})")
    
    train_sharp = datasets.MNIST(root=DATA_DIR, train=True, download=True, transform=transforms.ToTensor())
    test_sharp = datasets.MNIST(root=DATA_DIR, train=False, download=True, transform=transforms.ToTensor())
    
    indices = torch.randperm(len(train_sharp))[:SUBSET_SIZE]
    train_subset = Subset(train_sharp, indices)

    train_blur_path = f"./train_blur_{SUBSET_SIZE}.pt"
    test_blur_path = "./test_blur_full.pt"
    
    if not os.path.exists(train_blur_path):
        precompute_blur_to_disk(train_subset, train_blur_path)
    else:
        print(f"Found cached training data: {train_blur_path}")
        
    if not os.path.exists(test_blur_path):
        precompute_blur_to_disk(test_sharp, test_blur_path)
    else:
        print(f"Found cached test data: {test_blur_path}")
        
    train_ds = PrecomputedBlurredDataset(train_blur_path)
    test_ds = PrecomputedBlurredDataset(test_blur_path)
    

    theta_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=0)
    phi_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=100, shuffle=False)

    classifier = Classifier(10).to(DEVICE)
    theta = THETA_INIT.clone().to(DEVICE).requires_grad_(True)
    
    optim = torch.optim.Adam(classifier.parameters(), lr=1e-3)
    for ep in range(PRETRAIN_EPOCHS):
        for y, _, l in phi_loader:
            y, l = y.to(DEVICE), l.to(DEVICE)
            loss = nn.CrossEntropyLoss()(classifier(y), l)
            optim.zero_grad(); loss.backward(); optim.step()
    
    # Coordinate Descent Loop
    best_acc = 0.0
    
    for cycle in range(COORD_ITERS):
        print(f"\n>>> CYCLE {cycle+1} / {COORD_ITERS} <<<")
        
        # PHASE A: Update Regularizer (Theta)
        #print("  [Phase A] Optimizing Reconstruction Parameters (Theta)...")
        fixed_phi = {'phi': {k: v.cpu().clone() for k, v in classifier.state_dict().items()}}
        theta, _ = hoag_step_for_theta(theta, fixed_phi, theta_loader, steps=THETA_STEPS_PER_ITER)
        print(f"   => New Theta: {theta.cpu().numpy()}")

        # PHASE B: Update Classifier (Phi)
        #print("  [Phase B] Adapting Classifier to New Reconstructions...")
        classifier = update_phi_on_reconstructions(classifier, theta, phi_loader, epochs=PHI_EPOCHS_PER_ITER)
        
        # Eval
        acc = evaluate_pipeline(theta, classifier, test_loader)
        print(f"  [Result] Cycle {cycle+1} Test Acc: {acc*100:.2f}%")
        
        if acc > best_acc:
            best_acc = acc
            torch.save({'theta': theta, 'phi': classifier.state_dict()}, os.path.join(CHECKPOINT_DIR, "best_model.pt"))
            print("   * New Best Model Saved *")

if __name__ == "__main__":
    main()

Running on: cuda

=== Starting Approach 2: Joint Optimization (Subset: 5000) ===
Selecting 5000 random samples from training set...
Found cached training data: ./train_blur_5000.pt
Found cached test data: ./test_blur_full.pt
Pretraining Classifier on raw blurred images (warmup)...

>>> CYCLE 1 / 20 <<<
  [Phase A] Optimizing Reconstruction Parameters (Theta)...
    [Theta-HOAG] Step 10/50 | Loss: 0.0002 | |Grad|: 4.0098e-04
    [Theta-HOAG] Step 20/50 | Loss: 16.5209 | |Grad|: 8.6301e+00
    [Theta-HOAG] Step 30/50 | Loss: 8.0150 | |Grad|: 1.5566e+01
    [Theta-HOAG] Step 40/50 | Loss: 0.0027 | |Grad|: 1.0438e-02
    [Theta-HOAG] Step 50/50 | Loss: 0.0483 | |Grad|: 3.5922e-01
   => New Theta: [-1.7385604 -1.4856956]
  [Phase B] Adapting Classifier to New Reconstructions...
  [Phi-SGD] Training classifier on reconstructions...
    [Phi-SGD] Epoch 1 Avg Loss: 0.3245
    [Phi-SGD] Epoch 2 Avg Loss: 0.2818
    [Phi-SGD] Epoch 3 Avg Loss: 0.2539
    [Phi-SGD] Epoch 4 Avg Loss: 0.2426
    [P