<a href="https://colab.research.google.com/github/Ryu1231/TR-vs-SGD/blob/main/Reproduce_TR_%26_SGD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Subset
import time
import pickle
import os

device = torch.device('cpu')
def load_cifar10_data(subset_size=5000, seed=1234):
    torch.manual_seed(seed)
    np.random.seed(seed)

    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    indices = torch.randperm(len(train_dataset))[:subset_size]
    train_subset = Subset(train_dataset, indices)

    X = torch.stack([img.flatten() for img, _ in train_subset]).t()
    y = torch.zeros(len(train_subset), 10, device=device)
    for i, (_, label) in enumerate(train_subset):
        y[i, label] = 1

    X_test = torch.stack([img.flatten() for img, _ in test_dataset]).t()
    y_test = torch.zeros(len(test_dataset), 10, device=device)
    for i, (_, label) in enumerate(test_dataset):
        y_test[i, label] = 1

    mean_data = X.mean(dim=1, keepdim=True)
    std_data = X.std(dim=1, keepdim=True)
    std_data[std_data == 0] = 1
    X = (X - mean_data) / std_data
    X_test = (X_test - mean_data) / std_data

    return X.to(device), y, X_test.to(device), y_test

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x, apply_softmax=True):
        z = [x]
        x = self.flatten(x)
        x = self.fc1(x)
        x = torch.sigmoid(x)
        z.append(x)
        x = self.fc2(x)
        z.append(x)
        if apply_softmax:
            x = F.softmax(x, dim=1)
        z.append(x)
        return z

def unflatten_parameters(model, params):
    idx = 0
    for p in model.parameters():
        numel = p.numel()
        p.data.copy_(params[idx:idx+numel].reshape(p.shape))
        idx += numel

def compute_model(model, params, X, y, lambda_reg=0, compute_hess=False):
    n = X.size(1)
    unflatten_parameters(model, params)

    z = model(X.t())
    outputs = z[-1]

    ll = torch.sum(y * torch.log(outputs + 1e-10))
    loss = -ll / n
    perr = torch.mean((torch.argmax(y, dim=1) != torch.argmax(outputs, dim=1)).float())

    reg_term = 0.5 * lambda_reg * torch.sum(params ** 2)
    total_loss = loss + reg_term

    model.zero_grad()
    outputs = model(X.t())[-1]
    ll = torch.sum(y * torch.log(outputs + 1e-10))
    loss = -ll / n + reg_term
    grad = torch.autograd.grad(loss, model.parameters(), create_graph=compute_hess)
    grad_flat = torch.cat([g.flatten() for g in grad])

    if not compute_hess:
        return total_loss.item(), perr.item(), grad_flat

    def hess(V):
        model.zero_grad()
        grad_v = torch.sum(grad_flat * V)
        Hv = torch.autograd.grad(grad_v, model.parameters(), retain_graph=True)
        Hv_flat = torch.cat([hv.flatten() for hv in Hv])
        return Hv_flat + lambda_reg * V

    return total_loss.item(), grad_flat, hess, perr.item()

