# V3_v3: Architecture Universality (ResNet34)

**Purpose**: Confirm hysteresis is not ResNet18-specific

**Version History**:
- v1: VGG11, Sweep LR too low ‚Üí 90% stuck
- v2: VGG11, LR fixed ‚Üí Still 90% stuck at checkpoint creation
- **v3: Switch to ResNet34** (deeper but same family, known to be stable)

**Strategy** (per Sofia):
> PRX„Å´ÂøÖË¶Å„Å™„ÅÆ„ÅØ„ÄåVGG„Åß„ÇÇÂêå„ÅòÊï∞ÂÄ§„Äç„Åß„ÅØ„Å™„Åè„Äå‰∫åÊûù„ÅåÂà•„Ç¢„Éº„Ç≠„Åß„ÇÇÂá∫„Çã„Äç

**Design**: TŒ±Á∏ÆÂ∞èÁâà (3 seeds, coarser Œª grid)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os, glob, json, time
from datetime import datetime

EXP_NAME = 'exp_V3_resnet34_v3'
NOTEBOOK_ID = 'V3_v3'
BASE_DIR = '/content/drive/MyDrive/dual-gradient-learning/Paper-A'

existing = glob.glob(f'{BASE_DIR}/{EXP_NAME}_*')
if existing:
    SAVE_DIR = sorted(existing)[-1]
    print(f'üîÑ Resuming: {SAVE_DIR}')
else:
    TIMESTAMP = datetime.now().strftime('%Y%m%d_%H%M%S')
    SAVE_DIR = f'{BASE_DIR}/{EXP_NAME}_{TIMESTAMP}'
    os.makedirs(SAVE_DIR, exist_ok=True)
    print(f'üÜï New: {SAVE_DIR}')

os.makedirs(f'{SAVE_DIR}/checkpoints', exist_ok=True)
os.makedirs(f'{SAVE_DIR}/figures', exist_ok=True)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils import parameters_to_vector
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet34
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
# Core parameters - Same as ResNet18 (known to work)
BATCH_SIZE = 256
NUM_WORKERS = 4
LR = 0.1
K = 16
NOISE_RATE = 0.4

# Checkpoint creation - Same as TŒ±
ORDERED_LAMBDA = 0.35
ORDERED_EPOCHS = 50
ORDERED_THRESHOLD = 0.25

COLLAPSE_LAMBDA = 0.60
COLLAPSE_EPOCHS = 80
COLLAPSE_THRESHOLD = 0.45

# Sweep settings - TŒ±Á∏ÆÂ∞èÁâà
LAMBDA_START = 0.30
LAMBDA_END = 0.70
LAMBDA_STEP = 0.05
EPOCHS_PER_LAMBDA = 3

LAMBDA_GRID_UP = np.round(np.arange(LAMBDA_START, LAMBDA_END + LAMBDA_STEP/2, LAMBDA_STEP), 2)
LAMBDA_GRID_DOWN = np.round(np.arange(LAMBDA_END, LAMBDA_START - LAMBDA_STEP/2, -LAMBDA_STEP), 2)

N_SEEDS = 3

print(f'Architecture: ResNet34 (deeper than ResNet18)')
print(f'Same settings as TŒ± (known to work with ResNet18)')
print(f'Seeds: {N_SEEDS}')

In [None]:
def get_resnet34_cifar():
    """ResNet34 modified for CIFAR-10 (same modification as ResNet18)"""
    model = resnet34(weights=None, num_classes=10)
    # CIFAR modification: smaller first conv, no maxpool
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    return model

class IndexedDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        return img, label, idx
    def __len__(self):
        return len(self.dataset)

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def inject_label_noise(labels, noise_rate, seed):
    np.random.seed(seed)
    noisy = labels.copy()
    n_noisy = int(noise_rate * len(labels))
    idx = np.random.choice(len(labels), n_noisy, replace=False)
    for i in idx:
        noisy[i] = np.random.choice([l for l in range(10) if l != labels[i]])
    return noisy

def load_cifar10():
    tr = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
                             transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))])
    te = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))])
    return torchvision.datasets.CIFAR10('./data', True, tr, download=True), torchvision.datasets.CIFAR10('./data', False, te, download=True)

