# W2: Causal Intervention

**Purpose**: Demonstrate that hysteresis is truly history-dependent via causal intervention

**Protocol**:
- Start from Ordered checkpoint (low error) and Collapse checkpoint (high error)
- Apply sudden Œª jump to the SAME target Œª value
- Observe: Do they converge to same state, or remain distinct?

**Key Question**: At the same Œª, does history matter? (Causal test of bistability)

**Interventions**:
1. Ordered @ Œª=0.35 ‚Üí Jump to Œª=0.50 ‚Üí Train N epochs
2. Collapse @ Œª=0.60 ‚Üí Jump to Œª=0.50 ‚Üí Train N epochs
3. Compare final states

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os, glob, json, time
from datetime import datetime

EXP_NAME = 'exp_W2_causal_intervention'
NOTEBOOK_ID = 'W2'
BASE_DIR = '/content/drive/MyDrive/dual-gradient-learning/Paper-A'

existing = glob.glob(f'{BASE_DIR}/{EXP_NAME}_*')
if existing:
    SAVE_DIR = sorted(existing)[-1]
    print(f'üîÑ Resuming: {SAVE_DIR}')
else:
    TIMESTAMP = datetime.now().strftime('%Y%m%d_%H%M%S')
    SAVE_DIR = f'{BASE_DIR}/{EXP_NAME}_{TIMESTAMP}'
    os.makedirs(SAVE_DIR, exist_ok=True)
    print(f'üÜï New: {SAVE_DIR}')

os.makedirs(f'{SAVE_DIR}/checkpoints', exist_ok=True)
os.makedirs(f'{SAVE_DIR}/figures', exist_ok=True)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils import parameters_to_vector
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
# Core parameters
BATCH_SIZE = 256
NUM_WORKERS = 4
LR = 0.1
K = 16
NOISE_RATE = 0.4

# Checkpoint creation
ORDERED_LAMBDA = 0.35
ORDERED_EPOCHS = 50
ORDERED_THRESHOLD = 0.25

COLLAPSE_LAMBDA = 0.60
COLLAPSE_EPOCHS = 80
COLLAPSE_THRESHOLD = 0.45

# Intervention settings
# Jump both checkpoints to these target Œª values
TARGET_LAMBDAS = [0.40, 0.50, 0.55]  # Multiple intervention points
POST_INTERVENTION_EPOCHS = 30  # Train this many epochs after jump
EVAL_FREQ = 5  # Evaluate every N epochs

N_SEEDS = 3

print(f'Intervention targets: {TARGET_LAMBDAS}')
print(f'Post-intervention epochs: {POST_INTERVENTION_EPOCHS}')

In [None]:
def get_resnet18():
    model = resnet18(weights=None, num_classes=10)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    return model

class IndexedDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        return img, label, idx
    def __len__(self):
        return len(self.dataset)

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def inject_label_noise(labels, noise_rate, seed):
    np.random.seed(seed)
    noisy = labels.copy()
    n_noisy = int(noise_rate * len(labels))
    idx = np.random.choice(len(labels), n_noisy, replace=False)
    for i in idx:
        noisy[i] = np.random.choice([l for l in range(10) if l != labels[i]])
    return noisy

def load_cifar10():
    tr = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
                             transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))])
    te = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))])
    return torchvision.datasets.CIFAR10('./data', True, tr, download=True), torchvision.datasets.CIFAR10('./data', False, te, download=True)