def cg_steihaug(H, g, delta, params, x0):
    tr_model = lambda x: 0.5 * torch.dot(x, H(x)) + torch.dot(x, g)
    errtol, maxit, _ = params
    x = torch.zeros_like(g) if x0 is None else x0
    r = -g - H(x)
    z = r
    rho = torch.dot(z, r)
    tst = torch.norm(r)
    terminate = errtol * torch.norm(r)
    it = 0
    hatdel = delta
    rhoold = 1.0

    if tst <= terminate:
        return x, tr_model(x), 0, 'RS'

    while (tst > terminate and it < maxit and torch.norm(x) <= hatdel):
        if it == 0:
            p = z
        else:
            beta = rho / rhoold
            p = z + beta * p

        w = H(p)
        alpha = torch.dot(w, p)
        if alpha <= 0:
            ac = torch.dot(p, p)
            bc = 2 * torch.dot(x, p)
            cc = torch.dot(x, x) - delta * delta
            discriminant = bc**2 - 4 * ac * cc
            alpha = (-bc + torch.sqrt(discriminant)) / (2 * ac) if discriminant >= 0 else 0
            x = x + alpha * p
            return x, tr_model(x), it, 'NC'
        else:
            alpha = rho / alpha
            if torch.norm(x + alpha * p) > delta:
                ac = torch.dot(p, p)
                bc = 2 * torch.dot(x, p)
                cc = torch.dot(x, x) - delta * delta
                discriminant = bc**2 - 4 * ac * cc
                alpha = (-bc + torch.sqrt(discriminant)) / (2 * ac) if discriminant >= 0 else 0
                x = x + alpha * p
                return x, tr_model(x), it, 'TR'

        x = x + alpha * p
        r = r - alpha * w
        tst = torch.norm(r)
        if tst <= terminate:
            return x, tr_model(x), it, 'RS'
        if torch.norm(x) >= hatdel:
            return x, tr_model(x), it, 'TR'

        rhoold = rho
        z = r
        rho = torch.dot(z, r)
        it += 1

    return x, tr_model(x), it, 'MX'

def subsampled_tr(model, X, y, X_test, y_test, lambda_reg, options):
    n = X.size(1)
    sz = int(0.05 * n)

    delta = options.get('delta', 5)
    eta1 = options.get('eta1', 0.8)
    eta2 = options.get('eta2', 1e-4)
    gamma1 = options.get('gamma1', 2)
    gamma2 = options.get('gamma2', 1.2)
    maxNoProps = options.get('maxNoProps', float('inf'))
    max_iters = options.get('max_iters', 100)
    inner_iters = options.get('inner_iters', 100)
    cur = options.get('cur_iter', 0)
    sz = options.get('hs', sz)

    if cur >= 1:
        noProps = options['tr_noProps'][cur]
        tr_losses = options['tr_losses'][:cur] + [0] * max_iters
        tr_noProps = options['tr_noProps'][:cur] + [0] * max_iters
        te_errs = options['te_errs'][:cur] + [0] * max_iters
    else:
        noProps = 1
        noMVPs = 1
        tr_losses = [0] * max_iters
        tr_noProps = [0] * max_iters
        te_errs = [0] * max_iters

    params = options.get('params')

    print("\nStart training...\n")
    for iter in range(cur + 1, cur + max_iters + 1):
        if noProps > maxNoProps:
            iter -= 1
            break

        idx = torch.randperm(n)[:sz]
        x_sample = X[:, idx]
        y_sample = y[idx]

        ll, tr_err, grad = compute_model(model, params, X, y, lambda_reg)
        tr_loss = ll
        grad = grad + lambda_reg * params
        _, _, hess, _ = compute_model(model, params, x_sample, y_sample, lambda_reg, compute_hess=True)
        HessV = lambda V: hess(V)
        noProps += n

        _, te_err, _ = compute_model(model, params, X_test, y_test, lambda_reg)

        tr_losses[iter-1] = tr_loss
        te_errs[iter-1] = te_err
        tr_noProps[iter-1] = noProps

        print(f"Training loss: {tr_loss:.4f}")
        print(f"Test Error: {te_err:.4f}")

        fail_count = 0
        while True:
            steihaugParams = [1e-9, 250, 0]
            if fail_count == 0:
                s0 = torch.randn_like(params)
                s0 = 0.99 * delta * s0 / torch.norm(s0)
            s, m, num_cg, iflag = cg_steihaug(HessV, grad, delta, steihaugParams, s0)
            noProps += num_cg * 2 * x_sample.size(1)
            noMVPs += num_cg

            if m >= 0:
                s = torch.zeros_like(s)
                break

            newll_err = compute_model(model, params + s, X, y, lambda_reg)
            newll, _, _ = newll_err
            noProps += n
            newll = newll + 0.5 * lambda_reg * torch.norm(params + s) ** 2
            rho = (tr_loss - newll) / -m if m != 0 else -float('inf')

            if rho < eta2:
                fail_count += 1
                delta = delta / gamma1
                s0 = delta * s / torch.norm(s)
            elif rho < eta1:
                params = params + s
                delta = gamma2 * delta
                break
            else:
                params = params + s
                delta = gamma1 * delta
                break

    options = {
        'params': params,
        'cur_iter': iter,
        'tr_losses': tr_losses[:iter],
        'te_errs': te_errs[:iter],
        'tr_noProps': tr_noProps[:iter],
    }
    return params, options

