# 08 Design pattern and checklist
Hypothesis: method choice should balance target quality and stability diagnostics.

## Step 1: Imports and setup
Run all strategies on one related-task setup and score them.

In [None]:
from pathlib import Path
import sys
import torch
import pandas as pd
import matplotlib.pyplot as plt

ROOT = Path.cwd().resolve()
while ROOT != ROOT.parent and not (ROOT / 'src').is_dir():
    ROOT = ROOT.parent
sys.path.insert(0, str(ROOT / 'src'))

from utils.seed import set_seed
from data.cifar10_transfer import get_cifar10_transfer
from models.transfer_resnet import TransferResNet18
from methods.transfer_learning import pretrain_source, build_transferred_model, run_target_adaptation

FIGS = ROOT / 'outputs' / 'figures'
FIGS.mkdir(parents=True, exist_ok=True)

## Step 2: Run methods and build summary table
Use direct metrics for recommendation logic.

In [None]:
SEED = 0
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

set_seed(SEED)
loaders = get_cifar10_transfer(
    data_dir='./data',
    source_classes=[2, 3, 4, 5, 6, 7],
    target_classes=[3, 5, 7],
    source_train_per_class=1000,
    source_test_per_class=300,
    target_train_per_class=80,
    target_test_per_class=300,
    probe_per_class=120,
    batch_size=128,
    num_workers=2,
    seed=SEED,
)

source_model = TransferResNet18(num_classes=loaders.source_num_classes)
pretrain_source(
    model=source_model,
    train_loader=loaders.source_train,
    test_loader=loaders.source_test,
    device=DEVICE,
    epochs=6,
    lr=0.03,
    weight_decay=5e-4,
    momentum=0.9,
    use_progress=True,
);

methods = ['scratch', 'feature_extraction', 'gradual_unfreeze', 'naive_finetune']
frames = {}

for method in methods:
    if method == 'scratch':
        model = TransferResNet18(num_classes=loaders.target_num_classes)
        source_head = None
        source_test = None
    else:
        model, source_head = build_transferred_model(source_model, loaders.target_num_classes)
        source_test = loaders.source_test

    result = run_target_adaptation(
        model=model,
        target_train=loaders.target_train,
        target_test=loaders.target_test,
        target_probe=loaders.target_probe,
        source_test=source_test,
        source_head=source_head,
        device=DEVICE,
        strategy=method,
        epochs=10,
        lr=0.01,
        weight_decay=5e-4,
        momentum=0.9,
        gradual_schedule={
            2: ['backbone.layer4'],
            5: ['backbone.layer3', 'backbone.layer2'],
            7: ['backbone.layer1', 'backbone.bn1', 'backbone.conv1'],
        },
        use_progress=True,
    )
    frames[method] = pd.DataFrame(result.history)

summary = pd.DataFrame([
    {
        'method': m,
        'final_target_acc': float(df['target_test_acc'].iloc[-1]),
        'final_feature_drift': float(df['feature_drift'].iloc[-1]),
        'final_retention': float(df['source_retention_acc'].iloc[-1]) if 'source_retention_acc' in df else float('nan'),
        'final_grad_norm': float(df['grad_norm'].iloc[-1]),
    }
    for m, df in frames.items()
]).sort_values('final_target_acc', ascending=False)

stable = summary[summary['final_feature_drift'] <= 0.25]
recommended = stable.sort_values('final_target_acc', ascending=False).iloc[0]['method'] if not stable.empty else summary.iloc[0]['method']
summary

## Step 3: Plot decision visuals
Use trajectory and frontier views to justify the recommendation.

In [None]:
fig, ax = plt.subplots(figsize=(6.6, 3.7))
for method, df in frames.items():
    ax.plot(df['epoch'], df['target_test_acc'], marker='o', label=method)
ax.set_title('Target accuracy overview')
ax.set_xlabel('epoch')
ax.set_ylabel('target_test_acc')
ax.grid(alpha=0.25)
ax.legend(frameon=False)
fig.savefig(FIGS / '08_target_accuracy_overview.png', dpi=150, bbox_inches='tight')

fig, ax = plt.subplots(figsize=(6.2, 3.8))
for _, row in summary.iterrows():
    edge = 'black' if row['method'] == recommended else 'none'
    size = 90 if row['method'] == recommended else 55
    ax.scatter(row['final_feature_drift'], row['final_target_acc'], s=size, edgecolors=edge)
    ax.text(row['final_feature_drift'] + 0.004, row['final_target_acc'] + 0.002, row['method'], fontsize=9)
ax.set_title('Accuracy-stability frontier')
ax.set_xlabel('final_feature_drift (lower is better)')
ax.set_ylabel('final_target_acc')
ax.grid(alpha=0.25)
fig.savefig(FIGS / '08_accuracy_stability_frontier.png', dpi=150, bbox_inches='tight')

In [None]:
f'Recommended default: {recommended}'

### Expected Outcome
A good recommendation should have strong target accuracy with bounded drift.

## Final Checklist
1. Validate task relatedness.
2. Check retention and drift, not just endpoint accuracy.
3. Compare against scratch before claiming transfer gains.