# TŒ≤2_v3: Two-Branch Existence Test at Œ∑=0.8

**Purpose**: Test whether bistability exists under extreme noise (80%)

**Version History**:
- v1: threshold 25% ‚Üí All seeds failed
- v2: threshold 45% ‚Üí Seeds reached 46-65% but still failed
- **v3: No absolute threshold** (Sofia's recommendation)

**Key Design Change** (per Sofia):
> „Äåordered-initÔºàÔºùlow-error initÔºâÁµ∂ÂØæÈñæÂÄ§„ÅØ‰Ωø„Çè„Å™„ÅÑ„Äç
> „Äåordered„Åå‰Ωú„Çå„Å™„ÅÑÔºùÁõ∏„ÅåÊ∂à„Åà„Çã„Äç„Å®„ÅÑ„ÅÜÁµêÊûú„Å®„Åó„Å¶„ÄÅË®≠Ë®àÁõÆÁöÑ„ÇíÂàá„ÇäÊõø„Åà„Çå„Å∞Ë´ñÊñáÂåñ„Åß„Åç„Çã

**Protocol**:
1. Train multiple seeds at Œª=0.50 for 100 epochs
2. Take the **best performing seed** as "low-error init" (whatever error it achieves)
3. Create collapse init (90%) as usual
4. Compare the two branches - do they separate?

**Possible Outcomes**:
- Branches separate ‚Üí Bistability persists at Œ∑=0.8
- Branches converge ‚Üí Bistability breaks at Œ∑=0.8 (boundary discovery)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os, glob, json, time
from datetime import datetime

EXP_NAME = 'exp_Tb2_eta08_v3'
NOTEBOOK_ID = 'Tb2_v3'
BASE_DIR = '/content/drive/MyDrive/dual-gradient-learning/Paper-A'

existing = glob.glob(f'{BASE_DIR}/{EXP_NAME}_*')
if existing:
    SAVE_DIR = sorted(existing)[-1]
    print(f'üîÑ Resuming: {SAVE_DIR}')
else:
    TIMESTAMP = datetime.now().strftime('%Y%m%d_%H%M%S')
    SAVE_DIR = f'{BASE_DIR}/{EXP_NAME}_{TIMESTAMP}'
    os.makedirs(SAVE_DIR, exist_ok=True)
    print(f'üÜï New: {SAVE_DIR}')

os.makedirs(f'{SAVE_DIR}/checkpoints', exist_ok=True)
os.makedirs(f'{SAVE_DIR}/figures', exist_ok=True)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils import parameters_to_vector
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
# Core parameters
BATCH_SIZE = 256
NUM_WORKERS = 4
LR = 0.1
K = 16

# HIGH NOISE
NOISE_RATE = 0.8  # 80% noise

# v3: NO ABSOLUTE THRESHOLD for low-error init
# Instead, we take the best we can get
LOW_ERROR_LAMBDA = 0.50  # Œª for creating low-error state
LOW_ERROR_EPOCHS = 100

# Collapse init (high-error)
COLLAPSE_LAMBDA = 0.70
COLLAPSE_EPOCHS = 100

# Sweep settings
LAMBDA_START = 0.40
LAMBDA_END = 0.80
LAMBDA_STEP = 0.05
EPOCHS_PER_LAMBDA = 5

LAMBDA_GRID_UP = np.round(np.arange(LAMBDA_START, LAMBDA_END + LAMBDA_STEP/2, LAMBDA_STEP), 2)
LAMBDA_GRID_DOWN = np.round(np.arange(LAMBDA_END, LAMBDA_START - LAMBDA_STEP/2, -LAMBDA_STEP), 2)

# Phase 1: Find best low-error seed
N_SEARCH_SEEDS = 10  # Search this many seeds to find the best

# Phase 2: Run comparison with best seed
N_COMPARISON_SEEDS = 3  # How many seeds to run full comparison

print(f'Noise rate: {NOISE_RATE*100:.0f}%')
print(f'Strategy: No threshold, take best achievable error as low-error init')
print(f'Sweep range: Œª ‚àà [{LAMBDA_START}, {LAMBDA_END}]')

In [None]:
def get_resnet18():
    model = resnet18(weights=None, num_classes=10)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    return model

class IndexedDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        return img, label, idx
    def __len__(self):
        return len(self.dataset)

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def inject_label_noise(labels, noise_rate, seed):
    np.random.seed(seed)
    noisy = labels.copy()
    n_noisy = int(noise_rate * len(labels))
    idx = np.random.choice(len(labels), n_noisy, replace=False)
    for i in idx:
        noisy[i] = np.random.choice([l for l in range(10) if l != labels[i]])
    return noisy

def load_cifar10():
    tr = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
                             transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))])
    te = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))])
    return torchvision.datasets.CIFAR10('./data', True, tr, download=True), torchvision.datasets.CIFAR10('./data', False, te, download=True)