def momentum_sgd(model, X, y, X_test, y_test, lambda_reg, options):
    n = X.size(1)
    sz = int(0.05 * n)

    alpha = options.get('alpha', 0.005)
    beta = options.get('beta', 0.9)
    maxNoProps = options.get('maxNoProps', float('inf'))
    max_iters = options.get('max_iters', 100)
    cur = options.get('cur_iter', 0)
    sz = options.get('hs', sz)

    if cur >= 1:
        noProps = options['tr_noProps'][cur]
        tr_losses = options['tr_losses'][:cur] + [0] * max_iters
        tr_noProps = options['tr_noProps'][:cur] + [0] * max_iters
        te_errs = options['te_errs'][:cur] + [0] * max_iters
    else:
        noProps = 1
        tr_losses = [0] * max_iters
        tr_noProps = [0] * max_iters
        te_errs = [0] * max_iters

    params = options.get('params')

    print("\nStart training...\n")
    momentum_params = torch.zeros_like(params)
    for iter in range(cur + 1, cur + max_iters + 1):
        if noProps > maxNoProps:
            iter -= 1
            break

        idx = torch.randperm(n)[:sz]
        x_sample = X[:, idx]
        y_sample = y[idx]

        _, _, grad = compute_model(model, params, x_sample, y_sample, lambda_reg)
        grad = grad + lambda_reg * params
        noProps += 2 * x_sample.size(1)

        ll, tr_err, _ = compute_model(model, params, X, y, lambda_reg)
        tr_loss = ll

        _, te_err, _ = compute_model(model, params, X_test, y_test, lambda_reg)

        tr_losses[iter-1] = tr_loss
        te_errs[iter-1] = te_err
        tr_noProps[iter-1] = noProps

        print(f"Training loss: {tr_loss:.4f}")
        print(f"Test Error: {te_err:.4f}")

        momentum_params = beta * momentum_params - alpha * grad
        params = params + momentum_params

    options = {
        'params': params,
        'cur_iter': iter,
        'tr_losses': tr_losses[:iter],
        'te_errs': te_errs[:iter],
        'tr_noProps': tr_noProps[:iter],
    }
    return params, options

def cifar_classification(method, hs_sub=0.05, delta=1000, alpha=0.05, init=0, maxNP=1e6, seed=1234):
    torch.manual_seed(seed)
    np.random.seed(seed)
    X, y, X_test, y_test = load_cifar10_data(subset_size=5000, seed=seed)

    input_size = 32 * 32 * 3
    hidden_size = 512
    num_classes = 10
    model = MLP(input_size, hidden_size, num_classes).to(device)
    psize = sum(p.numel() for p in model.parameters())
    lambda_reg = 0

    if init == 0:
        initial_guess = torch.zeros(psize, device=device)
        print("Zero Initialization")
    else:
        initial_guess = torch.randn(psize, device=device)
        initial_guess = initial_guess / torch.norm(initial_guess)
        print("Normalized Random Initialization")

    options = {
        'params': initial_guess,
        'name': 'cifar10_classification',
        'inner_iters': 250,
        'alpha': alpha,
        'delta': delta,
        'max_iters': 500,
        'cur_iter': 0,
        'hs': int(hs_sub * X.size(1)),
        'maxNoProps': maxNP
    }

    if method == 'TR':
        params, options = subsampled_tr(model, X, y, X_test, y_test, lambda_reg, options)
    elif method == 'SGD':
        params, options = momentum_sgd(model, X, y, X_test, y_test, lambda_reg, options)
    return options

# Random Initialization

TR

In [None]:
delta_list = [10, 100]
maxNP = 1e6
tr_results = []
for delta in delta_list:
    print(f"Running TR with delta={delta}")
    options = cifar_classification('TR', hs_sub=0.05, delta=delta, init=1, maxNP=maxNP, seed=1234)
    tr_results.append(options)

