In [None]:
# protonet_mobilenet_final_tqdm.py
"""
Prototypical Network with MobileNetV2 backbone (ImageNet pretrained),
two-stage fine-tuning, cosine-prototype similarity with learnable temperature,
and episodic training on CIFAR-100.

This variant adds tqdm progress bars for training episodes, validation episodes,
and meta-test episodes so you can observe progress in real time.

Usage:
    python protonet_mobilenet_final_tqdm.py
"""

import os
import random
import math
from collections import defaultdict
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms, models

# ----------------------------
# Configuration
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)

# Episode params
TRAIN_N_WAY_MIN = 5
TRAIN_N_WAY_MAX = 10
K_SHOT = 1
Q_QUERY = 15
EPISODES_PER_EPOCH = 200
EPOCHS = 10
STAGE1_EPOCHS = 8

EMBED_DIM = 256
LR_HEAD = 1e-3
LR_BACKBONE = 1e-4
WEIGHT_DECAY = 1e-4
CHECKPOINT_DIR = "./checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Meta-test settings
META_TEST_N = 10
META_TEST_K = 1
META_TEST_Q = 10
META_TEST_EPISODES = 600

# ----------------------------
# Dataset: episodic sampler for CIFAR-100
# ----------------------------
class EpisodicCIFAR100:
    def __init__(self, root='./data', train=True, transform=None, train_class_split=0.8):
        self.cifar = torchvision.datasets.CIFAR100(root=root, train=True, download=True)
        self.transform = transform
        by_class = defaultdict(list)
        for idx, (_, label) in enumerate(self.cifar):
            by_class[label].append(idx)
        all_classes = sorted(list(by_class.keys()))
        n_train_classes = int(len(all_classes) * train_class_split)
        train_classes = all_classes[:n_train_classes]
        val_classes = all_classes[n_train_classes:]
        classes_use = train_classes if train else val_classes
        self.by_class = {}
        for new_label, cls in enumerate(classes_use):
            self.by_class[new_label] = by_class[cls]
        self.classes = list(self.by_class.keys())

    def sample_episode(self, n_way, k_shot, q_query):
        chosen_classes = random.sample(self.classes, n_way)
        support_x, support_y, query_x, query_y = [], [], [], []
        for i, cls in enumerate(chosen_classes):
            indices = random.sample(self.by_class[cls], k_shot + q_query)
            for si in indices[:k_shot]:
                img, _ = self.cifar[si]
                support_x.append(self.transform(img) if self.transform else img)
                support_y.append(i)
            for qi in indices[k_shot:]:
                img, _ = self.cifar[qi]
                query_x.append(self.transform(img) if self.transform else img)
                query_y.append(i)
        support_x = torch.stack(support_x)
        query_x = torch.stack(query_x)
        support_y = torch.tensor(support_y, dtype=torch.long)
        query_y = torch.tensor(query_y, dtype=torch.long)
        return support_x, support_y, query_x, query_y

