# TŒ¥1: 90% Trapping Analysis

**Purpose**: Identify early indicators that distinguish 90% trapped seeds from 40-55% collapse seeds

**Protocol**:
- Train multiple seeds at collapse-inducing Œª
- Log detailed metrics throughout training
- Compare trajectories of seeds that end up at 90% vs 40-55%

**Key Metrics**:
- cos(g_s, g_v): Structure-value alignment
- ||g_s||, ||g_v||: Gradient norms
- Error trajectory
- Loss trajectory

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os, glob, json, time
from datetime import datetime

EXP_NAME = 'exp_Td1_90_analysis'
NOTEBOOK_ID = 'Td1'
BASE_DIR = '/content/drive/MyDrive/dual-gradient-learning/Paper-A'

existing = glob.glob(f'{BASE_DIR}/{EXP_NAME}_*')
if existing:
    SAVE_DIR = sorted(existing)[-1]
    print(f'üîÑ Resuming: {SAVE_DIR}')
else:
    TIMESTAMP = datetime.now().strftime('%Y%m%d_%H%M%S')
    SAVE_DIR = f'{BASE_DIR}/{EXP_NAME}_{TIMESTAMP}'
    os.makedirs(SAVE_DIR, exist_ok=True)
    print(f'üÜï New: {SAVE_DIR}')

os.makedirs(f'{SAVE_DIR}/figures', exist_ok=True)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils import parameters_to_vector
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
# Parameters
BATCH_SIZE = 256
NUM_WORKERS = 4
LR = 0.1
K = 16
NOISE_RATE = 0.4

# Training for collapse
COLLAPSE_LAMBDA = 0.60
N_EPOCHS = 100
LOG_FREQ = 20  # Log every N steps
EVAL_FREQ = 5  # Evaluate every N epochs

# More seeds to capture both 90% and 40-55% outcomes
N_SEEDS = 10

# Classification thresholds
TRAPPED_THRESHOLD = 0.85  # >85% = trapped (90% class)
PARTIAL_THRESHOLD = 0.55  # 40-55% = partial collapse

print(f'Œª = {COLLAPSE_LAMBDA}')
print(f'Seeds: {N_SEEDS}')

In [None]:
def get_resnet18():
    model = resnet18(weights=None, num_classes=10)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    return model

class IndexedDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        return img, label, idx
    def __len__(self):
        return len(self.dataset)

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def inject_label_noise(labels, noise_rate, seed):
    np.random.seed(seed)
    noisy = labels.copy()
    n_noisy = int(noise_rate * len(labels))
    idx = np.random.choice(len(labels), n_noisy, replace=False)
    for i in idx:
        noisy[i] = np.random.choice([l for l in range(10) if l != labels[i]])
    return noisy

def load_cifar10():
    tr = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
                             transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))])
    te = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))])
    return torchvision.datasets.CIFAR10('./data', True, tr, download=True), torchvision.datasets.CIFAR10('./data', False, te, download=True)