SGD

In [None]:
alpha_list = [0.05, 0.2]
maxNP = 1e6
sgd_results = []
for alpha in alpha_list:
    print(f"Running SGD with alpha={alpha}")
    options = cifar_classification('SGD', hs_sub=0.05, alpha=alpha, init=1, maxNP=maxNP, seed=1234)
    sgd_results.append(options)

Plot

In [None]:
plt.figure(figsize=(8, 6))
colors_tr = ['blue', 'purple']
colors_sgd = ['green', 'red', 'organce']
for i, (results, delta) in enumerate(zip(tr_results, delta_list)):
    plt.plot(results['tr_noProps'], results['tr_losses'], label=f'TR-CG: Δ0={delta}', color=colors_tr[i])
for i, (results, alpha) in enumerate(zip(sgd_results, alpha_list)):
    plt.plot(results['tr_noProps'], results['tr_losses'], label=f'SGD: α={alpha}', color=colors_sgd[i], linestyle='--')
plt.xscale('log')
plt.yscale('log')
plt.xlabel('# of Props')
plt.ylabel('Training Loss')
plt.title('Image Classification: CIFAR-10 (Random Init)')
plt.legend()
plt.grid(True, which="both", ls="--")
plt.savefig('figure_a.png')
plt.show()
plt.close()

plt.figure(figsize=(8, 6))
for i, (results, delta) in enumerate(zip(tr_results, delta_list)):
    plt.plot(results['tr_noProps'], results['te_errs'], label=f'TR-CG: Δ0={delta}', color=colors_tr[i])
for i, (results, alpha) in enumerate(zip(sgd_results, alpha_list)):
    plt.plot(results['tr_noProps'], results['te_errs'], label=f'SGD: α={alpha}', color=colors_sgd[i], linestyle='--')
plt.xscale('log')
plt.xlabel('# of Props')
plt.ylabel('Test Error')
plt.title('Image Classification: CIFAR-10 (Random Init)')
plt.legend()
plt.grid(True, which="both", ls="--")
plt.savefig('figure_c.png')
plt.show()
plt.close()

# Saddle Point Initialization

TR

In [None]:
maxNP = 1e5
print(f"Running TR with delta={delta}")
tr_zero = cifar_classification('TR', hs_sub=0.05, delta=1, init=0, maxNP=maxNP, seed=1234)

SGD

In [None]:
maxNP = 1e5
alpha = 0.01
print(f"Running SGD with alpha={alpha}")
sgd_zero = cifar_classification('SGD', hs_sub=0.05, alpha=alpha, init=0, maxNP=maxNP, seed=1234)

Plot

In [None]:
plt.figure(figsize=(8, 6))
plt.plot(tr_zero['tr_noProps'], tr_zero['tr_losses'], label=f'TR-CG: Δ={delta}', color='blue')
plt.plot(sgd_zero['tr_noProps'], sgd_zero['tr_losses'], label=f'SGD: α={alpha}', color='orange', linestyle='--')
plt.xscale('log')
plt.yscale('log')
plt.xlabel('# of Props')
plt.ylabel('Training Loss')
plt.title('Image Classification: CIFAR-10 (Saddle Point Init)')
plt.legend()
plt.grid(True, which="both", ls="--")
plt.savefig('figure_d.png')
plt.show()
plt.close()

plt.figure(figsize=(8, 6))
plt.plot(tr_zero['tr_noProps'], tr_zero['te_errs'], label=f'TR-CG: Δ={delta}', color='blue')
plt.plot(sgd_zero['tr_noProps'], sgd_zero['te_errs'], label=f'SGD: α={alpha}', color='orange', linestyle='--')
plt.xscale('log')
plt.xlabel('# of Props')
plt.ylabel('Test Error')
plt.title('Image Classification: CIFAR-10 (Saddle Point Init)')
plt.legend()
plt.grid(True, which="both", ls="--")
plt.savefig('figure_f.png')
plt.show()
plt.close()