# 07 Label budget vs method choice
Hypothesis: best transfer strategy depends on target label budget.

## Step 1: Imports and setup
Sweep label budget with direct transfer runs.

In [None]:
from pathlib import Path
import sys
import torch
import pandas as pd
import matplotlib.pyplot as plt

ROOT = Path.cwd().resolve()
while ROOT != ROOT.parent and not (ROOT / 'src').is_dir():
    ROOT = ROOT.parent
sys.path.insert(0, str(ROOT / 'src'))

from utils.seed import set_seed
from data.cifar10_transfer import get_cifar10_transfer
from models.transfer_resnet import TransferResNet18
from methods.transfer_learning import pretrain_source, build_transferred_model, run_target_adaptation

FIGS = ROOT / 'outputs' / 'figures'
FIGS.mkdir(parents=True, exist_ok=True)

## Step 2: Run budget sweep
For each budget, pretrain source once and compare three methods.

In [None]:
SEED = 0
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
METHODS = ['scratch', 'feature_extraction', 'gradual_unfreeze']
BUDGETS = [20, 40, 80, 120]
rows = []

for budget in BUDGETS:
    set_seed(SEED)
    loaders = get_cifar10_transfer(
        data_dir='./data',
        source_classes=[2, 3, 4, 5, 6, 7],
        target_classes=[3, 5, 7],
        source_train_per_class=1000,
        source_test_per_class=300,
        target_train_per_class=budget,
        target_test_per_class=300,
        probe_per_class=120,
        batch_size=128,
        num_workers=2,
        seed=SEED,
    )

    source_model = TransferResNet18(num_classes=loaders.source_num_classes)
    pretrain_source(
        model=source_model,
        train_loader=loaders.source_train,
        test_loader=loaders.source_test,
        device=DEVICE,
        epochs=6,
        lr=0.03,
        weight_decay=5e-4,
        momentum=0.9,
        use_progress=True,
    )

    for method in METHODS:
        if method == 'scratch':
            model = TransferResNet18(num_classes=loaders.target_num_classes)
            source_head = None
            source_test = None
        else:
            model, source_head = build_transferred_model(source_model, loaders.target_num_classes)
            source_test = loaders.source_test

        result = run_target_adaptation(
            model=model,
            target_train=loaders.target_train,
            target_test=loaders.target_test,
            target_probe=loaders.target_probe,
            source_test=source_test,
            source_head=source_head,
            device=DEVICE,
            strategy=method,
            epochs=10,
            lr=0.01,
            weight_decay=5e-4,
            momentum=0.9,
            gradual_schedule={
                2: ['backbone.layer4'],
                5: ['backbone.layer3', 'backbone.layer2'],
                7: ['backbone.layer1', 'backbone.bn1', 'backbone.conv1'],
            },
            use_progress=True,
        )
        df = pd.DataFrame(result.history)
        rows.append(
            {
                'budget': budget,
                'method': method,
                'best_target_acc': float(df['target_test_acc'].max()),
                'final_target_acc': float(df['target_test_acc'].iloc[-1]),
                'final_feature_drift': float(df['feature_drift'].iloc[-1]),
            }
        )

results = pd.DataFrame(rows)
results.sort_values(['budget', 'best_target_acc'], ascending=[True, False])

## Step 3: Plot budget trends
Show how ranking changes with supervision.

In [None]:
fig, ax = plt.subplots(figsize=(6.8, 3.8))
for method, df in results.groupby('method'):
    ax.plot(df['budget'], df['best_target_acc'], marker='o', label=method)
ax.set_title('Best target accuracy by label budget')
ax.set_xlabel('target labels per class')
ax.set_ylabel('best_target_acc')
ax.grid(alpha=0.25)
ax.legend(frameon=False)
fig.savefig(FIGS / '07_budget_vs_best_target_acc.png', dpi=150, bbox_inches='tight')

fig, ax = plt.subplots(figsize=(6.8, 3.8))
for method, df in results.groupby('method'):
    ax.plot(df['budget'], df['final_feature_drift'], marker='o', label=method)
ax.set_title('Final feature drift by label budget')
ax.set_xlabel('target labels per class')
ax.set_ylabel('final_feature_drift')
ax.grid(alpha=0.25)
ax.legend(frameon=False)
fig.savefig(FIGS / '07_budget_vs_feature_drift.png', dpi=150, bbox_inches='tight')

In [None]:
results.pivot(index='budget', columns='method', values='best_target_acc')

### Expected Outcome
Feature extraction should be strong at lower budgets, while gradual unfreeze improves as budget grows.

## Interpretation
Method selection should be label-budget aware, not fixed.