# W1: Geometric Logs by Branch

**Purpose**: Record gradient geometry (cos, norms) to show branches are geometrically distinct

**Protocol**:
- Run sweep on both branches (ordered/collapse)
- Record at each Î»: cos(g_s, g_v), cos(g_v, g_c), ||g_s||, ||g_v||, ||g_mix||
- Show systematic differences between branches

**Key Metrics**:
- c_sv: cos(g_struct, g_value)
- c_vc: cos(g_value, g_clean)  
- c_mc: cos(g_mix, g_clean)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os, glob, json, time
from datetime import datetime

EXP_NAME = 'exp_W1_geometry'
NOTEBOOK_ID = 'W1'
BASE_DIR = '/content/drive/MyDrive/dual-gradient-learning/Paper-A'

ORDERED_CKPT_DIR = sorted(glob.glob(f'{BASE_DIR}/exp_Ta_prep_ordered_*'))[-1] + '/checkpoints'
COLLAPSE_CKPT_DIR = sorted(glob.glob(f'{BASE_DIR}/exp_Ta_prep_collapse_*'))[-1] + '/checkpoints'
print(f'Ordered: {ORDERED_CKPT_DIR}')
print(f'Collapse: {COLLAPSE_CKPT_DIR}')

existing = glob.glob(f'{BASE_DIR}/{EXP_NAME}_*')
if existing:
    SAVE_DIR = sorted(existing)[-1]
    print(f'ðŸ”„ Resuming: {SAVE_DIR}')
else:
    TIMESTAMP = datetime.now().strftime('%Y%m%d_%H%M%S')
    SAVE_DIR = f'{BASE_DIR}/{EXP_NAME}_{TIMESTAMP}'
    os.makedirs(SAVE_DIR, exist_ok=True)
    print(f'ðŸ†• New: {SAVE_DIR}')

os.makedirs(f'{SAVE_DIR}/figures', exist_ok=True)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils import parameters_to_vector
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
BATCH_SIZE = 256
NUM_WORKERS = 4
LR = 0.1
K = 16

NOISE_RATE = 0.4
LAMBDA_START = 0.30
LAMBDA_END = 0.70
LAMBDA_STEP = 0.04  # Coarser for speed (11 points)
EPOCHS_PER_LAMBDA = 3
LOG_FREQ = 50  # Log geometry every N steps

LAMBDA_GRID_UP = np.round(np.arange(LAMBDA_START, LAMBDA_END + LAMBDA_STEP/2, LAMBDA_STEP), 2)
LAMBDA_GRID_DOWN = np.round(np.arange(LAMBDA_END, LAMBDA_START - LAMBDA_STEP/2, -LAMBDA_STEP), 2)

N_SEEDS = 3  # Reduced for speed

print(f'Î» points: {len(LAMBDA_GRID_UP)}')
print(f'Log frequency: every {LOG_FREQ} steps')

In [None]:
def get_resnet18():
    model = resnet18(weights=None, num_classes=10)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    return model

class IndexedDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        return img, label, idx
    def __len__(self):
        return len(self.dataset)

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def inject_label_noise(labels, noise_rate, seed):
    np.random.seed(seed)
    noisy = labels.copy()
    n_noisy = int(noise_rate * len(labels))
    idx = np.random.choice(len(labels), n_noisy, replace=False)
    for i in idx:
        noisy[i] = np.random.choice([l for l in range(10) if l != labels[i]])
    return noisy

def load_cifar10():
    tr = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
                             transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))])
    te = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))])
    return torchvision.datasets.CIFAR10('./data', True, tr, download=True), torchvision.datasets.CIFAR10('./data', False, te, download=True)

