# 04 Naive transfer and forgetting
Hypothesis: aggressive full-model updates increase drift and forgetting risk.

## Step 1: Imports and setup
Run a single aggressive adaptation to inspect failure signals.

In [None]:
from pathlib import Path
import sys
import torch
import pandas as pd
import matplotlib.pyplot as plt

ROOT = Path.cwd().resolve()
while ROOT != ROOT.parent and not (ROOT / 'src').is_dir():
    ROOT = ROOT.parent
sys.path.insert(0, str(ROOT / 'src'))

from utils.seed import set_seed
from data.cifar10_transfer import get_cifar10_transfer
from models.transfer_resnet import TransferResNet18
from methods.transfer_learning import pretrain_source, build_transferred_model, run_target_adaptation

FIGS = ROOT / 'outputs' / 'figures'
FIGS.mkdir(parents=True, exist_ok=True)

## Step 2: Build data and pretrain source model
Use the same task split as Notebook 03.

In [None]:
SEED = 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

set_seed(SEED)
loaders = get_cifar10_transfer(
    data_dir='./data',
    source_classes=[2, 3, 4, 5, 6, 7],
    target_classes=[2, 4, 6],
    source_train_per_class=1000,
    source_test_per_class=300,
    target_train_per_class=40,
    target_test_per_class=300,
    probe_per_class=120,
    batch_size=128,
    num_workers=2,
    seed=SEED,
)

source_model = TransferResNet18(num_classes=loaders.source_num_classes)
pretrain_source(
    model=source_model,
    train_loader=loaders.source_train,
    test_loader=loaders.source_test,
    device=DEVICE,
    epochs=6,
    lr=0.03,
    weight_decay=5e-4,
    momentum=0.9,
    use_progress=True,
);

## Step 3: Run naive fine-tuning
Increase target learning rate to expose instability.

In [None]:
model, source_head = build_transferred_model(source_model, loaders.target_num_classes)
result = run_target_adaptation(
    model=model,
    target_train=loaders.target_train,
    target_test=loaders.target_test,
    target_probe=loaders.target_probe,
    source_test=loaders.source_test,
    source_head=source_head,
    device=DEVICE,
    strategy='naive_finetune',
    epochs=16,
    lr=0.04,
    weight_decay=5e-4,
    momentum=0.9,
    gradual_schedule={
        4: ['backbone.layer4'],
        8: ['backbone.layer3', 'backbone.layer2'],
        11: ['backbone.layer1', 'backbone.bn1', 'backbone.conv1'],
    },
    use_progress=True,
)
naive = pd.DataFrame(result.history)
naive[['epoch', 'target_test_acc', 'source_retention_acc', 'feature_drift', 'grad_norm']].tail()

## Step 4: Plot forgetting signatures
Look for drift up + retention down + unstable gradients.

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(12.2, 3.3))
axes[0].plot(naive['epoch'], naive['target_test_acc'], marker='o')
axes[0].set_title('target accuracy')
axes[1].plot(naive['epoch'], naive['source_retention_acc'], marker='o')
axes[1].set_title('source retention')
axes[2].plot(naive['epoch'], naive['feature_drift'], marker='o')
axes[2].set_title('feature drift')
for ax in axes:
    ax.set_xlabel('epoch')
    ax.grid(alpha=0.25)
fig.savefig(FIGS / '04_naive_failure_dashboard.png', dpi=150, bbox_inches='tight')

fig, ax = plt.subplots(figsize=(6.0, 3.3))
ax.plot(naive['epoch'], naive['grad_norm'], marker='o', color='#E45756')
ax.set_title('Naive gradient norm')
ax.set_xlabel('epoch')
ax.set_ylabel('grad_norm')
ax.grid(alpha=0.25)
fig.savefig(FIGS / '04_naive_grad_norm.png', dpi=150, bbox_inches='tight')

### Expected Outcome
Under aggressive updates, drift should rise and source retention should drop.

## Reading This Pattern
This is the classic catastrophic-forgetting profile for transfer.