# V2: Realistic Value Signal (Trusted Subset)

**Purpose**: Show hysteresis exists even without perfect oracle (clean labels)

**Protocol**:
- Use only a small fraction (1%, 5%) of clean labels as trusted subset
- g_value computed from trusted subset only
- Compare with full oracle results

**Key Question**: Does two-branch structure survive with realistic value signal?

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os, glob, json, time
from datetime import datetime

EXP_NAME = 'exp_V2_trusted_subset'
NOTEBOOK_ID = 'V2'
BASE_DIR = '/content/drive/MyDrive/dual-gradient-learning/Paper-A'

existing = glob.glob(f'{BASE_DIR}/{EXP_NAME}_*')
if existing:
    SAVE_DIR = sorted(existing)[-1]
    print(f'ðŸ”„ Resuming: {SAVE_DIR}')
else:
    TIMESTAMP = datetime.now().strftime('%Y%m%d_%H%M%S')
    SAVE_DIR = f'{BASE_DIR}/{EXP_NAME}_{TIMESTAMP}'
    os.makedirs(SAVE_DIR, exist_ok=True)
    print(f'ðŸ†• New: {SAVE_DIR}')

os.makedirs(f'{SAVE_DIR}/checkpoints', exist_ok=True)
os.makedirs(f'{SAVE_DIR}/figures', exist_ok=True)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils import parameters_to_vector
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
BATCH_SIZE = 256
NUM_WORKERS = 4
LR = 0.1
K = 16

NOISE_RATE = 0.4
TRUSTED_RATIOS = [0.01, 0.05]  # 1% and 5%

ORDERED_LAMBDA = 0.35
ORDERED_EPOCHS = 50
COLLAPSE_LAMBDA = 0.60
COLLAPSE_EPOCHS = 100

LAMBDA_START = 0.30
LAMBDA_END = 0.70
LAMBDA_STEP = 0.04
EPOCHS_PER_LAMBDA = 3

LAMBDA_GRID_UP = np.round(np.arange(LAMBDA_START, LAMBDA_END + LAMBDA_STEP/2, LAMBDA_STEP), 2)
LAMBDA_GRID_DOWN = np.round(np.arange(LAMBDA_END, LAMBDA_START - LAMBDA_STEP/2, -LAMBDA_STEP), 2)

N_SEEDS = 3

print(f'Trusted ratios: {TRUSTED_RATIOS}')
print(f'Î» points: {len(LAMBDA_GRID_UP)}')

In [None]:
def get_resnet18():
    model = resnet18(weights=None, num_classes=10)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    return model

class IndexedDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        return img, label, idx
    def __len__(self):
        return len(self.dataset)

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def inject_label_noise(labels, noise_rate, seed):
    np.random.seed(seed)
    noisy = labels.copy()
    n_noisy = int(noise_rate * len(labels))
    idx = np.random.choice(len(labels), n_noisy, replace=False)
    for i in idx:
        noisy[i] = np.random.choice([l for l in range(10) if l != labels[i]])
    return noisy

def create_trusted_subset(n_total, trusted_ratio, seed):
    np.random.seed(seed + 999)
    n_trusted = int(n_total * trusted_ratio)
    return set(np.random.choice(n_total, n_trusted, replace=False))

def load_cifar10():
    tr = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
                             transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))])
    te = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))])
    return torchvision.datasets.CIFAR10('./data', True, tr, download=True), torchvision.datasets.CIFAR10('./data', False, te, download=True)