# ----------------------------
# Model
# ----------------------------
class MobileNetProto(nn.Module):
    def __init__(self, embed_dim=EMBED_DIM, pretrained=True):
        super().__init__()
        m = models.mobilenet_v2(pretrained=pretrained)
        self.features = m.features
        last_channels = m.last_channel
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(last_channels, last_channels // 2),
            nn.ReLU(inplace=True),
            nn.Linear(last_channels // 2, embed_dim)
        )
        self.log_scale = nn.Parameter(torch.tensor(0.0))

    def forward(self, x):
        f = self.features(x)
        pooled = self.pool(f).view(f.size(0), -1)
        emb = self.fc(pooled)
        emb = F.normalize(emb, p=2, dim=1)
        scale = torch.exp(self.log_scale)
        return emb, scale

# ----------------------------
# Loss (cosine prototypical)
# ----------------------------
def prototypical_loss_cosine(support_emb, support_y, query_emb, query_y, n_way, k_shot, scale):
    prototypes = []
    for c in range(n_way):
        idxs = (support_y == c).nonzero(as_tuple=True)[0]
        proto = support_emb[idxs].mean(dim=0)
        proto = F.normalize(proto, p=2, dim=0)
        prototypes.append(proto)
    prototypes = torch.stack(prototypes)
    logits = scale * (query_emb @ prototypes.t())
    log_p = F.log_softmax(logits, dim=1)
    loss = F.nll_loss(log_p, query_y.to(log_p.device))
    y_hat = log_p.argmax(dim=1)
    acc = (y_hat == query_y.to(y_hat.device)).float().mean().item()
    return loss, acc

# ----------------------------
# Training with progress bars
# ----------------------------
def train_protonet():
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.2, 0.2, 0.2, 0.05),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    episodic_train = EpisodicCIFAR100(train=True, transform=train_transform, train_class_split=0.8)
    episodic_val = EpisodicCIFAR100(train=False, transform=val_transform, train_class_split=0.8)

    model = MobileNetProto(embed_dim=EMBED_DIM, pretrained=True).to(device)

    # Stage 1: freeze backbone
    for p in model.features.parameters():
        p.requires_grad = False
    model.features.eval()
    opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR_HEAD, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=12, gamma=0.5)

    best_val = 0.0

    # outer epoch loop with tqdm
    for epoch in range(1, EPOCHS + 1):
        model.train()
        losses = []
        accs = []

        # inner episode loop with tqdm progress bar
        ep_bar = tqdm(range(EPISODES_PER_EPOCH), desc=f"Epoch {epoch}/{EPOCHS} episodes", leave=False)
        for _ in ep_bar:
            n_way = random.randint(TRAIN_N_WAY_MIN, TRAIN_N_WAY_MAX)
            support_x, support_y, query_x, query_y = episodic_train.sample_episode(n_way, K_SHOT, Q_QUERY)
            support_x = support_x.to(device); support_y = support_y.to(device)
            query_x = query_x.to(device); query_y = query_y.to(device)

            support_emb, scale = model(support_x)
            query_emb, _ = model(query_x)

            loss, acc = prototypical_loss_cosine(support_emb, support_y, query_emb, query_y, n_way, K_SHOT, scale)

            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, model.parameters()), max_norm=5.0)
            opt.step()

            losses.append(loss.item()); accs.append(acc)
            # update tqdm postfix
            ep_bar.set_postfix({"loss": f"{loss.item():.3f}", "acc": f"{acc*100:.2f}%"})

        scheduler.step()

        avg_loss = sum(losses) / len(losses)
        avg_acc = sum(accs) / len(accs)
        tqdm.write(f"Epoch {epoch}/{EPOCHS}  train_loss={avg_loss:.4f}  train_acc={avg_acc*100:.2f}%")

        # Validation with a progress bar
        model.eval()
        val_accs = []
        val_episodes = 100
        val_bar = tqdm(range(val_episodes), desc=f"Epoch {epoch} validation", leave=False)
        with torch.no_grad():
            for _ in val_bar:
                n_way_val = random.choice([5, 10])
                s_x, s_y, q_x, q_y = episodic_val.sample_episode(n_way_val, K_SHOT, Q_QUERY)
                s_x = s_x.to(device); q_x = q_x.to(device)
                s_y = s_y.to(device); q_y = q_y.to(device)
                s_emb, scale = model(s_x)
                q_emb, _ = model(q_x)
                _, vacc = prototypical_loss_cosine(s_emb, s_y, q_emb, q_y, n_way_val, K_SHOT, scale)
                val_accs.append(vacc)
                val_bar.set_postfix({"v_acc": f"{vacc*100:.2f}%"})

        mean_val = float(torch.tensor(val_accs).mean().item())
        tqdm.write(f"  val_acc (mean over {val_episodes} episodes) = {mean_val*100:.2f}%")

        # checkpoint best
        if mean_val > best_val:
            best_val = mean_val
            ckpt = {"epoch": epoch, "model_state": model.state_dict(), "opt_state": opt.state_dict(), "val_acc": best_val}
            torch.save(ckpt, os.path.join(CHECKPOINT_DIR, "protonet_best.pth"))
            tqdm.write(f"  --> New best val acc: {best_val*100:.2f}%  (checkpoint saved)")

        # Unfreeze after stage1 epochs
        if epoch == STAGE1_EPOCHS:
            tqdm.write("Switching to Stage 2: unfreezing backbone and creating new optimizer with param groups.")
            for p in model.features.parameters():
                p.requires_grad = True
            model.features.train()
            opt = torch.optim.Adam([
                {'params': model.features.parameters(), 'lr': LR_BACKBONE},
                {'params': model.fc.parameters(), 'lr': LR_HEAD},
                {'params': [model.log_scale], 'lr': LR_HEAD},
            ], weight_decay=WEIGHT_DECAY)
            scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=12, gamma=0.5)

    torch.save({"model_state": model.state_dict(), "best_val": best_val}, os.path.join(CHECKPOINT_DIR, "protonet_final.pth"))
    tqdm.write("Training complete. Best val acc: {:.2f}%".format(best_val*100))
    return model