def evaluate(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            correct += (model(x).argmax(1) == y).sum().item()
            total += y.size(0)
    return correct / total

def cosine_sim(a, b):
    return (a @ b / (a.norm() * b.norm() + 1e-12)).item()

In [None]:
def train_one_epoch_with_logging(model, train_loader, opt, clean_t, noisy_t, lam, state):
    """Train one epoch with geometry logging"""
    crit = nn.CrossEntropyLoss()
    model.train()
    
    step = state['step']
    cached_gv = state['gv']
    geometry_logs = []
    
    for x, _, idx in train_loader:
        x, idx = x.to(device), idx.to(device)
        bn = noisy_t[idx]  # Noisy labels (structure)
        bc = clean_t[idx]  # Clean labels (value)
        
        # Compute g_struct (noisy labels)
        opt.zero_grad()
        loss_s = crit(model(x), bn)
        loss_s.backward(retain_graph=True)
        g_s = parameters_to_vector([p.grad for p in model.parameters()]).clone()
        
        # Compute g_value (clean labels) - always compute for logging
        opt.zero_grad()
        loss_v = crit(model(x), bc)
        loss_v.backward()
        g_v = parameters_to_vector([p.grad for p in model.parameters()]).clone()
        
        # Cache for efficiency (still use cached for actual training)
        if step % K == 0 or cached_gv is None:
            cached_gv = g_v.clone()
        
        # Normalize
        g_s_n = g_s / (g_s.norm() + 1e-12)
        g_v_n = cached_gv / (cached_gv.norm() + 1e-12)
        
        # Mix
        g_mix = (1 - lam) * g_s_n + lam * g_v_n
        
        # Log geometry at intervals
        if step % LOG_FREQ == 0:
            log_entry = {
                'step': step,
                'lambda': float(lam),
                'c_sv': cosine_sim(g_s, g_v),       # cos(struct, value)
                'c_sc': cosine_sim(g_s, g_v),       # Same as c_sv when value=clean
                'c_vc': 1.0,                         # cos(value, clean) = 1 when value IS clean
                'c_mc': cosine_sim(g_mix, g_v),     # cos(mix, clean)
                'norm_s': g_s.norm().item(),
                'norm_v': g_v.norm().item(),
                'norm_mix': g_mix.norm().item(),
                'loss_s': loss_s.item(),
                'loss_v': loss_v.item()
            }
            geometry_logs.append(log_entry)
        
        # Apply gradient
        opt.zero_grad()
        i = 0
        for p in model.parameters():
            n = p.numel()
            p.grad = g_mix[i:i+n].view(p.shape).clone()
            i += n
        opt.step()
        step += 1
    
    state['step'] = step
    state['gv'] = cached_gv
    
    return geometry_logs

In [None]:
def run_sweep_with_geometry(ckpt_path, train_loader, test_loader, clean_labels, noisy_labels, direction='up'):
    """Run sweep and collect geometry logs"""
    ckpt = torch.load(ckpt_path, map_location=device)
    seed = ckpt['seed']
    init_error = ckpt['final_error']
    
    lambda_grid = LAMBDA_GRID_UP if direction == 'up' else LAMBDA_GRID_DOWN
    branch = 'ordered_up' if direction == 'up' else 'collapse_down'
    
    print(f'    Loaded: seed={seed}, init_error={init_error:.4f}, branch={branch}')
    
    clean_t = torch.tensor(clean_labels, device=device)
    noisy_t = torch.tensor(noisy_labels, device=device)
    
    set_seed(seed + 5000)
    model = get_resnet18().to(device)
    model.load_state_dict(ckpt['model_state_dict'])
    
    opt = optim.SGD(model.parameters(), lr=LR * 0.01, momentum=0.9, weight_decay=5e-4)
    state = {'step': 0, 'gv': None}
    
    trajectory = []
    all_geometry_logs = []
    
    for lam in lambda_grid:
        epoch_logs = []
        for _ in range(EPOCHS_PER_LAMBDA):
            logs = train_one_epoch_with_logging(model, train_loader, opt, clean_t, noisy_t, lam, state)
            epoch_logs.extend(logs)
        
        err = 1 - evaluate(model, test_loader)
        trajectory.append({'lambda': float(lam), 'error': err})
        
        # Aggregate geometry for this Î»
        if epoch_logs:
            agg = {
                'lambda': float(lam),
                'error': err,
                'c_sv_mean': np.mean([l['c_sv'] for l in epoch_logs]),
                'c_sv_std': np.std([l['c_sv'] for l in epoch_logs]),
                'c_mc_mean': np.mean([l['c_mc'] for l in epoch_logs]),
                'c_mc_std': np.std([l['c_mc'] for l in epoch_logs]),
                'norm_s_mean': np.mean([l['norm_s'] for l in epoch_logs]),
                'norm_v_mean': np.mean([l['norm_v'] for l in epoch_logs]),
                'norm_mix_mean': np.mean([l['norm_mix'] for l in epoch_logs]),
                'loss_s_mean': np.mean([l['loss_s'] for l in epoch_logs]),
                'loss_v_mean': np.mean([l['loss_v'] for l in epoch_logs]),
                'n_logs': len(epoch_logs)
            }
            all_geometry_logs.append(agg)
        
        print(f'      Î»={lam:.2f}: err={err:.4f}, c_sv={agg["c_sv_mean"]:.3f}, c_mc={agg["c_mc_mean"]:.3f}')
    
    return {
        'seed': seed,
        'branch': branch,
        'init_error': init_error,
        'final_error': trajectory[-1]['error'],
        'trajectory': trajectory,
        'geometry': all_geometry_logs,
        'checkpoint_source': os.path.basename(ckpt_path)
    }

In [None]:
trainset, testset = load_cifar10()
clean_labels = np.array(trainset.targets)
train_loader = DataLoader(IndexedDataset(trainset), BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(testset, BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

# Get checkpoints
ordered_ckpts = sorted(glob.glob(f'{ORDERED_CKPT_DIR}/ordered_seed*.pth'))[:N_SEEDS]
collapse_ckpts = sorted(glob.glob(f'{COLLAPSE_CKPT_DIR}/collapse_seed*.pth'))
collapse_ckpts = [c for c in collapse_ckpts if torch.load(c, map_location='cpu')['final_error'] < 0.85][:N_SEEDS]

print(f'Ordered: {len(ordered_ckpts)}, Collapse: {len(collapse_ckpts)}')

m = get_resnet18().to(device)
for _ in range(10): _ = m(torch.randn(BATCH_SIZE,3,32,32,device=device))
del m; torch.cuda.empty_cache()
print('Ready')

In [None]:
results = []

# Ordered branch
print('='*60)
print('ORDERED BRANCH (Î» â†‘)')
print('='*60)

for i, ckpt_path in enumerate(ordered_ckpts):
    print(f'\n[{i+1}/{len(ordered_ckpts)}] {os.path.basename(ckpt_path)}')
    
    ckpt_temp = torch.load(ckpt_path, map_location='cpu')
    seed = ckpt_temp['seed']
    noisy_labels = inject_label_noise(clean_labels, NOISE_RATE, seed)
    
    result = run_sweep_with_geometry(ckpt_path, train_loader, test_loader, clean_labels, noisy_labels, 'up')
    result['experiment_id'] = f'{NOTEBOOK_ID}-ord-{i+1:03d}'
    results.append(result)
    torch.cuda.empty_cache()

# Collapse branch
print('\n' + '='*60)
print('COLLAPSE BRANCH (Î» â†“)')
print('='*60)

for i, ckpt_path in enumerate(collapse_ckpts):
    print(f'\n[{i+1}/{len(collapse_ckpts)}] {os.path.basename(ckpt_path)}')
    
    ckpt_temp = torch.load(ckpt_path, map_location='cpu')
    seed = ckpt_temp['seed']
    noisy_labels = inject_label_noise(clean_labels, NOISE_RATE, seed)
    
    result = run_sweep_with_geometry(ckpt_path, train_loader, test_loader, clean_labels, noisy_labels, 'down')
    result['experiment_id'] = f'{NOTEBOOK_ID}-col-{i+1:03d}'
    results.append(result)
    torch.cuda.empty_cache()

json.dump(results, open(f'{SAVE_DIR}/{NOTEBOOK_ID}_results.json', 'w'), indent=2, default=str)
print('\n' + '='*60)
print(f'{NOTEBOOK_ID} COMPLETE')
print('='*60)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Extract geometry data
geom_data = []
for r in results:
    for g in r['geometry']:
        geom_data.append({
            'seed': r['seed'],
            'branch': r['branch'],
            **g
        })
df_geom = pd.DataFrame(geom_data)
df_geom.to_csv(f'{SAVE_DIR}/{NOTEBOOK_ID}_geometry.csv', index=False)

# Separate branches
df_ord = df_geom[df_geom['branch'] == 'ordered_up']
df_col = df_geom[df_geom['branch'] == 'collapse_down']

print(f'Geometry data: {len(df_geom)} points')
print(f'  Ordered: {len(df_ord)}')
print(f'  Collapse: {len(df_col)}')

In [None]:
# Visualization: Geometry by branch
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# 1. c_sv (cos struct-value) by Î»
ax = axes[0, 0]
if len(df_ord) > 0:
    ord_mean = df_ord.groupby('lambda')['c_sv_mean'].mean()
    ax.plot(ord_mean.index, ord_mean.values, 'b-o', linewidth=2, markersize=6, label='Ordered')
if len(df_col) > 0:
    col_mean = df_col.groupby('lambda')['c_sv_mean'].mean()
    ax.plot(col_mean.index, col_mean.values, 'r-s', linewidth=2, markersize=6, label='Collapse')
ax.set_xlabel('Î»')
ax.set_ylabel('cos(g_struct, g_value)')
ax.set_title('c_sv: Structure-Value Alignment')
ax.legend()
ax.grid(True, alpha=0.3)

# 2. c_mc (cos mix-clean) by Î»
ax = axes[0, 1]
if len(df_ord) > 0:
    ord_mean = df_ord.groupby('lambda')['c_mc_mean'].mean()
    ax.plot(ord_mean.index, ord_mean.values, 'b-o', linewidth=2, markersize=6, label='Ordered')
if len(df_col) > 0:
    col_mean = df_col.groupby('lambda')['c_mc_mean'].mean()
    ax.plot(col_mean.index, col_mean.values, 'r-s', linewidth=2, markersize=6, label='Collapse')
ax.set_xlabel('Î»')
ax.set_ylabel('cos(g_mix, g_clean)')
ax.set_title('c_mc: Mix-Clean Alignment')
ax.legend()
ax.grid(True, alpha=0.3)

# 3. Test error by Î»
ax = axes[0, 2]
if len(df_ord) > 0:
    ord_mean = df_ord.groupby('lambda')['error'].mean()
    ax.plot(ord_mean.index, ord_mean.values, 'b-o', linewidth=2, markersize=6, label='Ordered')
if len(df_col) > 0:
    col_mean = df_col.groupby('lambda')['error'].mean()
    ax.plot(col_mean.index, col_mean.values, 'r-s', linewidth=2, markersize=6, label='Collapse')
ax.axhline(0.40, color='orange', linestyle='--', alpha=0.5)
ax.set_xlabel('Î»')
ax.set_ylabel('Test Error')
ax.set_title('Performance')
ax.legend()
ax.grid(True, alpha=0.3)

# 4. Gradient norms
ax = axes[1, 0]
if len(df_ord) > 0:
    ax.plot(df_ord.groupby('lambda')['norm_s_mean'].mean().index,
            df_ord.groupby('lambda')['norm_s_mean'].mean().values, 'b-', linewidth=2, label='Ord ||g_s||')
    ax.plot(df_ord.groupby('lambda')['norm_v_mean'].mean().index,
            df_ord.groupby('lambda')['norm_v_mean'].mean().values, 'b--', linewidth=2, label='Ord ||g_v||')
if len(df_col) > 0:
    ax.plot(df_col.groupby('lambda')['norm_s_mean'].mean().index,
            df_col.groupby('lambda')['norm_s_mean'].mean().values, 'r-', linewidth=2, label='Col ||g_s||')
    ax.plot(df_col.groupby('lambda')['norm_v_mean'].mean().index,
            df_col.groupby('lambda')['norm_v_mean'].mean().values, 'r--', linewidth=2, label='Col ||g_v||')
ax.set_xlabel('Î»')
ax.set_ylabel('Gradient Norm')
ax.set_title('Gradient Magnitudes')
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

# 5. Loss values
ax = axes[1, 1]
if len(df_ord) > 0:
    ax.plot(df_ord.groupby('lambda')['loss_s_mean'].mean().index,
            df_ord.groupby('lambda')['loss_s_mean'].mean().values, 'b-', linewidth=2, label='Ord L_s')
    ax.plot(df_ord.groupby('lambda')['loss_v_mean'].mean().index,
            df_ord.groupby('lambda')['loss_v_mean'].mean().values, 'b--', linewidth=2, label='Ord L_v')
if len(df_col) > 0:
    ax.plot(df_col.groupby('lambda')['loss_s_mean'].mean().index,
            df_col.groupby('lambda')['loss_s_mean'].mean().values, 'r-', linewidth=2, label='Col L_s')
    ax.plot(df_col.groupby('lambda')['loss_v_mean'].mean().index,
            df_col.groupby('lambda')['loss_v_mean'].mean().values, 'r--', linewidth=2, label='Col L_v')
ax.set_xlabel('Î»')
ax.set_ylabel('Loss')
ax.set_title('Loss Values')
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

# 6. c_sv vs error scatter
ax = axes[1, 2]
if len(df_ord) > 0:
    ax.scatter(df_ord['c_sv_mean'], df_ord['error'], c='blue', alpha=0.6, label='Ordered', s=50)
if len(df_col) > 0:
    ax.scatter(df_col['c_sv_mean'], df_col['error'], c='red', alpha=0.6, label='Collapse', s=50)
ax.set_xlabel('cos(g_struct, g_value)')
ax.set_ylabel('Test Error')
ax.set_title('Geometry-Performance Relationship')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{SAVE_DIR}/figures/{NOTEBOOK_ID}_geometry_comparison.png', dpi=150)
plt.show()

# Summary statistics
print('\n' + '='*60)
print(f'{NOTEBOOK_ID} GEOMETRY SUMMARY')
print('='*60)

if len(df_ord) > 0 and len(df_col) > 0:
    print('\nðŸ“Š Mean c_sv (struct-value alignment):')
    print(f'   Ordered:  {df_ord["c_sv_mean"].mean():.3f} Â± {df_ord["c_sv_mean"].std():.3f}')
    print(f'   Collapse: {df_col["c_sv_mean"].mean():.3f} Â± {df_col["c_sv_mean"].std():.3f}')
    
    print('\nðŸ“Š Mean c_mc (mix-clean alignment):')
    print(f'   Ordered:  {df_ord["c_mc_mean"].mean():.3f} Â± {df_ord["c_mc_mean"].std():.3f}')
    print(f'   Collapse: {df_col["c_mc_mean"].mean():.3f} Â± {df_col["c_mc_mean"].std():.3f}')