def evaluate(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            correct += (model(x).argmax(1) == y).sum().item()
            total += y.size(0)
    return correct / total

In [None]:
def train_one_epoch_trusted(model, train_loader, opt, clean_t, noisy_t, lam, cached_gv_ref, trusted_mask):
    crit = nn.CrossEntropyLoss()
    model.train()
    step = cached_gv_ref['step']
    cached_gv = cached_gv_ref['gv']
    
    for x, _, idx in train_loader:
        x, idx = x.to(device), idx.to(device)
        bn = noisy_t[idx]
        bc = clean_t[idx]
        
        opt.zero_grad()
        loss_s = crit(model(x), bn)
        loss_s.backward(retain_graph=True)
        gs = parameters_to_vector([p.grad for p in model.parameters()]).clone()
        
        if step % K == 0 or cached_gv is None:
            idx_np = idx.cpu().numpy()
            trusted_in_batch = [i for i, global_idx in enumerate(idx_np) if global_idx in trusted_mask]
            
            if len(trusted_in_batch) > 0:
                trusted_idx = torch.tensor(trusted_in_batch, device=device)
                x_trusted = x[trusted_idx]
                bc_trusted = bc[trusted_idx]
                
                opt.zero_grad()
                loss_v = crit(model(x_trusted), bc_trusted)
                loss_v.backward()
                cached_gv = parameters_to_vector([p.grad for p in model.parameters()]).clone()
        
        if cached_gv is None:
            cached_gv = gs.clone()
        
        gs_n = gs / (gs.norm() + 1e-12)
        gv_n = cached_gv / (cached_gv.norm() + 1e-12)
        
        g_mix = (1 - lam) * gs_n + lam * gv_n
        
        opt.zero_grad()
        i = 0
        for p in model.parameters():
            n = p.numel()
            p.grad = g_mix[i:i+n].view(p.shape).clone()
            i += n
        opt.step()
        step += 1
    
    cached_gv_ref['step'] = step
    cached_gv_ref['gv'] = cached_gv

In [None]:
trainset, testset = load_cifar10()
clean_labels = np.array(trainset.targets)
train_loader = DataLoader(IndexedDataset(trainset), BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(testset, BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

m = get_resnet18().to(device)
for _ in range(10): _ = m(torch.randn(BATCH_SIZE,3,32,32,device=device))
del m; torch.cuda.empty_cache()
print('Ready')

In [None]:
def create_checkpoint(seed, target_state, trusted_ratio, train_loader, test_loader, clean_labels):
    is_ordered = (target_state == 'ordered')
    lam = ORDERED_LAMBDA if is_ordered else COLLAPSE_LAMBDA
    epochs = ORDERED_EPOCHS if is_ordered else COLLAPSE_EPOCHS
    
    set_seed(seed)
    noisy_labels = inject_label_noise(clean_labels, NOISE_RATE, seed)
    trusted_mask = create_trusted_subset(len(clean_labels), trusted_ratio, seed)
    
    clean_t = torch.tensor(clean_labels, device=device)
    noisy_t = torch.tensor(noisy_labels, device=device)
    
    model = get_resnet18().to(device)
    opt = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
    milestones = [30, 40] if is_ordered else [50, 80]
    sched = optim.lr_scheduler.MultiStepLR(opt, milestones, gamma=0.1)
    cached_gv_ref = {'step': 0, 'gv': None}
    
    print(f'  Training {target_state} at Î»={lam}, trusted={trusted_ratio*100:.0f}%...')
    
    for ep in range(epochs):
        train_one_epoch_trusted(model, train_loader, opt, clean_t, noisy_t, lam, cached_gv_ref, trusted_mask)
        sched.step()
        if (ep + 1) % 20 == 0:
            err = 1 - evaluate(model, test_loader)
            print(f'    Epoch {ep+1}: error={err:.4f}')
    
    final_error = 1 - evaluate(model, test_loader)
    
    return {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': opt.state_dict(),
        'seed': seed,
        'trusted_ratio': trusted_ratio,
        'target_state': target_state,
        'final_error': final_error,
        'lambda': lam,
        'epochs': epochs
    }

In [None]:
def run_sweep_trusted(ckpt_data, train_loader, test_loader, clean_labels, direction='up'):
    seed = ckpt_data['seed']
    trusted_ratio = ckpt_data['trusted_ratio']
    init_error = ckpt_data['final_error']
    
    lambda_grid = LAMBDA_GRID_UP if direction == 'up' else LAMBDA_GRID_DOWN
    branch = 'ordered_up' if direction == 'up' else 'collapse_down'
    
    noisy_labels = inject_label_noise(clean_labels, NOISE_RATE, seed)
    trusted_mask = create_trusted_subset(len(clean_labels), trusted_ratio, seed)
    
    clean_t = torch.tensor(clean_labels, device=device)
    noisy_t = torch.tensor(noisy_labels, device=device)
    
    set_seed(seed + 7000)
    model = get_resnet18().to(device)
    model.load_state_dict(ckpt_data['model_state_dict'])
    
    opt = optim.SGD(model.parameters(), lr=LR * 0.01, momentum=0.9, weight_decay=5e-4)
    cached_gv_ref = {'step': 0, 'gv': None}
    trajectory = []
    
    for lam in lambda_grid:
        for _ in range(EPOCHS_PER_LAMBDA):
            train_one_epoch_trusted(model, train_loader, opt, clean_t, noisy_t, lam, cached_gv_ref, trusted_mask)
        
        err = 1 - evaluate(model, test_loader)
        trajectory.append({'lambda': float(lam), 'error': err})
        print(f'      Î»={lam:.2f}: err={err:.4f}')
    
    return {
        'seed': seed,
        'trusted_ratio': trusted_ratio,
        'branch': branch,
        'init_error': init_error,
        'final_error': trajectory[-1]['error'],
        'trajectory': trajectory
    }

In [None]:
all_results = []
ckpt_file = f'{SAVE_DIR}/{NOTEBOOK_ID}_checkpoint.json'

if os.path.exists(ckpt_file):
    all_results = json.load(open(ckpt_file))
    print(f'Loaded {len(all_results)} previous results')

done_keys = {(r['trusted_ratio'], r['seed'], r['branch']) for r in all_results}

for trusted_ratio in TRUSTED_RATIOS:
    print(f'\n{"="*60}')
    print(f'TRUSTED RATIO: {trusted_ratio*100:.0f}%')
    print(f'{"="*60}')
    
    for seed in range(N_SEEDS):
        if (trusted_ratio, seed, 'ordered_up') in done_keys and (trusted_ratio, seed, 'collapse_down') in done_keys:
            print(f'\nSeed {seed}: Already complete')
            continue
        
        print(f'\n--- Seed {seed} ---')
        
        if (trusted_ratio, seed, 'ordered_up') not in done_keys:
            print('\n[Ordered Checkpoint]')
            t0 = time.time()
            ckpt_ord = create_checkpoint(seed, 'ordered', trusted_ratio, train_loader, test_loader, clean_labels)
            print(f'  Created in {(time.time()-t0)/60:.1f}min, error={ckpt_ord["final_error"]:.4f}')
            
            print('\n[Ordered Sweep (Î» â†‘)]')
            result_ord = run_sweep_trusted(ckpt_ord, train_loader, test_loader, clean_labels, 'up')
            result_ord['experiment_id'] = f'{NOTEBOOK_ID}-tr{int(trusted_ratio*100):02d}-ord-s{seed}'
            all_results.append(result_ord)
            done_keys.add((trusted_ratio, seed, 'ordered_up'))
            json.dump(all_results, open(ckpt_file, 'w'), indent=2, default=str)
            torch.cuda.empty_cache()
        
        if (trusted_ratio, seed, 'collapse_down') not in done_keys:
            print('\n[Collapse Checkpoint]')
            t0 = time.time()
            ckpt_col = create_checkpoint(seed + 100, 'collapse', trusted_ratio, train_loader, test_loader, clean_labels)
            print(f'  Created in {(time.time()-t0)/60:.1f}min, error={ckpt_col["final_error"]:.4f}')
            
            print('\n[Collapse Sweep (Î» â†“)]')
            result_col = run_sweep_trusted(ckpt_col, train_loader, test_loader, clean_labels, 'down')
            result_col['experiment_id'] = f'{NOTEBOOK_ID}-tr{int(trusted_ratio*100):02d}-col-s{seed}'
            all_results.append(result_col)
            done_keys.add((trusted_ratio, seed, 'collapse_down'))
            json.dump(all_results, open(ckpt_file, 'w'), indent=2, default=str)
            torch.cuda.empty_cache()

print('\n' + '='*60)
print(f'{NOTEBOOK_ID} COMPLETE')
print('='*60)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

json.dump(all_results, open(f'{SAVE_DIR}/{NOTEBOOK_ID}_results.json', 'w'), indent=2, default=str)

all_data = []
for r in all_results:
    for t in r['trajectory']:
        all_data.append({'seed': r['seed'], 'trusted_ratio': r['trusted_ratio'], 'branch': r['branch'], 'lambda': t['lambda'], 'error': t['error']})
df = pd.DataFrame(all_data)
df.to_csv(f'{SAVE_DIR}/{NOTEBOOK_ID}_results.csv', index=False)

# Visualization
fig, axes = plt.subplots(1, len(TRUSTED_RATIOS), figsize=(6*len(TRUSTED_RATIOS), 5))
if len(TRUSTED_RATIOS) == 1: axes = [axes]

for i, tr in enumerate(TRUSTED_RATIOS):
    ax = axes[i]
    df_tr = df[df['trusted_ratio'] == tr]
    df_ord = df_tr[df_tr['branch'] == 'ordered_up']
    df_col = df_tr[df_tr['branch'] == 'collapse_down']
    
    if len(df_ord) > 0:
        mean_ord = df_ord.groupby('lambda')['error'].agg(['mean', 'std']).reset_index()
        ax.fill_between(mean_ord['lambda'], mean_ord['mean'] - mean_ord['std'], mean_ord['mean'] + mean_ord['std'], alpha=0.3, color='blue')
        ax.plot(mean_ord['lambda'], mean_ord['mean'], 'b-o', linewidth=2, markersize=5, label='Ordered (Î»â†‘)')
    
    if len(df_col) > 0:
        mean_col = df_col.groupby('lambda')['error'].agg(['mean', 'std']).reset_index()
        ax.fill_between(mean_col['lambda'], mean_col['mean'] - mean_col['std'], mean_col['mean'] + mean_col['std'], alpha=0.3, color='red')
        ax.plot(mean_col['lambda'], mean_col['mean'], 'r-s', linewidth=2, markersize=5, label='Collapse (Î»â†“)')
    
    ax.axhline(0.40, color='orange', linestyle='--', alpha=0.5)
    ax.axhline(0.20, color='green', linestyle='--', alpha=0.5)
    ax.set_xlabel('Î»', fontsize=12)
    ax.set_ylabel('Test Error', fontsize=12)
    ax.set_title(f'Trusted Subset: {tr*100:.0f}%', fontsize=14, fontweight='bold')
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0.28, 0.72)

plt.tight_layout()
plt.savefig(f'{SAVE_DIR}/figures/{NOTEBOOK_ID}_hysteresis_by_trusted.png', dpi=150)
plt.show()

# Summary
print('\n' + '='*60)
print(f'{NOTEBOOK_ID} SUMMARY')
print('='*60)
for tr in TRUSTED_RATIOS:
    df_tr = df[df['trusted_ratio'] == tr]
    df_ord = df_tr[df_tr['branch'] == 'ordered_up']
    df_col = df_tr[df_tr['branch'] == 'collapse_down']
    if len(df_ord) > 0 and len(df_col) > 0:
        gap = (df_col.groupby('lambda')['error'].mean() - df_ord.groupby('lambda')['error'].mean()).mean()
        print(f'\nðŸ“Š Trusted {tr*100:.0f}%: Mean gap = {gap*100:.1f}%, Two-branch = {"YES" if gap > 0.10 else "WEAK" if gap > 0.05 else "NO"}')