# ----------------------------
# Meta-test with progress bar
# ----------------------------
class EpisodicDatasetFromFolders:
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.classes = sorted([d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))])
        self.by_class = {}
        for i, c in enumerate(self.classes):
            folder = os.path.join(root, c)
            imgs = [os.path.join(folder, f) for f in os.listdir(folder) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            if len(imgs) >= 1:
                self.by_class[i] = imgs
        self.class_ids = list(self.by_class.keys())
        if len(self.class_ids) == 0:
            raise ValueError("No classes found in novel root: " + root)

    def sample_episode(self, n_way, k_shot, q_query):
        chosen = random.sample(self.class_ids, n_way)
        support_x, support_y, query_x, query_y = [], [], [], []
        from PIL import Image
        for i, cid in enumerate(chosen):
            imgs = random.sample(self.by_class[cid], k_shot + q_query)
            for si in imgs[:k_shot]:
                im = Image.open(si).convert("RGB")
                support_x.append(self.transform(im))
                support_y.append(i)
            for qi in imgs[k_shot:]:
                im = Image.open(qi).convert("RGB")
                query_x.append(self.transform(im))
                query_y.append(i)
        return torch.stack(support_x), torch.tensor(support_y), torch.stack(query_x), torch.tensor(query_y)

def meta_test(model, novel_root):
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ])
    ds = EpisodicDatasetFromFolders(novel_root, transform=transform)
    model.eval()
    accs = []
    meta_bar = tqdm(range(META_TEST_EPISODES), desc="Meta-test episodes")
    with torch.no_grad():
        for _ in meta_bar:
            s_x, s_y, q_x, q_y = ds.sample_episode(META_TEST_N, META_TEST_K, META_TEST_Q)
            s_x, q_x = s_x.to(device), q_x.to(device)
            s_y, q_y = s_y.to(device), q_y.to(device)
            s_emb, scale = model(s_x)
            q_emb, _ = model(q_x)
            _, acc = prototypical_loss_cosine(s_emb, s_y, q_emb, q_y, META_TEST_N, META_TEST_K, scale)
            accs.append(acc)
            meta_bar.set_postfix({"acc": f"{acc*100:.2f}%"})
    accs = torch.tensor(accs)
    mean = accs.mean().item()
    std = accs.std().item()
    ci95 = 1.96 * std / math.sqrt(len(accs))
    tqdm.write(f"Meta-test {META_TEST_N}-way {META_TEST_K}-shot: mean={mean*100:.2f}%, ±{ci95*100:.2f}% (95% CI)")
    return mean, ci95

# ----------------------------
# Run
# ----------------------------
if __name__ == "__main__":
    model = train_protonet()
    # Example meta-test usage:
    # novel_root = "/path/to/novel_classes_folder"
    # meta_test(model, novel_root)




Epoch 1/10  train_loss=1.7088  train_acc=40.11%




  val_acc (mean over 100 episodes) = 40.37%
  --> New best val acc: 40.37%  (checkpoint saved)




Epoch 2/10  train_loss=1.5966  train_acc=43.42%




  val_acc (mean over 100 episodes) = 46.03%
  --> New best val acc: 46.03%  (checkpoint saved)




Epoch 3/10  train_loss=1.5646  train_acc=44.26%




  val_acc (mean over 100 episodes) = 45.39%




Epoch 4/10  train_loss=1.5094  train_acc=46.26%




  val_acc (mean over 100 episodes) = 45.93%




Epoch 5/10  train_loss=1.4429  train_acc=48.18%




  val_acc (mean over 100 episodes) = 46.81%
  --> New best val acc: 46.81%  (checkpoint saved)




Epoch 6/10  train_loss=1.3974  train_acc=50.00%




  val_acc (mean over 100 episodes) = 49.47%
  --> New best val acc: 49.47%  (checkpoint saved)




Epoch 7/10  train_loss=1.4049  train_acc=48.78%




  val_acc (mean over 100 episodes) = 48.97%




Epoch 8/10  train_loss=1.3611  train_acc=50.46%




  val_acc (mean over 100 episodes) = 46.93%
Switching to Stage 2: unfreezing backbone and creating new optimizer with param groups.




Epoch 9/10  train_loss=1.2235  train_acc=55.09%




  val_acc (mean over 100 episodes) = 54.17%
  --> New best val acc: 54.17%  (checkpoint saved)




Epoch 10/10  train_loss=1.0605  train_acc=61.73%




  val_acc (mean over 100 episodes) = 58.69%
  --> New best val acc: 58.69%  (checkpoint saved)
Training complete. Best val acc: 58.69%