def evaluate(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            correct += (model(x).argmax(1) == y).sum().item()
            total += y.size(0)
    return correct / total

In [None]:
def train_one_epoch(model, train_loader, opt, clean_t, noisy_t, lam, state):
    crit = nn.CrossEntropyLoss()
    model.train()
    step = state['step']
    cached_gv = state['gv']
    
    for x, _, idx in train_loader:
        x, idx = x.to(device), idx.to(device)
        bn, bc = noisy_t[idx], clean_t[idx]
        
        opt.zero_grad()
        loss_s = crit(model(x), bn)
        loss_s.backward(retain_graph=True)
        g_s = parameters_to_vector([p.grad for p in model.parameters()]).clone()
        
        if step % K == 0 or cached_gv is None:
            opt.zero_grad()
            loss_v = crit(model(x), bc)
            loss_v.backward()
            cached_gv = parameters_to_vector([p.grad for p in model.parameters()]).clone()
        
        g_s_n = g_s / (g_s.norm() + 1e-12)
        g_v_n = cached_gv / (cached_gv.norm() + 1e-12)
        g_mix = (1 - lam) * g_s_n + lam * g_v_n
        
        opt.zero_grad()
        i = 0
        for p in model.parameters():
            n = p.numel()
            p.grad = g_mix[i:i+n].view(p.shape).clone()
            i += n
        opt.step()
        step += 1
    
    state['step'] = step
    state['gv'] = cached_gv

In [None]:
trainset, testset = load_cifar10()
clean_labels = np.array(trainset.targets)
train_loader = DataLoader(IndexedDataset(trainset), BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(testset, BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

m = get_resnet18().to(device)
for _ in range(5): _ = m(torch.randn(BATCH_SIZE,3,32,32,device=device))
del m; torch.cuda.empty_cache()
print('Ready')

In [None]:
def run_intervention(checkpoint_state, checkpoint_type, target_lambda, seed, clean_t, noisy_t):
    """
    Run causal intervention: jump to target_lambda and observe trajectory.
    """
    set_seed(seed + 500 + int(target_lambda * 100))
    
    model = get_resnet18().to(device)
    model.load_state_dict({k: v.to(device) for k, v in checkpoint_state.items()})
    
    # Use lower LR for intervention phase (like sweep)
    opt = optim.SGD(model.parameters(), lr=LR * 0.01, momentum=0.9, weight_decay=5e-4)
    state = {'step': 0, 'gv': None}
    
    init_error = 1 - evaluate(model, test_loader)
    trajectory = [{'epoch': 0, 'error': init_error}]
    
    # Train at target lambda
    for ep in range(POST_INTERVENTION_EPOCHS):
        train_one_epoch(model, train_loader, opt, clean_t, noisy_t, target_lambda, state)
        
        if (ep + 1) % EVAL_FREQ == 0:
            err = 1 - evaluate(model, test_loader)
            trajectory.append({'epoch': ep + 1, 'error': err})
    
    final_error = 1 - evaluate(model, test_loader)
    if trajectory[-1]['epoch'] != POST_INTERVENTION_EPOCHS:
        trajectory.append({'epoch': POST_INTERVENTION_EPOCHS, 'error': final_error})
    
    del model; torch.cuda.empty_cache()
    
    return {
        'source': checkpoint_type,
        'target_lambda': target_lambda,
        'init_error': init_error,
        'final_error': final_error,
        'trajectory': trajectory
    }

In [None]:
results = []
ckpt_file = f'{SAVE_DIR}/{NOTEBOOK_ID}_checkpoint.json'

if os.path.exists(ckpt_file):
    results = json.load(open(ckpt_file))
    done_seeds = {r['seed'] for r in results}
    print(f'Loaded: {len(done_seeds)} seeds done')
else:
    done_seeds = set()

for seed in range(N_SEEDS):
    if seed in done_seeds:
        print(f'Seed {seed}: Already done')
        continue
    
    print(f'\n{"="*60}')
    print(f'SEED {seed}')
    print(f'{"="*60}')
    
    t0 = time.time()
    
    # Setup
    noisy_labels = inject_label_noise(clean_labels, NOISE_RATE, seed)
    clean_t = torch.tensor(clean_labels, device=device)
    noisy_t = torch.tensor(noisy_labels, device=device)
    
    seed_result = {'seed': seed, 'interventions': []}
    
    # === Phase 1: Create Ordered Checkpoint ===
    print(f'\n[Phase 1] Creating Ordered Checkpoint (Œª={ORDERED_LAMBDA})...')
    set_seed(seed)
    model = get_resnet18().to(device)
    opt = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
    sched = optim.lr_scheduler.MultiStepLR(opt, [30, 40], gamma=0.1)
    state = {'step': 0, 'gv': None}
    
    for ep in range(ORDERED_EPOCHS):
        train_one_epoch(model, train_loader, opt, clean_t, noisy_t, ORDERED_LAMBDA, state)
        sched.step()
        if (ep + 1) % 10 == 0:
            err = 1 - evaluate(model, test_loader)
            print(f'  Epoch {ep+1}: error={err:.4f}')
    
    ordered_error = 1 - evaluate(model, test_loader)
    ordered_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
    seed_result['ordered_init_error'] = ordered_error
    print(f'  ‚úÖ Ordered checkpoint: {ordered_error:.2%}')
    del model; torch.cuda.empty_cache()
    
    # === Phase 2: Create Collapse Checkpoint ===
    print(f'\n[Phase 2] Creating Collapse Checkpoint (Œª={COLLAPSE_LAMBDA})...')
    set_seed(seed + 100)
    model = get_resnet18().to(device)
    opt = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
    sched = optim.lr_scheduler.MultiStepLR(opt, [40, 60], gamma=0.1)
    state = {'step': 0, 'gv': None}
    
    for ep in range(COLLAPSE_EPOCHS):
        train_one_epoch(model, train_loader, opt, clean_t, noisy_t, COLLAPSE_LAMBDA, state)
        sched.step()
        if (ep + 1) % 20 == 0:
            err = 1 - evaluate(model, test_loader)
            print(f'  Epoch {ep+1}: error={err:.4f}')
    
    collapse_error = 1 - evaluate(model, test_loader)
    collapse_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
    seed_result['collapse_init_error'] = collapse_error
    print(f'  üíÄ Collapse checkpoint: {collapse_error:.2%}')
    del model; torch.cuda.empty_cache()
    
    # === Phase 3: Causal Interventions ===
    print(f'\n[Phase 3] Causal Interventions...')
    
    for target_lam in TARGET_LAMBDAS:
        print(f'\n  --- Target Œª = {target_lam} ---')
        
        # Intervention from Ordered
        print(f'    From Ordered ({ordered_error:.2%}) ‚Üí Œª={target_lam}...')
        ord_result = run_intervention(ordered_state, 'ordered', target_lam, seed, clean_t, noisy_t)
        print(f'      Final: {ord_result["final_error"]:.2%}')
        
        # Intervention from Collapse
        print(f'    From Collapse ({collapse_error:.2%}) ‚Üí Œª={target_lam}...')
        col_result = run_intervention(collapse_state, 'collapse', target_lam, seed, clean_t, noisy_t)
        print(f'      Final: {col_result["final_error"]:.2%}')
        
        # Gap after intervention
        gap = col_result['final_error'] - ord_result['final_error']
        print(f'    üìä Gap after {POST_INTERVENTION_EPOCHS} epochs: {gap*100:.1f}%')
        
        seed_result['interventions'].append({
            'target_lambda': target_lam,
            'ordered': ord_result,
            'collapse': col_result,
            'final_gap': gap
        })
    
    elapsed = time.time() - t0
    seed_result['time_seconds'] = elapsed
    seed_result['experiment_id'] = f'{NOTEBOOK_ID}-seed{seed:02d}'
    
    results.append(seed_result)
    json.dump(results, open(ckpt_file, 'w'), indent=2, default=str)
    done_seeds.add(seed)
    print(f'\n  ‚è±Ô∏è Time: {elapsed/60:.1f} min')

print(f'\n{"="*60}')
print(f'{NOTEBOOK_ID} COMPLETE')
print(f'{"="*60}')

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

json.dump(results, open(f'{SAVE_DIR}/{NOTEBOOK_ID}_results.json', 'w'), indent=2, default=str)

# Visualization: Intervention trajectories for each target Œª
n_targets = len(TARGET_LAMBDAS)
fig, axes = plt.subplots(1, n_targets, figsize=(6*n_targets, 5))
if n_targets == 1:
    axes = [axes]

for ax, target_lam in zip(axes, TARGET_LAMBDAS):
    # Collect trajectories for this target
    for r in results:
        for intv in r['interventions']:
            if intv['target_lambda'] == target_lam:
                # Ordered trajectory
                epochs_o = [t['epoch'] for t in intv['ordered']['trajectory']]
                errors_o = [t['error'] for t in intv['ordered']['trajectory']]
                ax.plot(epochs_o, errors_o, 'b-o', alpha=0.6, linewidth=2, markersize=5)
                
                # Collapse trajectory
                epochs_c = [t['epoch'] for t in intv['collapse']['trajectory']]
                errors_c = [t['error'] for t in intv['collapse']['trajectory']]
                ax.plot(epochs_c, errors_c, 'r-s', alpha=0.6, linewidth=2, markersize=5)
    
    ax.axhline(0.40, color='orange', linestyle='--', alpha=0.5)
    ax.set_xlabel('Epochs after intervention', fontsize=12)
    ax.set_ylabel('Test Error', fontsize=12)
    ax.set_title(f'Intervention to Œª={target_lam}', fontsize=14, fontweight='bold')
    ax.legend(['From Ordered', 'From Collapse'], fontsize=10)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(-1, POST_INTERVENTION_EPOCHS + 1)

plt.tight_layout()
plt.savefig(f'{SAVE_DIR}/figures/{NOTEBOOK_ID}_intervention_trajectories.png', dpi=150)
plt.savefig(f'{SAVE_DIR}/figures/{NOTEBOOK_ID}_intervention_trajectories.pdf')
plt.show()

In [None]:
# Bar plot: Final gap for each target Œª
fig, ax = plt.subplots(figsize=(10, 6))

gap_data = {lam: [] for lam in TARGET_LAMBDAS}
for r in results:
    for intv in r['interventions']:
        gap_data[intv['target_lambda']].append(intv['final_gap'])

x = np.arange(len(TARGET_LAMBDAS))
means = [np.mean(gap_data[lam]) for lam in TARGET_LAMBDAS]
stds = [np.std(gap_data[lam]) for lam in TARGET_LAMBDAS]

bars = ax.bar(x, [m * 100 for m in means], yerr=[s * 100 for s in stds], 
              capsize=5, color='steelblue', alpha=0.8, edgecolor='navy')

ax.axhline(10, color='green', linestyle='--', alpha=0.7, label='Strong gap threshold (10%)')
ax.axhline(0, color='black', linestyle='-', alpha=0.3)

ax.set_xlabel('Target Œª', fontsize=12)
ax.set_ylabel('Remaining Gap after Intervention (%)', fontsize=12)
ax.set_title(f'Causal Intervention: History-Dependence Test\n({POST_INTERVENTION_EPOCHS} epochs at each target Œª)', fontsize=14)
ax.set_xticks(x)
ax.set_xticklabels([f'Œª={lam}' for lam in TARGET_LAMBDAS])
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for i, (m, s) in enumerate(zip(means, stds)):
    ax.text(i, m*100 + s*100 + 1, f'{m*100:.1f}%', ha='center', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.savefig(f'{SAVE_DIR}/figures/{NOTEBOOK_ID}_gap_after_intervention.png', dpi=150)
plt.savefig(f'{SAVE_DIR}/figures/{NOTEBOOK_ID}_gap_after_intervention.pdf')
plt.show()

In [None]:
# Summary
print('='*60)
print(f'{NOTEBOOK_ID} SUMMARY: Causal Intervention')
print('='*60)

print(f'\nüìä Experimental Design:')
print(f'   Ordered checkpoint: Œª={ORDERED_LAMBDA}')
print(f'   Collapse checkpoint: Œª={COLLAPSE_LAMBDA}')
print(f'   Post-intervention epochs: {POST_INTERVENTION_EPOCHS}')
print(f'   Target Œª values: {TARGET_LAMBDAS}')

print(f'\nüìä Results by target Œª:')
for lam in TARGET_LAMBDAS:
    gaps = gap_data[lam]
    if gaps:
        mean_gap = np.mean(gaps)
        std_gap = np.std(gaps)
        print(f'\n   Œª={lam}:')
        print(f'     Gap after intervention: {mean_gap*100:.1f} ¬± {std_gap*100:.1f}%')
        if mean_gap > 0.10:
            print(f'     ‚Üí ‚úÖ Strong history-dependence (bistability confirmed)')
        elif mean_gap > 0.05:
            print(f'     ‚Üí ‚ö†Ô∏è Moderate history-dependence')
        else:
            print(f'     ‚Üí ‚ùå Weak/no history-dependence (converging)')

# Overall conclusion
all_gaps = [g for gaps in gap_data.values() for g in gaps]
overall_mean = np.mean(all_gaps) if all_gaps else 0

print(f'\n{"="*60}')
print(f'CONCLUSION:')
if overall_mean > 0.10:
    print(f'  ‚úÖ BISTABILITY CONFIRMED')
    print(f'  ‚úÖ History determines state even at same Œª')
    print(f'  ‚úÖ Average persistent gap: {overall_mean*100:.1f}%')
elif overall_mean > 0.05:
    print(f'  ‚ö†Ô∏è PARTIAL BISTABILITY')
    print(f'  ‚ö†Ô∏è Some history-dependence persists')
    print(f'  ‚ö†Ô∏è Average persistent gap: {overall_mean*100:.1f}%')
else:
    print(f'  ‚ùå NO BISTABILITY')
    print(f'  ‚ùå States converge regardless of history')
    print(f'  ‚ùå Average persistent gap: {overall_mean*100:.1f}%')
print(f'{"="*60}')

In [None]:
# Create summary CSV
summary_data = []
for r in results:
    for intv in r['interventions']:
        summary_data.append({
            'seed': r['seed'],
            'target_lambda': intv['target_lambda'],
            'ordered_init': intv['ordered']['init_error'],
            'ordered_final': intv['ordered']['final_error'],
            'collapse_init': intv['collapse']['init_error'],
            'collapse_final': intv['collapse']['final_error'],
            'final_gap': intv['final_gap']
        })

df_summary = pd.DataFrame(summary_data)
df_summary.to_csv(f'{SAVE_DIR}/{NOTEBOOK_ID}_summary.csv', index=False)
print('Summary saved to CSV')
print(df_summary.to_string())