### Get plots on multiclass logistic regression on CIFAR-10.

In [1]:
import jax
import jax.numpy as jnp
import jax.random as jr
from jax import grad
import matplotlib.pyplot as plt
import torch, torchvision
import torchvision.transforms as T
import numpy as np
from tqdm.auto import tqdm

from optimizers.mrcg import (
    scaling_selection as mrcg_scaling_selection,
    backtracking_LS   as mrcg_backtracking,
    forward_backward_LS as mrcg_forwardback,
)

### 1. Data and Hyperparameters

In [2]:
def load_cifar10_flatten():
    tfm = T.Compose([T.ToTensor()])
    ds  = torchvision.datasets.CIFAR10(root="~/.torch/datasets",
                                       train=True, download=True, transform=tfm)
    X = torch.stack([x.view(-1) for x, _ in ds]).numpy()          # (50k,3072)
    y = torch.tensor([lbl for _, lbl in ds]).numpy()              # (50k,)
    X = (X - X.mean(0, keepdims=True)) / (X.std(0, keepdims=True) + 1e-6)
    X = np.concatenate([X, np.ones((X.shape[0],1))], axis=1)      # +bias feat
    return X.astype(np.float32), y.astype(np.int32)

X_np, y_np = load_cifar10_flatten()
N, d = X_np.shape
C     = 10
print(f"Loaded CIFAR-10: X {X_np.shape}, y {y_np.shape}")

# one-hot for *all* 10 classes  (→ first loss = ln 10)
Y_onehot = (y_np[:,None] == np.arange(C)[None,:]).astype(np.float32)

# JAX arrays
X, Y = jnp.asarray(X_np), jnp.asarray(Y_onehot)

lambda_ = 1e-3
sigma, theta, rho = 0.0, 0.5, 1e-4
key = jr.PRNGKey(42)

  5%|▌         | 9.01M/170M [02:14<40:13, 66.9kB/s]  


KeyboardInterrupt: 

### 2. Define oracle counting identical to report

In [None]:
def f_raw(params):
    """params shape (d*(C-1),) for classes 0…8; class-9 weights are zero."""
    W = params.reshape(d, C-1)                 # (d,9)
    logits  = X @ W                            # (N,9)
    ref_col = jnp.zeros((N,1))                 # class-9 logits = 0
    full_logits = jnp.concatenate([logits, ref_col], axis=1)  # (N,10)

    log_probs = full_logits - jax.scipy.special.logsumexp(
                                full_logits, axis=1, keepdims=True)
    ce  = -jnp.mean(jnp.sum(Y * log_probs, axis=1))            # NOTE: no slice
    reg = 0.5 * lambda_ * jnp.sum(params**2)
    return ce + reg

oracle_calls = 0
def f_counted(p):   # +1
    global oracle_calls
    oracle_calls += 1
    return f_raw(p)

def grad_count(p):  # f +1 already, add +1 here  ⇒ 2 total
    global oracle_calls
    oracle_calls += 1
    return grad(f_counted)(p)

def hvp_count(p,v): # f+g already, add +2 here  ⇒ 4 total
    global oracle_calls
    oracle_calls += 2
    return jax.jvp(grad(f_counted), (p,), (v,))[1]


### 3. Optimizer steps

In [None]:
def mrcg_step(state):
    params, key = state
    key, sk = jr.split(key)

    g = grad_count(params)
    p, flag = mrcg_scaling_selection(
        g, lambda p,_: f_counted(p), params, sigma, sk,
        hv_fun=lambda p,v: hvp_count(p,v)
    )

    if flag in ("SPC","LPC"):
        alpha = mrcg_backtracking(lambda p,_: f_counted(p), sk, theta, rho,
                              params, g, p)
    else:
        alpha = mrcg_forwardback(lambda p,_: f_counted(p), sk, theta, rho,
                             params, g, p)
    return (params + alpha*p, key)

### 4. Run it

In [None]:
state = (jnp.zeros(d*(C-1), jnp.float32), key)
MAX_CALLS, GRAD_TOL = 100_000, 1e-4
obj, orc = [], []

bar = tqdm(total=MAX_CALLS, desc="Oracle calls", unit="call", dynamic_ncols=True)
while oracle_calls < MAX_CALLS:
    oc_prev = oracle_calls
    state   = mrcg_step(state)
    f_val   = f_counted(state[0])               # +1
    g_norm  = jnp.linalg.norm(grad_count(state[0]))  # +1
    obj.append(f_val); orc.append(oracle_calls)
    bar.update(oracle_calls - oc_prev)
    if g_norm <= GRAD_TOL: break
bar.close()

print(f"Stopped after {oracle_calls} calls; final ‖g‖={float(g_norm):.2e}")

### 5. Plot

In [None]:
import matplotlib.ticker as mtick
fig, ax = plt.subplots(figsize=(4.5,3.8))
ax.set_xscale("log")
ax.plot(orc, obj, label="MRCG")
ax.yaxis.set_major_formatter(mtick.ScalarFormatter(useMathText=True))
ax.ticklabel_format(axis='y', style='sci', scilimits=(0,0))
ax.set_xlabel("Oracle Calls"); ax.set_ylabel("Objective Value")
ax.grid(True, which="both", lw=0.3); ax.legend()
fig.tight_layout()
fig.savefig("mrcg_cifar10_curve.png", dpi=150)
plt.show()