def evaluate(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            correct += (model(x).argmax(1) == y).sum().item()
            total += y.size(0)
    return correct / total

In [None]:
def train_one_epoch(model, train_loader, opt, clean_t, noisy_t, lam, state):
    crit = nn.CrossEntropyLoss()
    model.train()
    step = state['step']
    cached_gv = state['gv']
    
    for x, _, idx in train_loader:
        x, idx = x.to(device), idx.to(device)
        bn, bc = noisy_t[idx], clean_t[idx]
        
        opt.zero_grad()
        loss_s = crit(model(x), bn)
        loss_s.backward(retain_graph=True)
        g_s = parameters_to_vector([p.grad for p in model.parameters()]).clone()
        
        if step % K == 0 or cached_gv is None:
            opt.zero_grad()
            loss_v = crit(model(x), bc)
            loss_v.backward()
            cached_gv = parameters_to_vector([p.grad for p in model.parameters()]).clone()
        
        g_s_n = g_s / (g_s.norm() + 1e-12)
        g_v_n = cached_gv / (cached_gv.norm() + 1e-12)
        g_mix = (1 - lam) * g_s_n + lam * g_v_n
        
        opt.zero_grad()
        i = 0
        for p in model.parameters():
            n = p.numel()
            p.grad = g_mix[i:i+n].view(p.shape).clone()
            i += n
        opt.step()
        step += 1
    
    state['step'] = step
    state['gv'] = cached_gv

In [None]:
trainset, testset = load_cifar10()
clean_labels = np.array(trainset.targets)
train_loader = DataLoader(IndexedDataset(trainset), BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(testset, BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

m = get_resnet18().to(device)
for _ in range(5): _ = m(torch.randn(BATCH_SIZE,3,32,32,device=device))
del m; torch.cuda.empty_cache()
print('Ready')

In [None]:
# ============================================================
# PHASE 1: Find best low-error seeds
# ============================================================
print('='*60)
print('PHASE 1: Finding Best Low-Error Seeds')
print('='*60)

seed_results = []

for seed in range(N_SEARCH_SEEDS):
    print(f'\nSeed {seed}...')
    
    noisy_labels = inject_label_noise(clean_labels, NOISE_RATE, seed)
    clean_t = torch.tensor(clean_labels, device=device)
    noisy_t = torch.tensor(noisy_labels, device=device)
    
    set_seed(seed)
    model = get_resnet18().to(device)
    opt = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
    sched = optim.lr_scheduler.MultiStepLR(opt, [50, 80], gamma=0.1)
    state = {'step': 0, 'gv': None}
    
    best_error = 1.0
    for ep in range(LOW_ERROR_EPOCHS):
        train_one_epoch(model, train_loader, opt, clean_t, noisy_t, LOW_ERROR_LAMBDA, state)
        sched.step()
        if (ep + 1) % 20 == 0:
            err = 1 - evaluate(model, test_loader)
            best_error = min(best_error, err)
            print(f'  Epoch {ep+1}: error={err:.4f}')
    
    final_error = 1 - evaluate(model, test_loader)
    seed_results.append({
        'seed': seed,
        'final_error': final_error,
        'state': {k: v.cpu().clone() for k, v in model.state_dict().items()},
        'noisy_labels': noisy_labels
    })
    print(f'  Final: {final_error:.4f}')
    
    del model; torch.cuda.empty_cache()

# Sort by error (lowest first)
seed_results.sort(key=lambda x: x['final_error'])

print(f'\n{"="*60}')
print('Phase 1 Results (sorted by error):')
for i, r in enumerate(seed_results):
    marker = '‚≠ê' if i < N_COMPARISON_SEEDS else '  '
    print(f'{marker} Seed {r["seed"]}: {r["final_error"]*100:.1f}%')
print(f'{"="*60}')

In [None]:
# ============================================================
# PHASE 2: Two-Branch Comparison
# ============================================================
print('\n' + '='*60)
print('PHASE 2: Two-Branch Comparison')
print('='*60)

# Take best N seeds
selected_seeds = seed_results[:N_COMPARISON_SEEDS]

results = []

for seed_info in selected_seeds:
    seed = seed_info['seed']
    low_error_init = seed_info['final_error']
    
    print(f'\n{"="*50}')
    print(f'Seed {seed} (low-error init: {low_error_init*100:.1f}%)')
    print(f'{"="*50}')
    
    clean_t = torch.tensor(clean_labels, device=device)
    noisy_t = torch.tensor(seed_info['noisy_labels'], device=device)
    
    # === Create Collapse Init (high-error) ===
    print(f'\n[Creating Collapse Init (Œª={COLLAPSE_LAMBDA})]...')
    set_seed(seed + 100)
    model = get_resnet18().to(device)
    opt = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
    sched = optim.lr_scheduler.MultiStepLR(opt, [50, 80], gamma=0.1)
    state = {'step': 0, 'gv': None}
    
    for ep in range(COLLAPSE_EPOCHS):
        train_one_epoch(model, train_loader, opt, clean_t, noisy_t, COLLAPSE_LAMBDA, state)
        sched.step()
        if (ep + 1) % 25 == 0:
            err = 1 - evaluate(model, test_loader)
            print(f'  Epoch {ep+1}: error={err:.4f}')
    
    collapse_error = 1 - evaluate(model, test_loader)
    collapse_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
    print(f'  Collapse init: {collapse_error:.4f}')
    del model; torch.cuda.empty_cache()
    
    # === Low-Error Branch Sweep (Œª‚Üë) ===
    print(f'\n[Low-Error Branch Sweep (Œª‚Üë)]...')
    set_seed(seed + 200)
    model = get_resnet18().to(device)
    model.load_state_dict({k: v.to(device) for k, v in seed_info['state'].items()})
    opt = optim.SGD(model.parameters(), lr=LR * 0.01, momentum=0.9, weight_decay=5e-4)
    state = {'step': 0, 'gv': None}
    
    low_error_traj = []
    for lam in LAMBDA_GRID_UP:
        for _ in range(EPOCHS_PER_LAMBDA):
            train_one_epoch(model, train_loader, opt, clean_t, noisy_t, lam, state)
        err = 1 - evaluate(model, test_loader)
        low_error_traj.append({'lambda': float(lam), 'error': err})
        print(f'  Œª={lam:.2f}: {err:.4f}')
    
    del model; torch.cuda.empty_cache()
    
    # === Collapse Branch Sweep (Œª‚Üì) ===
    print(f'\n[Collapse Branch Sweep (Œª‚Üì)]...')
    set_seed(seed + 300)
    model = get_resnet18().to(device)
    model.load_state_dict({k: v.to(device) for k, v in collapse_state.items()})
    opt = optim.SGD(model.parameters(), lr=LR * 0.01, momentum=0.9, weight_decay=5e-4)
    state = {'step': 0, 'gv': None}
    
    collapse_traj = []
    for lam in LAMBDA_GRID_DOWN:
        for _ in range(EPOCHS_PER_LAMBDA):
            train_one_epoch(model, train_loader, opt, clean_t, noisy_t, lam, state)
        err = 1 - evaluate(model, test_loader)
        collapse_traj.append({'lambda': float(lam), 'error': err})
        print(f'  Œª={lam:.2f}: {err:.4f}')
    
    del model; torch.cuda.empty_cache()
    
    results.append({
        'seed': seed,
        'eta': NOISE_RATE,
        'low_error_init': low_error_init,
        'collapse_init': collapse_error,
        'low_error_trajectory': low_error_traj,
        'collapse_trajectory': collapse_traj,
        'experiment_id': f'{NOTEBOOK_ID}-seed{seed:02d}'
    })

# Save results
json.dump(results, open(f'{SAVE_DIR}/{NOTEBOOK_ID}_results.json', 'w'), indent=2, default=str)

print(f'\n{"="*60}')
print(f'{NOTEBOOK_ID} COMPLETE')
print(f'{"="*60}')

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Create DataFrame
all_data = []
for r in results:
    for t in r['low_error_trajectory']:
        all_data.append({'seed': r['seed'], 'branch': 'low_error', 'lambda': t['lambda'], 'error': t['error']})
    for t in r['collapse_trajectory']:
        all_data.append({'seed': r['seed'], 'branch': 'collapse', 'lambda': t['lambda'], 'error': t['error']})
df = pd.DataFrame(all_data)
df.to_csv(f'{SAVE_DIR}/{NOTEBOOK_ID}_results.csv', index=False)

# Visualization
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Individual trajectories
ax = axes[0]
for r in results:
    lams_l = [t['lambda'] for t in r['low_error_trajectory']]
    errs_l = [t['error'] for t in r['low_error_trajectory']]
    lams_c = [t['lambda'] for t in r['collapse_trajectory']]
    errs_c = [t['error'] for t in r['collapse_trajectory']]
    ax.plot(lams_l, errs_l, 'b-o', alpha=0.6, linewidth=2, markersize=5)
    ax.plot(lams_c, errs_c, 'r-s', alpha=0.6, linewidth=2, markersize=5)

ax.axhline(0.90, color='red', linestyle='--', alpha=0.3, label='Random')
ax.set_xlabel('Œª', fontsize=12)
ax.set_ylabel('Test Error', fontsize=12)
ax.set_title(f'Œ∑={NOISE_RATE}: Individual Trajectories', fontsize=14)
ax.legend(['Low-Error (Œª‚Üë)', 'Collapse (Œª‚Üì)'], fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_xlim(LAMBDA_START - 0.02, LAMBDA_END + 0.02)
ax.set_ylim(0, 1)

# Mean with std
ax = axes[1]
df_low = df[df['branch'] == 'low_error']
df_col = df[df['branch'] == 'collapse']

if len(df_low) > 0:
    m = df_low.groupby('lambda')['error'].agg(['mean', 'std']).reset_index()
    ax.fill_between(m['lambda'], m['mean']-m['std'], m['mean']+m['std'], alpha=0.3, color='blue')
    ax.plot(m['lambda'], m['mean'], 'b-o', linewidth=2, markersize=6, label='Low-Error (Œª‚Üë)')

if len(df_col) > 0:
    m = df_col.groupby('lambda')['error'].agg(['mean', 'std']).reset_index()
    ax.fill_between(m['lambda'], m['mean']-m['std'], m['mean']+m['std'], alpha=0.3, color='red')
    ax.plot(m['lambda'], m['mean'], 'r-s', linewidth=2, markersize=6, label='Collapse (Œª‚Üì)')

ax.axhline(0.90, color='red', linestyle='--', alpha=0.3)
ax.set_xlabel('Œª', fontsize=12)
ax.set_ylabel('Test Error', fontsize=12)
ax.set_title(f'Two-Branch Test at Œ∑={NOISE_RATE}', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_xlim(LAMBDA_START - 0.02, LAMBDA_END + 0.02)
ax.set_ylim(0, 1)

plt.tight_layout()
plt.savefig(f'{SAVE_DIR}/figures/{NOTEBOOK_ID}_two_branch_eta08.png', dpi=150)
plt.savefig(f'{SAVE_DIR}/figures/{NOTEBOOK_ID}_two_branch_eta08.pdf')
plt.show()

In [None]:
# Summary and Conclusion
print('='*60)
print(f'{NOTEBOOK_ID} SUMMARY: Two-Branch Test at Œ∑={NOISE_RATE}')
print('='*60)

# Initial conditions
print(f'\nüìä Initial Conditions:')
for r in results:
    print(f'   Seed {r["seed"]}: Low-error={r["low_error_init"]*100:.1f}%, Collapse={r["collapse_init"]*100:.1f}%')

# Gap analysis at multiple Œª points
print(f'\nüìä Gap Analysis:')
gaps_at_lambda = {}

for check_lam in [0.50, 0.60, 0.70]:
    low_at_lam = df_low[df_low['lambda'] == check_lam]['error']
    col_at_lam = df_col[df_col['lambda'] == check_lam]['error']
    
    if len(low_at_lam) > 0 and len(col_at_lam) > 0:
        low_err = low_at_lam.mean()
        col_err = col_at_lam.mean()
        gap = col_err - low_err
        gaps_at_lambda[check_lam] = gap
        
        print(f'\n   At Œª={check_lam}:')
        print(f'     Low-error: {low_err*100:.1f}%')
        print(f'     Collapse:  {col_err*100:.1f}%')
        print(f'     Gap:       {gap*100:.1f}%')

# Conclusion
if gaps_at_lambda:
    avg_gap = np.mean(list(gaps_at_lambda.values()))
    
    print(f'\n{"="*60}')
    print(f'CONCLUSION (Œ∑={NOISE_RATE}):')
    
    if avg_gap > 0.15:
        print(f'  ‚úÖ TWO-BRANCH STRUCTURE PERSISTS at Œ∑=0.8')
        print(f'  ‚úÖ Average gap: {avg_gap*100:.1f}%')
        print(f'  ‚úÖ Bistability is robust to extreme noise')
    elif avg_gap > 0.05:
        print(f'  ‚ö†Ô∏è WEAK TWO-BRANCH STRUCTURE at Œ∑=0.8')
        print(f'  ‚ö†Ô∏è Average gap: {avg_gap*100:.1f}%')
        print(f'  ‚ö†Ô∏è Bistability partially persists')
    else:
        print(f'  üîç NO CLEAR TWO-BRANCH STRUCTURE at Œ∑=0.8')
        print(f'  üîç Average gap: {avg_gap*100:.1f}%')
        print(f'  üîç This suggests Œ∑=0.8 is at or beyond the bistability boundary')
        print(f'  üîç ‚Üí This is a FINDING: noise-bistability phase boundary')
    
    print(f'{"="*60}')