# TŒ¥2: Escape from 90% Trap

**Purpose**: Test interventions to escape from the 90% trapped state

**Background**: Some seeds get trapped at ~90% error (near random). Can we rescue them?

**Interventions to test**:
1. **Œª‚Üí0**: Pure structure gradient (ignore value completely)
2. **LR Boost**: Increase learning rate significantly
3. **Œª‚Üí0 + LR Boost**: Combined intervention
4. **Warm Restart**: Reset optimizer momentum

**Key Question**: What breaks the 90% trap?

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os, glob, json, time
from datetime import datetime

EXP_NAME = 'exp_Td2_escape_trap'
NOTEBOOK_ID = 'Td2'
BASE_DIR = '/content/drive/MyDrive/dual-gradient-learning/Paper-A'

existing = glob.glob(f'{BASE_DIR}/{EXP_NAME}_*')
if existing:
    SAVE_DIR = sorted(existing)[-1]
    print(f'üîÑ Resuming: {SAVE_DIR}')
else:
    TIMESTAMP = datetime.now().strftime('%Y%m%d_%H%M%S')
    SAVE_DIR = f'{BASE_DIR}/{EXP_NAME}_{TIMESTAMP}'
    os.makedirs(SAVE_DIR, exist_ok=True)
    print(f'üÜï New: {SAVE_DIR}')

os.makedirs(f'{SAVE_DIR}/checkpoints', exist_ok=True)
os.makedirs(f'{SAVE_DIR}/figures', exist_ok=True)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils import parameters_to_vector
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
# Core parameters
BATCH_SIZE = 256
NUM_WORKERS = 4
LR = 0.1
K = 16
NOISE_RATE = 0.4

# Create trapped state
TRAP_LAMBDA = 0.60  # Œª that tends to produce 90% traps
TRAP_EPOCHS = 100   # Long enough to get trapped
TRAP_THRESHOLD = 0.85  # Consider >85% as "trapped"

# Escape interventions
ESCAPE_EPOCHS = 50  # Epochs to attempt escape
EVAL_FREQ = 5

# Intervention configurations
INTERVENTIONS = [
    {'name': 'baseline', 'lambda': 0.60, 'lr_mult': 0.01, 'reset_opt': False},  # Continue as-is (control)
    {'name': 'lambda_zero', 'lambda': 0.00, 'lr_mult': 0.01, 'reset_opt': False},  # Pure structure
    {'name': 'lambda_low', 'lambda': 0.20, 'lr_mult': 0.01, 'reset_opt': False},  # Low Œª
    {'name': 'lr_boost', 'lambda': 0.60, 'lr_mult': 0.1, 'reset_opt': False},  # 10x LR
    {'name': 'lr_boost_high', 'lambda': 0.60, 'lr_mult': 1.0, 'reset_opt': True},  # Full LR + reset
    {'name': 'combined', 'lambda': 0.00, 'lr_mult': 0.1, 'reset_opt': True},  # Œª=0 + LR boost + reset
]

# Seeds known to produce 90% traps (from previous experiments)
# We'll also try to create new traps
N_TRAP_ATTEMPTS = 10  # Try this many seeds to find trapped states

print(f'Trap Œª: {TRAP_LAMBDA}')
print(f'Interventions: {[i["name"] for i in INTERVENTIONS]}')

In [None]:
def get_resnet18():
    model = resnet18(weights=None, num_classes=10)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    return model

class IndexedDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        return img, label, idx
    def __len__(self):
        return len(self.dataset)

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def inject_label_noise(labels, noise_rate, seed):
    np.random.seed(seed)
    noisy = labels.copy()
    n_noisy = int(noise_rate * len(labels))
    idx = np.random.choice(len(labels), n_noisy, replace=False)
    for i in idx:
        noisy[i] = np.random.choice([l for l in range(10) if l != labels[i]])
    return noisy

def load_cifar10():
    tr = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
                             transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))])
    te = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))])
    return torchvision.datasets.CIFAR10('./data', True, tr, download=True), torchvision.datasets.CIFAR10('./data', False, te, download=True)