def evaluate(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            correct += (model(x).argmax(1) == y).sum().item()
            total += y.size(0)
    return correct / total

In [None]:
def train_one_epoch(model, train_loader, opt, clean_t, noisy_t, lam, state):
    crit = nn.CrossEntropyLoss()
    model.train()
    step = state['step']
    cached_gv = state['gv']
    
    for x, _, idx in train_loader:
        x, idx = x.to(device), idx.to(device)
        bn, bc = noisy_t[idx], clean_t[idx]
        
        opt.zero_grad()
        loss_s = crit(model(x), bn)
        loss_s.backward(retain_graph=True)
        g_s = parameters_to_vector([p.grad for p in model.parameters()]).clone()
        
        if step % K == 0 or cached_gv is None:
            opt.zero_grad()
            loss_v = crit(model(x), bc)
            loss_v.backward()
            cached_gv = parameters_to_vector([p.grad for p in model.parameters()]).clone()
        
        g_s_n = g_s / (g_s.norm() + 1e-12)
        g_v_n = cached_gv / (cached_gv.norm() + 1e-12)
        g_mix = (1 - lam) * g_s_n + lam * g_v_n
        
        opt.zero_grad()
        i = 0
        for p in model.parameters():
            n = p.numel()
            p.grad = g_mix[i:i+n].view(p.shape).clone()
            i += n
        opt.step()
        step += 1
    
    state['step'] = step
    state['gv'] = cached_gv

In [None]:
trainset, testset = load_cifar10()
clean_labels = np.array(trainset.targets)
train_loader = DataLoader(IndexedDataset(trainset), BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(testset, BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

m = get_resnet34_cifar().to(device)
print(f'ResNet34 parameters: {sum(p.numel() for p in m.parameters()):,}')
for _ in range(3): _ = m(torch.randn(BATCH_SIZE,3,32,32,device=device))
del m; torch.cuda.empty_cache()
print('Ready')

In [None]:
results = []
ckpt_file = f'{SAVE_DIR}/{NOTEBOOK_ID}_checkpoint.json'

if os.path.exists(ckpt_file):
    results = json.load(open(ckpt_file))
    done_seeds = {r['seed'] for r in results}
    print(f'Loaded: {len(done_seeds)} seeds done')
else:
    done_seeds = set()

for seed in range(N_SEEDS):
    if seed in done_seeds:
        print(f'Seed {seed}: Already done')
        continue
    
    print(f'\n{"="*60}')
    print(f'SEED {seed} (ResNet34)')
    print(f'{"="*60}')
    
    t0 = time.time()
    
    noisy_labels = inject_label_noise(clean_labels, NOISE_RATE, seed)
    clean_t = torch.tensor(clean_labels, device=device)
    noisy_t = torch.tensor(noisy_labels, device=device)
    
    # === Phase 1: Ordered Checkpoint ===
    print(f'\n[Phase 1] Ordered (Œª={ORDERED_LAMBDA})...')
    set_seed(seed)
    model = get_resnet34_cifar().to(device)
    opt = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
    sched = optim.lr_scheduler.MultiStepLR(opt, [30, 40], gamma=0.1)
    state = {'step': 0, 'gv': None}
    
    for ep in range(ORDERED_EPOCHS):
        train_one_epoch(model, train_loader, opt, clean_t, noisy_t, ORDERED_LAMBDA, state)
        sched.step()
        if (ep + 1) % 10 == 0:
            err = 1 - evaluate(model, test_loader)
            print(f'  Epoch {ep+1}: error={err:.4f}')
    
    ordered_error = 1 - evaluate(model, test_loader)
    print(f'  Ordered final: {ordered_error:.4f}')
    
    if ordered_error >= ORDERED_THRESHOLD:
        print(f'  ‚ö†Ô∏è Failed to reach ordered state (>{ORDERED_THRESHOLD}), skipping seed')
        del model; torch.cuda.empty_cache()
        continue
    
    print(f'  ‚úÖ Ordered checkpoint created')
    ordered_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
    del model; torch.cuda.empty_cache()
    
    # === Phase 2: Collapse Checkpoint ===
    print(f'\n[Phase 2] Collapse (Œª={COLLAPSE_LAMBDA})...')
    set_seed(seed + 100)
    model = get_resnet34_cifar().to(device)
    opt = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
    sched = optim.lr_scheduler.MultiStepLR(opt, [40, 60], gamma=0.1)
    state = {'step': 0, 'gv': None}
    
    for ep in range(COLLAPSE_EPOCHS):
        train_one_epoch(model, train_loader, opt, clean_t, noisy_t, COLLAPSE_LAMBDA, state)
        sched.step()
        if (ep + 1) % 20 == 0:
            err = 1 - evaluate(model, test_loader)
            print(f'  Epoch {ep+1}: error={err:.4f}')
    
    collapse_error = 1 - evaluate(model, test_loader)
    print(f'  Collapse final: {collapse_error:.4f}')
    
    collapse_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
    del model; torch.cuda.empty_cache()
    
    # === Phase 3: Ordered Sweep (Œª‚Üë) ===
    print(f'\n[Phase 3] Ordered sweep (Œª‚Üë)...')
    set_seed(seed + 200)
    model = get_resnet34_cifar().to(device)
    model.load_state_dict({k: v.to(device) for k, v in ordered_state.items()})
    opt = optim.SGD(model.parameters(), lr=LR * 0.01, momentum=0.9, weight_decay=5e-4)
    state = {'step': 0, 'gv': None}
    
    ordered_traj = []
    for lam in LAMBDA_GRID_UP:
        for _ in range(EPOCHS_PER_LAMBDA):
            train_one_epoch(model, train_loader, opt, clean_t, noisy_t, lam, state)
        err = 1 - evaluate(model, test_loader)
        ordered_traj.append({'lambda': float(lam), 'error': err})
        print(f'  Œª={lam:.2f}: {err:.4f}')
    
    del model; torch.cuda.empty_cache()
    
    # === Phase 4: Collapse Sweep (Œª‚Üì) ===
    print(f'\n[Phase 4] Collapse sweep (Œª‚Üì)...')
    set_seed(seed + 300)
    model = get_resnet34_cifar().to(device)
    model.load_state_dict({k: v.to(device) for k, v in collapse_state.items()})
    opt = optim.SGD(model.parameters(), lr=LR * 0.01, momentum=0.9, weight_decay=5e-4)
    state = {'step': 0, 'gv': None}
    
    collapse_traj = []
    for lam in LAMBDA_GRID_DOWN:
        for _ in range(EPOCHS_PER_LAMBDA):
            train_one_epoch(model, train_loader, opt, clean_t, noisy_t, lam, state)
        err = 1 - evaluate(model, test_loader)
        collapse_traj.append({'lambda': float(lam), 'error': err})
        print(f'  Œª={lam:.2f}: {err:.4f}')
    
    elapsed = time.time() - t0
    
    results.append({
        'seed': seed,
        'architecture': 'ResNet34',
        'ordered_init_error': ordered_error,
        'collapse_init_error': collapse_error,
        'ordered_trajectory': ordered_traj,
        'collapse_trajectory': collapse_traj,
        'time_seconds': elapsed,
        'experiment_id': f'{NOTEBOOK_ID}-seed{seed:02d}'
    })
    
    json.dump(results, open(ckpt_file, 'w'), indent=2, default=str)
    done_seeds.add(seed)
    print(f'\n  ‚è±Ô∏è Time: {elapsed/60:.1f} min')
    del model; torch.cuda.empty_cache()

print(f'\n{"="*60}')
print(f'{NOTEBOOK_ID} COMPLETE')
print(f'{"="*60}')

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

json.dump(results, open(f'{SAVE_DIR}/{NOTEBOOK_ID}_results.json', 'w'), indent=2, default=str)

if len(results) == 0:
    print('‚ùå No successful runs to visualize')
else:
    # Create DataFrame
    all_data = []
    for r in results:
        for t in r['ordered_trajectory']:
            all_data.append({'seed': r['seed'], 'branch': 'ordered', 'lambda': t['lambda'], 'error': t['error']})
        for t in r['collapse_trajectory']:
            all_data.append({'seed': r['seed'], 'branch': 'collapse', 'lambda': t['lambda'], 'error': t['error']})
    df = pd.DataFrame(all_data)
    df.to_csv(f'{SAVE_DIR}/{NOTEBOOK_ID}_results.csv', index=False)
    
    # Visualization
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # Individual trajectories
    ax = axes[0]
    for r in results:
        lams_o = [t['lambda'] for t in r['ordered_trajectory']]
        errs_o = [t['error'] for t in r['ordered_trajectory']]
        lams_c = [t['lambda'] for t in r['collapse_trajectory']]
        errs_c = [t['error'] for t in r['collapse_trajectory']]
        ax.plot(lams_o, errs_o, 'b-o', alpha=0.6, linewidth=2, markersize=5)
        ax.plot(lams_c, errs_c, 'r-s', alpha=0.6, linewidth=2, markersize=5)
    
    ax.axhline(0.40, color='orange', linestyle='--', alpha=0.5)
    ax.set_xlabel('Œª', fontsize=12)
    ax.set_ylabel('Test Error', fontsize=12)
    ax.set_title('ResNet34: Individual Trajectories', fontsize=14)
    ax.legend(['Ordered (Œª‚Üë)', 'Collapse (Œª‚Üì)'], fontsize=10)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0.28, 0.72)
    
    # Mean with std
    ax = axes[1]
    df_ord = df[df['branch'] == 'ordered']
    df_col = df[df['branch'] == 'collapse']
    
    if len(df_ord) > 0:
        m = df_ord.groupby('lambda')['error'].agg(['mean', 'std']).reset_index()
        ax.fill_between(m['lambda'], m['mean']-m['std'], m['mean']+m['std'], alpha=0.3, color='blue')
        ax.plot(m['lambda'], m['mean'], 'b-o', linewidth=2, markersize=6, label='Ordered (Œª‚Üë)')
    
    if len(df_col) > 0:
        m = df_col.groupby('lambda')['error'].agg(['mean', 'std']).reset_index()
        ax.fill_between(m['lambda'], m['mean']-m['std'], m['mean']+m['std'], alpha=0.3, color='red')
        ax.plot(m['lambda'], m['mean'], 'r-s', linewidth=2, markersize=6, label='Collapse (Œª‚Üì)')
    
    ax.axhline(0.40, color='orange', linestyle='--', alpha=0.5)
    ax.set_xlabel('Œª', fontsize=12)
    ax.set_ylabel('Test Error', fontsize=12)
    ax.set_title('ResNet34: Hysteresis (Mean¬±Std)', fontsize=14, fontweight='bold')
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0.28, 0.72)
    
    plt.tight_layout()
    plt.savefig(f'{SAVE_DIR}/figures/{NOTEBOOK_ID}_hysteresis_resnet34.png', dpi=150)
    plt.savefig(f'{SAVE_DIR}/figures/{NOTEBOOK_ID}_hysteresis_resnet34.pdf')
    plt.show()

In [None]:
# Summary
print('='*60)
print(f'{NOTEBOOK_ID} SUMMARY: ResNet34')
print('='*60)

print(f'\nüìä Completed runs: {len(results)}/{N_SEEDS}')

if len(results) > 0:
    df_ord = df[df['branch'] == 'ordered']
    df_col = df[df['branch'] == 'collapse']
    
    if len(df_ord) > 0 and len(df_col) > 0:
        for check_lam in [0.40, 0.50, 0.60]:
            ord_at_lam = df_ord[df_ord['lambda'] == check_lam]['error']
            col_at_lam = df_col[df_col['lambda'] == check_lam]['error']
            if len(ord_at_lam) > 0 and len(col_at_lam) > 0:
                ord_err = ord_at_lam.mean()
                col_err = col_at_lam.mean()
                gap = col_err - ord_err
                print(f'\nüìä At Œª={check_lam}:')
                print(f'   Ordered:  {ord_err*100:.1f}%')
                print(f'   Collapse: {col_err*100:.1f}%')
                print(f'   Gap:      {gap*100:.1f}%')
        
        # Main conclusion
        mid_lam = 0.50
        ord_err = df_ord[df_ord['lambda'] == mid_lam]['error'].mean()
        col_err = df_col[df_col['lambda'] == mid_lam]['error'].mean()
        gap = col_err - ord_err
        
        print(f'\n{"="*60}')
        print(f'CONCLUSION:')
        if gap > 0.10:
            print(f'  ‚úÖ Two-branch structure confirmed in ResNet34')
            print(f'  ‚úÖ Architecture universality demonstrated')
            print(f'  ‚úÖ Hysteresis gap: {gap*100:.1f}%')
        elif gap > 0.05:
            print(f'  ‚ö†Ô∏è Weak two-branch structure in ResNet34')
        else:
            print(f'  ‚ùå No clear two-branch structure')
        print(f'{"="*60}')
else:
    print('\n‚ùå No successful runs')