def evaluate(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            correct += (model(x).argmax(1) == y).sum().item()
            total += y.size(0)
    return correct / total

def cosine_sim(a, b):
    return (a @ b / (a.norm() * b.norm() + 1e-12)).item()

In [None]:
def train_with_logging(model, train_loader, opt, clean_t, noisy_t, lam, state, log_list):
    """Train one epoch with detailed logging"""
    crit = nn.CrossEntropyLoss()
    model.train()
    step = state['step']
    cached_gv = state['gv']
    epoch = state['epoch']
    
    for x, _, idx in train_loader:
        x, idx = x.to(device), idx.to(device)
        bn, bc = noisy_t[idx], clean_t[idx]
        
        # g_struct
        opt.zero_grad()
        loss_s = crit(model(x), bn)
        loss_s.backward(retain_graph=True)
        g_s = parameters_to_vector([p.grad for p in model.parameters()]).clone()
        
        # g_value (always compute for logging)
        opt.zero_grad()
        loss_v = crit(model(x), bc)
        loss_v.backward()
        g_v = parameters_to_vector([p.grad for p in model.parameters()]).clone()
        
        if step % K == 0 or cached_gv is None:
            cached_gv = g_v.clone()
        
        # Normalize and mix
        g_s_n = g_s / (g_s.norm() + 1e-12)
        g_v_n = cached_gv / (cached_gv.norm() + 1e-12)
        g_mix = (1 - lam) * g_s_n + lam * g_v_n
        
        # Log
        if step % LOG_FREQ == 0:
            log_list.append({
                'epoch': epoch,
                'step': step,
                'c_sv': cosine_sim(g_s, g_v),
                'c_mc': cosine_sim(g_mix, g_v),
                'norm_s': g_s.norm().item(),
                'norm_v': g_v.norm().item(),
                'norm_mix': g_mix.norm().item(),
                'loss_s': loss_s.item(),
                'loss_v': loss_v.item()
            })
        
        # Apply
        opt.zero_grad()
        i = 0
        for p in model.parameters():
            n = p.numel()
            p.grad = g_mix[i:i+n].view(p.shape).clone()
            i += n
        opt.step()
        step += 1
    
    state['step'] = step
    state['gv'] = cached_gv

In [None]:
trainset, testset = load_cifar10()
clean_labels = np.array(trainset.targets)
train_loader = DataLoader(IndexedDataset(trainset), BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(testset, BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

m = get_resnet18().to(device)
for _ in range(10): _ = m(torch.randn(BATCH_SIZE,3,32,32,device=device))
del m; torch.cuda.empty_cache()
print('Ready')

In [None]:
all_results = []

for seed in range(N_SEEDS):
    print(f'\n{"="*60}')
    print(f'SEED {seed}')
    print(f'{"="*60}')
    
    set_seed(seed)
    noisy_labels = inject_label_noise(clean_labels, NOISE_RATE, seed)
    clean_t = torch.tensor(clean_labels, device=device)
    noisy_t = torch.tensor(noisy_labels, device=device)
    
    model = get_resnet18().to(device)
    opt = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
    sched = optim.lr_scheduler.MultiStepLR(opt, [50, 80], gamma=0.1)
    
    state = {'step': 0, 'gv': None, 'epoch': 0}
    step_logs = []
    error_trajectory = []
    
    for ep in range(N_EPOCHS):
        state['epoch'] = ep
        train_with_logging(model, train_loader, opt, clean_t, noisy_t, COLLAPSE_LAMBDA, state, step_logs)
        sched.step()
        
        if (ep + 1) % EVAL_FREQ == 0:
            err = 1 - evaluate(model, test_loader)
            error_trajectory.append({'epoch': ep + 1, 'error': err})
            print(f'  Epoch {ep+1}: error={err:.4f}')
    
    final_error = 1 - evaluate(model, test_loader)
    
    # Classify
    if final_error >= TRAPPED_THRESHOLD:
        category = 'trapped_90'
        print(f'  üî¥ TRAPPED (90%): {final_error:.4f}')
    elif final_error >= PARTIAL_THRESHOLD:
        category = 'partial_collapse'
        print(f'  üü° PARTIAL COLLAPSE (40-55%): {final_error:.4f}')
    else:
        category = 'ordered'
        print(f'  üü¢ ORDERED: {final_error:.4f}')
    
    all_results.append({
        'seed': seed,
        'final_error': final_error,
        'category': category,
        'error_trajectory': error_trajectory,
        'step_logs': step_logs,
        'experiment_id': f'{NOTEBOOK_ID}-seed{seed:02d}'
    })
    
    del model; torch.cuda.empty_cache()

json.dump(all_results, open(f'{SAVE_DIR}/{NOTEBOOK_ID}_results.json', 'w'), indent=2, default=str)
print(f'\n{"="*60}')
print(f'{NOTEBOOK_ID} COMPLETE')
print(f'{"="*60}')

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Separate by category
trapped = [r for r in all_results if r['category'] == 'trapped_90']
partial = [r for r in all_results if r['category'] == 'partial_collapse']
ordered = [r for r in all_results if r['category'] == 'ordered']

print(f'\nüìä Classification:')
print(f'   Trapped (90%): {len(trapped)}')
print(f'   Partial (40-55%): {len(partial)}')
print(f'   Ordered: {len(ordered)}')

In [None]:
# Create step-level DataFrames for each category
def create_step_df(results_list, category_name):
    data = []
    for r in results_list:
        for log in r['step_logs']:
            data.append({
                'seed': r['seed'],
                'category': category_name,
                **log
            })
    return pd.DataFrame(data)

df_trapped = create_step_df(trapped, 'trapped_90') if trapped else pd.DataFrame()
df_partial = create_step_df(partial, 'partial_collapse') if partial else pd.DataFrame()
df_ordered = create_step_df(ordered, 'ordered') if ordered else pd.DataFrame()

df_all = pd.concat([df_trapped, df_partial, df_ordered], ignore_index=True)
df_all.to_csv(f'{SAVE_DIR}/{NOTEBOOK_ID}_step_logs.csv', index=False)

print(f'Total step logs: {len(df_all)}')

In [None]:
# Visualization: Compare trapped vs partial
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Color scheme
colors = {'trapped_90': 'red', 'partial_collapse': 'orange', 'ordered': 'blue'}
labels = {'trapped_90': 'Trapped (90%)', 'partial_collapse': 'Partial (40-55%)', 'ordered': 'Ordered'}

# 1. Error trajectory
ax = axes[0, 0]
for r in all_results:
    epochs = [t['epoch'] for t in r['error_trajectory']]
    errors = [t['error'] for t in r['error_trajectory']]
    ax.plot(epochs, errors, 'o-', color=colors[r['category']], alpha=0.5, linewidth=1.5)

ax.axhline(0.90, color='red', linestyle='--', alpha=0.3, label='90%')
ax.axhline(0.55, color='orange', linestyle='--', alpha=0.3, label='55%')
ax.set_xlabel('Epoch')
ax.set_ylabel('Test Error')
ax.set_title('Error Trajectory')
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

# 2. c_sv over time
ax = axes[0, 1]
for cat, df in [('trapped_90', df_trapped), ('partial_collapse', df_partial)]:
    if len(df) > 0:
        mean_by_epoch = df.groupby('epoch')['c_sv'].mean()
        ax.plot(mean_by_epoch.index, mean_by_epoch.values, 'o-', 
                color=colors[cat], linewidth=2, label=labels[cat])

ax.set_xlabel('Epoch')
ax.set_ylabel('cos(g_s, g_v)')
ax.set_title('Structure-Value Alignment')
ax.legend()
ax.grid(True, alpha=0.3)

# 3. c_mc over time
ax = axes[0, 2]
for cat, df in [('trapped_90', df_trapped), ('partial_collapse', df_partial)]:
    if len(df) > 0:
        mean_by_epoch = df.groupby('epoch')['c_mc'].mean()
        ax.plot(mean_by_epoch.index, mean_by_epoch.values, 'o-',
                color=colors[cat], linewidth=2, label=labels[cat])

ax.set_xlabel('Epoch')
ax.set_ylabel('cos(g_mix, g_clean)')
ax.set_title('Mix-Clean Alignment')
ax.legend()
ax.grid(True, alpha=0.3)

# 4. norm_s over time
ax = axes[1, 0]
for cat, df in [('trapped_90', df_trapped), ('partial_collapse', df_partial)]:
    if len(df) > 0:
        mean_by_epoch = df.groupby('epoch')['norm_s'].mean()
        ax.plot(mean_by_epoch.index, mean_by_epoch.values, 'o-',
                color=colors[cat], linewidth=2, label=labels[cat])

ax.set_xlabel('Epoch')
ax.set_ylabel('||g_struct||')
ax.set_title('Structure Gradient Norm')
ax.legend()
ax.grid(True, alpha=0.3)

# 5. loss_v over time
ax = axes[1, 1]
for cat, df in [('trapped_90', df_trapped), ('partial_collapse', df_partial)]:
    if len(df) > 0:
        mean_by_epoch = df.groupby('epoch')['loss_v'].mean()
        ax.plot(mean_by_epoch.index, mean_by_epoch.values, 'o-',
                color=colors[cat], linewidth=2, label=labels[cat])

ax.set_xlabel('Epoch')
ax.set_ylabel('Loss (clean labels)')
ax.set_title('Clean Loss')
ax.legend()
ax.grid(True, alpha=0.3)

# 6. Early divergence detection
ax = axes[1, 2]
# Check early epochs (0-20) for divergence
early_epochs = [0, 5, 10, 15, 20]
for cat, df in [('trapped_90', df_trapped), ('partial_collapse', df_partial)]:
    if len(df) > 0:
        early_c_sv = df[df['epoch'].isin(early_epochs)].groupby('epoch')['c_sv'].mean()
        ax.plot(early_c_sv.index, early_c_sv.values, 'o-',
                color=colors[cat], linewidth=2, markersize=8, label=labels[cat])

ax.set_xlabel('Epoch')
ax.set_ylabel('cos(g_s, g_v)')
ax.set_title('Early Divergence (epochs 0-20)')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{SAVE_DIR}/figures/{NOTEBOOK_ID}_analysis.png', dpi=150)
plt.show()

In [None]:
# Statistical comparison at early epochs
print('\n' + '='*60)
print('EARLY INDICATOR ANALYSIS')
print('='*60)

for epoch_check in [5, 10, 20]:
    print(f'\nüìä At epoch {epoch_check}:')
    
    if len(df_trapped) > 0 and len(df_partial) > 0:
        trap_at_ep = df_trapped[df_trapped['epoch'] == epoch_check]
        part_at_ep = df_partial[df_partial['epoch'] == epoch_check]
        
        if len(trap_at_ep) > 0 and len(part_at_ep) > 0:
            for metric in ['c_sv', 'c_mc', 'norm_s', 'loss_v']:
                trap_mean = trap_at_ep[metric].mean()
                part_mean = part_at_ep[metric].mean()
                diff = trap_mean - part_mean
                print(f'   {metric}: Trapped={trap_mean:.4f}, Partial={part_mean:.4f}, Diff={diff:+.4f}')
    else:
        print('   Insufficient data for comparison')

In [None]:
# Summary
print('\n' + '='*60)
print(f'{NOTEBOOK_ID} SUMMARY')
print('='*60)

print(f'\nüìä Outcome Distribution:')
print(f'   Trapped (90%):     {len(trapped)}/{N_SEEDS} ({len(trapped)/N_SEEDS*100:.0f}%)')
print(f'   Partial (40-55%):  {len(partial)}/{N_SEEDS} ({len(partial)/N_SEEDS*100:.0f}%)')
print(f'   Ordered (<40%):    {len(ordered)}/{N_SEEDS} ({len(ordered)/N_SEEDS*100:.0f}%)')

print(f'\nüìä Key Finding:')
if len(df_trapped) > 0 and len(df_partial) > 0:
    # Check if c_sv diverges early
    early_trap_c_sv = df_trapped[df_trapped['epoch'] <= 10]['c_sv'].mean()
    early_part_c_sv = df_partial[df_partial['epoch'] <= 10]['c_sv'].mean()
    
    if abs(early_trap_c_sv - early_part_c_sv) > 0.05:
        print(f'   ‚úÖ Early divergence detected in c_sv')
        print(f'      Trapped: {early_trap_c_sv:.4f}, Partial: {early_part_c_sv:.4f}')
    else:
        print(f'   ‚ö†Ô∏è No clear early divergence in c_sv')
else:
    print(f'   Need more data (both trapped and partial outcomes)')