# 08 Hybrid teacher + threshold
Hypothesis: hybrid is more stable early and uses more unlabeled data safely.

Toggle FAST_DEV_RUN to use MNIST for quick runs.


### Early‑state noise and the hybrid design

The hybrid combines a smoothed teacher with a threshold so early noise doesn’t flood the state.

That’s why it can be more stable in the first few epochs.



### Expected Outcome

Hybrid should match or exceed FixMatch and Mean Teacher on test accuracy.
It should also be more selective early (lower accept rate), then ramp up as the teacher stabilizes.


In [None]:
from pathlib import Path
import sys
import torch
import matplotlib.pyplot as plt

sys.path.append(str(Path.cwd().parent / 'src'))

from utils.seed import set_seed
from data.mnist import get_mnist_ssl
from data.cifar10 import get_cifar10_ssl
from models.small_cnn import SmallCNN
from models.resnet18 import build_resnet18
from methods.fixmatch import run_fixmatch
from methods.mean_teacher import run_mean_teacher
from methods.hybrid_teacher_threshold import run_hybrid


In [None]:
set_seed(0)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
FAST_DEV_RUN = False

if FAST_DEV_RUN:
    loaders = get_mnist_ssl('data', labeled_per_class=50, batch_size=128, num_workers=2, seed=0)
    build = SmallCNN
    epochs = 2
    tau = 0.95
    lambda_u = 1.0
    ema_decay = 0.99
else:
    loaders = get_cifar10_ssl('data', labeled_per_class=40, batch_size=128, num_workers=2, seed=0)
    build = build_resnet18
    epochs = 100
    tau = 0.95
    lambda_u = 10.0
    ema_decay = 0.999

# FixMatch
model_f = build()
opt_f = torch.optim.SGD(model_f.parameters(), lr=0.03, momentum=0.9, weight_decay=5e-4)
fix = run_fixmatch(
    model_f,
    loaders.labeled,
    loaders.unlabeled,
    loaders.unlabeled_eval,
    loaders.test,
    opt_f,
    DEVICE,
    epochs=epochs,
    tau=tau,
    lambda_u=lambda_u,
)

# Mean Teacher
student = build()
teacher = build()
opt_m = torch.optim.SGD(student.parameters(), lr=0.03, momentum=0.9, weight_decay=5e-4)
mt = run_mean_teacher(
    student,
    teacher,
    loaders.labeled,
    loaders.unlabeled,
    loaders.test,
    opt_m,
    DEVICE,
    epochs=epochs,
    ema_decay=ema_decay,
    lambda_u=lambda_u,
    warmup_epochs=10 if not FAST_DEV_RUN else 0,
)

# Hybrid
student_h = build()
teacher_h = build()
opt_h = torch.optim.SGD(student_h.parameters(), lr=0.03, momentum=0.9, weight_decay=5e-4)
hyb = run_hybrid(
    student_h,
    teacher_h,
    loaders.labeled,
    loaders.unlabeled,
    loaders.unlabeled_eval,
    loaders.test,
    opt_h,
    DEVICE,
    epochs=epochs,
    ema_decay=ema_decay,
    tau=tau,
    lambda_u=lambda_u,
)


In [None]:
plt.figure(figsize=(5, 3.2))
plt.plot([r['epoch'] for r in fix.history], [r['test_acc'] for r in fix.history], marker='o', label='FixMatch')
plt.plot([r['epoch'] for r in mt.history], [r['test_acc'] for r in mt.history], marker='o', label='Mean Teacher')
plt.plot([r['epoch'] for r in hyb.history], [r['test_acc'] for r in hyb.history], marker='o', label='Hybrid')
plt.title('Test accuracy comparison')
plt.xlabel('epoch')
plt.ylabel('acc')
plt.legend(frameon=False)

plt.figure(figsize=(5, 3.2))
plt.plot([r['epoch'] for r in fix.history], [r['accept_rate'] for r in fix.history], marker='o', label='FixMatch accept')
plt.plot([r['epoch'] for r in hyb.history], [r['accept_rate'] for r in hyb.history], marker='o', label='Hybrid accept')
plt.title('Unlabeled acceptance rate')
plt.xlabel('epoch')
plt.ylabel('accept rate')
plt.legend(frameon=False)
