# Policy Gradient Active Learning Demo

This notebook demonstrates the REINFORCE-based active learning agent for Dogs vs Cats classification.

In [None]:
import sys
sys.path.append('..')

import torch
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

from src.models.cnn import ResNet18Binary
from src.rl.policy import PolicyNetwork
from src.datasets.loaders import get_dataloaders
from src.utils.confidence import get_confidence_metrics
from src.utils.evaluation import track_learning_curve

## Load Trained Models

In [None]:
# Load the trained classifier and policy
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load classifier
classifier = ResNet18Binary(pretrained=True).to(device)
classifier.load_state_dict(torch.load('../checkpoints/week2/resnet18_epoch5.pth', map_location=device))

# Load policy
policy = PolicyNetwork(in_dim=512, hidden=256).to(device)
policy.load_state_dict(torch.load('../checkpoints/week4/policy_epoch5.pth', map_location=device))

## Visualize Sample Selection Process

In [None]:
# Load data
_, val_loader, _, val_ds = get_dataloaders('../data/processed/catsdogs_128', batch_size=32)

# Get confidence metrics
metrics = get_confidence_metrics(classifier, val_loader, device)

# Visualize entropy distribution
plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.hist(metrics['entropy'].numpy(), bins=50, alpha=0.7)
plt.xlabel('Entropy')
plt.ylabel('Frequency')
plt.title('Uncertainty Distribution')

plt.subplot(1, 3, 2)
plt.hist(metrics['margin'].numpy(), bins=50, alpha=0.7)
plt.xlabel('Margin')
plt.ylabel('Frequency')
plt.title('Confidence Margin Distribution')

plt.subplot(1, 3, 3)
# Policy scores for features
with torch.no_grad():
    policy_scores = policy(metrics['features'][:1000].to(device)).cpu()
plt.hist(policy_scores.numpy(), bins=50, alpha=0.7)
plt.xlabel('Policy Score')
plt.ylabel('Frequency')
plt.title('Policy Selection Scores')

plt.tight_layout()
plt.show()

## Compare Learning Curves

In [None]:
# Load results from all weeks
import json

# Load Week 3 results (baselines)
with open('../outputs/week3/curves.json', 'r') as f:
    week3_curves = json.load(f)

# Load Week 4 results (RL)
with open('../outputs/week4/rl_curve.json', 'r') as f:
    week4_curves = json.load(f)

# Plot comparison
plt.figure(figsize=(10, 6))

# Plot baselines
for method, accuracies in week3_curves.items():
    rounds = list(range(len(accuracies)))
    labeled_counts = [1000 + r * 500 for r in rounds]
    plt.plot(labeled_counts, accuracies, label=f'{method} Sampling', marker='o')

# Plot RL curve
rl_accuracies = week4_curves['RL']
rl_steps = list(range(len(rl_accuracies)))
rl_labeled_counts = [1000 + (s // 9) * 500 for s in rl_steps]  # Approximate mapping
plt.plot(rl_labeled_counts[:len(rl_accuracies)], rl_accuracies, label='REINFORCE Policy', marker='s', linewidth=2)

plt.xlabel('Number of Labeled Samples')
plt.ylabel('Validation Accuracy')
plt.title('Active Learning Comparison: Policy Gradient vs Traditional Methods')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# Print sample efficiency
print("\nSample Efficiency Analysis:")
print("Method\t\t\tSamples to 90% Accuracy")
print("-" * 40)
for method, accuracies in week3_curves.items():
    for i, acc in enumerate(accuracies):
        if acc >= 0.90:
            samples = 1000 + i * 500
            print(f"{method}\t\t\t{samples}")
            break
    else:
        print(f"{method}\t\t\tNot reached")

## Policy Behavior Analysis

In [None]:
# Analyze what the policy has learned
with torch.no_grad():
    # Get policy scores for different confidence levels
    high_entropy_idx = torch.argsort(metrics['entropy'], descending=True)[:100]
    low_entropy_idx = torch.argsort(metrics['entropy'], descending=False)[:100]
    
    high_entropy_features = metrics['features'][high_entropy_idx].to(device)
    low_entropy_features = metrics['features'][low_entropy_idx].to(device)
    
    high_entropy_scores = policy(high_entropy_features).cpu()
    low_entropy_scores = policy(low_entropy_features).cpu()

plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.scatter(metrics['entropy'][high_entropy_idx], high_entropy_scores, alpha=0.6, label='High Entropy')
plt.scatter(metrics['entropy'][low_entropy_idx], low_entropy_scores, alpha=0.6, label='Low Entropy')
plt.xlabel('Sample Entropy')
plt.ylabel('Policy Score')
plt.title('Policy Score vs Sample Uncertainty')
plt.legend()

plt.subplot(1, 2, 2)
plt.boxplot([high_entropy_scores.numpy(), low_entropy_scores.numpy()], 
           labels=['High Entropy\n(Uncertain)', 'Low Entropy\n(Confident)'])
plt.ylabel('Policy Score')
plt.title('Policy Preference Distribution')

plt.tight_layout()
plt.show()

print(f"Mean policy score for high entropy samples: {high_entropy_scores.mean():.3f}")
print(f"Mean policy score for low entropy samples: {low_entropy_scores.mean():.3f}")
print(f"Policy preference ratio (high/low): {high_entropy_scores.mean() / low_entropy_scores.mean():.2f}")