def evaluate(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            correct += (model(x).argmax(1) == y).sum().item()
            total += y.size(0)
    return correct / total

In [None]:
def train_one_epoch(model, train_loader, opt, clean_t, noisy_t, lam, state):
    crit = nn.CrossEntropyLoss()
    model.train()
    step = state['step']
    cached_gv = state['gv']
    
    for x, _, idx in train_loader:
        x, idx = x.to(device), idx.to(device)
        bn, bc = noisy_t[idx], clean_t[idx]
        
        opt.zero_grad()
        loss_s = crit(model(x), bn)
        loss_s.backward(retain_graph=True)
        g_s = parameters_to_vector([p.grad for p in model.parameters()]).clone()
        
        if step % K == 0 or cached_gv is None:
            opt.zero_grad()
            loss_v = crit(model(x), bc)
            loss_v.backward()
            cached_gv = parameters_to_vector([p.grad for p in model.parameters()]).clone()
        
        g_s_n = g_s / (g_s.norm() + 1e-12)
        g_v_n = cached_gv / (cached_gv.norm() + 1e-12)
        g_mix = (1 - lam) * g_s_n + lam * g_v_n
        
        opt.zero_grad()
        i = 0
        for p in model.parameters():
            n = p.numel()
            p.grad = g_mix[i:i+n].view(p.shape).clone()
            i += n
        opt.step()
        step += 1
    
    state['step'] = step
    state['gv'] = cached_gv

In [None]:
trainset, testset = load_cifar10()
clean_labels = np.array(trainset.targets)
train_loader = DataLoader(IndexedDataset(trainset), BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(testset, BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

m = get_resnet18().to(device)
for _ in range(5): _ = m(torch.randn(BATCH_SIZE,3,32,32,device=device))
del m; torch.cuda.empty_cache()
print('Ready')

In [None]:
def create_trapped_state(seed, clean_t, noisy_t):
    """Attempt to create a 90% trapped state"""
    set_seed(seed)
    model = get_resnet18().to(device)
    opt = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
    sched = optim.lr_scheduler.MultiStepLR(opt, [50, 80], gamma=0.1)
    state = {'step': 0, 'gv': None}
    
    error_history = []
    
    for ep in range(TRAP_EPOCHS):
        train_one_epoch(model, train_loader, opt, clean_t, noisy_t, TRAP_LAMBDA, state)
        sched.step()
        
        if (ep + 1) % 20 == 0:
            err = 1 - evaluate(model, test_loader)
            error_history.append({'epoch': ep + 1, 'error': err})
            print(f'    Epoch {ep+1}: {err:.4f}')
    
    final_error = 1 - evaluate(model, test_loader)
    is_trapped = final_error >= TRAP_THRESHOLD
    
    if is_trapped:
        trapped_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
    else:
        trapped_state = None
    
    del model; torch.cuda.empty_cache()
    
    return {
        'seed': seed,
        'final_error': final_error,
        'is_trapped': is_trapped,
        'error_history': error_history,
        'state': trapped_state
    }

In [None]:
def run_escape_intervention(trapped_state, intervention, seed, clean_t, noisy_t, init_error):
    """Run escape intervention and track trajectory"""
    set_seed(seed + 1000 + hash(intervention['name']) % 1000)
    
    model = get_resnet18().to(device)
    model.load_state_dict({k: v.to(device) for k, v in trapped_state.items()})
    
    # Setup optimizer based on intervention
    actual_lr = LR * intervention['lr_mult']
    opt = optim.SGD(model.parameters(), lr=actual_lr, momentum=0.9, weight_decay=5e-4)
    
    state = {'step': 0, 'gv': None}
    trajectory = [{'epoch': 0, 'error': init_error}]
    
    for ep in range(ESCAPE_EPOCHS):
        train_one_epoch(model, train_loader, opt, clean_t, noisy_t, intervention['lambda'], state)
        
        if (ep + 1) % EVAL_FREQ == 0:
            err = 1 - evaluate(model, test_loader)
            trajectory.append({'epoch': ep + 1, 'error': err})
    
    final_error = 1 - evaluate(model, test_loader)
    escaped = final_error < 0.60  # Escaped if below 60%
    
    del model; torch.cuda.empty_cache()
    
    return {
        'intervention': intervention['name'],
        'config': intervention,
        'init_error': init_error,
        'final_error': final_error,
        'escaped': escaped,
        'improvement': init_error - final_error,
        'trajectory': trajectory
    }

In [None]:
# Phase 1: Find trapped states
print('='*60)
print('PHASE 1: Finding Trapped States')
print('='*60)

trapped_states = []

for seed in range(N_TRAP_ATTEMPTS):
    print(f'\nSeed {seed}:')
    
    noisy_labels = inject_label_noise(clean_labels, NOISE_RATE, seed)
    clean_t = torch.tensor(clean_labels, device=device)
    noisy_t = torch.tensor(noisy_labels, device=device)
    
    result = create_trapped_state(seed, clean_t, noisy_t)
    
    if result['is_trapped']:
        print(f'  üî¥ TRAPPED at {result["final_error"]:.2%}')
        trapped_states.append({
            'seed': seed,
            'final_error': result['final_error'],
            'state': result['state'],
            'noisy_labels': noisy_labels
        })
    else:
        print(f'  üü¢ Not trapped ({result["final_error"]:.2%})')
    
    # Stop if we have enough trapped states
    if len(trapped_states) >= 3:
        print(f'\nFound {len(trapped_states)} trapped states, proceeding to interventions')
        break

print(f'\nTotal trapped states found: {len(trapped_states)}')

In [None]:
# Phase 2: Run escape interventions
print('\n' + '='*60)
print('PHASE 2: Escape Interventions')
print('='*60)

all_results = []

for trap_info in trapped_states:
    seed = trap_info['seed']
    print(f'\n{"="*50}')
    print(f'Trapped Seed {seed} (error={trap_info["final_error"]:.2%})')
    print(f'{"="*50}')
    
    clean_t = torch.tensor(clean_labels, device=device)
    noisy_t = torch.tensor(trap_info['noisy_labels'], device=device)
    
    seed_results = {
        'seed': seed,
        'trap_error': trap_info['final_error'],
        'interventions': []
    }
    
    for intervention in INTERVENTIONS:
        print(f'\n  Intervention: {intervention["name"]}')
        print(f'    Œª={intervention["lambda"]}, LR√ó{intervention["lr_mult"]}, reset={intervention["reset_opt"]}')
        
        result = run_escape_intervention(
            trap_info['state'],
            intervention,
            seed,
            clean_t,
            noisy_t,
            trap_info['final_error']
        )
        
        status = '‚úÖ ESCAPED' if result['escaped'] else '‚ùå Still trapped'
        print(f'    {trap_info["final_error"]:.2%} ‚Üí {result["final_error"]:.2%} ({status})')
        print(f'    Improvement: {result["improvement"]*100:.1f}%')
        
        seed_results['interventions'].append(result)
    
    seed_results['experiment_id'] = f'{NOTEBOOK_ID}-seed{seed:02d}'
    all_results.append(seed_results)

# Save results
json.dump(all_results, open(f'{SAVE_DIR}/{NOTEBOOK_ID}_results.json', 'w'), indent=2, default=str)

print(f'\n{"="*60}')
print(f'{NOTEBOOK_ID} COMPLETE')
print(f'{"="*60}')

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Visualization: Escape trajectories by intervention
n_interventions = len(INTERVENTIONS)
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

colors = plt.cm.tab10(np.linspace(0, 1, len(trapped_states)))

for i, intervention in enumerate(INTERVENTIONS):
    ax = axes[i]
    
    for j, seed_result in enumerate(all_results):
        intv_result = next((r for r in seed_result['interventions'] if r['intervention'] == intervention['name']), None)
        if intv_result:
            epochs = [t['epoch'] for t in intv_result['trajectory']]
            errors = [t['error'] for t in intv_result['trajectory']]
            ax.plot(epochs, errors, 'o-', color=colors[j], linewidth=2, markersize=5,
                    label=f'Seed {seed_result["seed"]}')
    
    ax.axhline(0.60, color='orange', linestyle='--', alpha=0.5, label='Escape threshold')
    ax.axhline(0.90, color='red', linestyle='--', alpha=0.3)
    ax.set_xlabel('Epochs')
    ax.set_ylabel('Test Error')
    ax.set_title(f'{intervention["name"]}\n(Œª={intervention["lambda"]}, LR√ó{intervention["lr_mult"]})', fontsize=11)
    ax.grid(True, alpha=0.3)
    ax.set_ylim(0, 1)
    if i == 0:
        ax.legend(fontsize=8)

plt.tight_layout()
plt.savefig(f'{SAVE_DIR}/figures/{NOTEBOOK_ID}_escape_trajectories.png', dpi=150)
plt.savefig(f'{SAVE_DIR}/figures/{NOTEBOOK_ID}_escape_trajectories.pdf')
plt.show()

In [None]:
# Bar plot: Success rate by intervention
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Escape rate
ax = axes[0]
escape_rates = []
for intervention in INTERVENTIONS:
    escaped = sum(1 for r in all_results 
                  for intv in r['interventions'] 
                  if intv['intervention'] == intervention['name'] and intv['escaped'])
    total = sum(1 for r in all_results 
                for intv in r['interventions'] 
                if intv['intervention'] == intervention['name'])
    rate = escaped / total if total > 0 else 0
    escape_rates.append(rate)

x = np.arange(len(INTERVENTIONS))
bars = ax.bar(x, [r * 100 for r in escape_rates], color='steelblue', alpha=0.8, edgecolor='navy')
ax.axhline(50, color='orange', linestyle='--', alpha=0.5)
ax.set_xlabel('Intervention')
ax.set_ylabel('Escape Rate (%)')
ax.set_title('Escape Success Rate by Intervention')
ax.set_xticks(x)
ax.set_xticklabels([i['name'] for i in INTERVENTIONS], rotation=45, ha='right')
ax.set_ylim(0, 100)
ax.grid(True, alpha=0.3, axis='y')

# Average improvement
ax = axes[1]
improvements = []
for intervention in INTERVENTIONS:
    imps = [intv['improvement'] for r in all_results 
            for intv in r['interventions'] 
            if intv['intervention'] == intervention['name']]
    improvements.append(np.mean(imps) if imps else 0)

bars = ax.bar(x, [imp * 100 for imp in improvements], color='green', alpha=0.8, edgecolor='darkgreen')
ax.axhline(0, color='black', linestyle='-', alpha=0.3)
ax.set_xlabel('Intervention')
ax.set_ylabel('Average Improvement (%)')
ax.set_title('Average Error Reduction by Intervention')
ax.set_xticks(x)
ax.set_xticklabels([i['name'] for i in INTERVENTIONS], rotation=45, ha='right')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig(f'{SAVE_DIR}/figures/{NOTEBOOK_ID}_intervention_comparison.png', dpi=150)
plt.savefig(f'{SAVE_DIR}/figures/{NOTEBOOK_ID}_intervention_comparison.pdf')
plt.show()

In [None]:
# Summary
print('='*60)
print(f'{NOTEBOOK_ID} SUMMARY: Escape from 90% Trap')
print('='*60)

print(f'\nüìä Trapped states tested: {len(all_results)}')

print(f'\nüìä Results by Intervention:')
print(f'{"Intervention":<20} {"Escape Rate":<15} {"Avg Improvement":<15} {"Best Final":<12}')
print('-' * 62)

best_intervention = None
best_escape_rate = 0

for i, intervention in enumerate(INTERVENTIONS):
    results_for_intv = [intv for r in all_results 
                        for intv in r['interventions'] 
                        if intv['intervention'] == intervention['name']]
    
    if results_for_intv:
        escaped = sum(1 for r in results_for_intv if r['escaped'])
        escape_rate = escaped / len(results_for_intv)
        avg_improvement = np.mean([r['improvement'] for r in results_for_intv])
        best_final = min(r['final_error'] for r in results_for_intv)
        
        print(f'{intervention["name"]:<20} {escape_rate*100:>6.1f}%        {avg_improvement*100:>6.1f}%          {best_final*100:>5.1f}%')
        
        if escape_rate > best_escape_rate:
            best_escape_rate = escape_rate
            best_intervention = intervention['name']

print(f'\n{"="*60}')
print(f'CONCLUSION:')
if best_escape_rate > 0.5:
    print(f'  ‚úÖ Best intervention: {best_intervention}')
    print(f'  ‚úÖ Escape rate: {best_escape_rate*100:.0f}%')
    print(f'  ‚úÖ 90% trap CAN be escaped with proper intervention')
elif best_escape_rate > 0:
    print(f'  ‚ö†Ô∏è Best intervention: {best_intervention}')
    print(f'  ‚ö†Ô∏è Escape rate: {best_escape_rate*100:.0f}% (partial success)')
    print(f'  ‚ö†Ô∏è 90% trap is difficult but not impossible to escape')
else:
    print(f'  ‚ùå No intervention successfully escaped the trap')
    print(f'  ‚ùå 90% trap is a stable attractor')
print(f'{"="*60}')

In [None]:
# Create summary DataFrame
summary_data = []
for r in all_results:
    for intv in r['interventions']:
        summary_data.append({
            'seed': r['seed'],
            'trap_error': r['trap_error'],
            'intervention': intv['intervention'],
            'lambda': intv['config']['lambda'],
            'lr_mult': intv['config']['lr_mult'],
            'final_error': intv['final_error'],
            'improvement': intv['improvement'],
            'escaped': intv['escaped']
        })

df_summary = pd.DataFrame(summary_data)
df_summary.to_csv(f'{SAVE_DIR}/{NOTEBOOK_ID}_summary.csv', index=False)
print('Summary saved')
print(df_summary.to_string())