<a href="https://colab.research.google.com/github/OneFineStarstuff/Pinn/blob/main/maml_synthetic_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# maml_synthetic.py
# Minimal, reproducible MAML training on synthetic N-way K-shot Gaussian classification.
# - Gaussian task sampler (new task each episode)
# - Higher- or first-order MAML
# - Meta-train/val loop with accuracy and checkpointing
# - Deterministic seeding and safe defaults

import os
import math
import time
import argparse
from collections import OrderedDict

import torch
from torch import nn, optim
import torch.nn.functional as F


# ---------------------------
# Utilities
# ---------------------------

def set_seed(seed: int = 42, deterministic: bool = True):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def accuracy_from_logits(logits: torch.Tensor, targets: torch.Tensor) -> float:
    preds = logits.argmax(dim=-1)
    return (preds == targets).float().mean().item()


# ---------------------------
# Model: functional forward
# ---------------------------

class SimpleNet(nn.Module):
    def __init__(self, in_dim=2, hidden=40, out_dim=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, out_dim)
        )

    def forward(self, x):
        return self.net(x)

    def forward_with_weights(self, x, weights: "OrderedDict[str, torch.Tensor]"):
        # Mirror nn.Sequential: Linear -> ReLU -> Linear
        x = F.linear(x, weights['net.0.weight'], weights['net.0.bias'])
        x = F.relu(x)
        x = F.linear(x, weights['net.2.weight'], weights['net.2.bias'])
        return x


# ---------------------------
# MAML meta-learner
# ---------------------------

class MAML(nn.Module):
    def __init__(self, model: nn.Module, lr_inner: float = 0.01, first_order: bool = False):
        super().__init__()
        self.model = model
        self.lr_inner = lr_inner
        self.first_order = first_order

    def adapt(self, support_x, support_y, n_inner_steps: int, create_graph: bool):
        # Start from current meta-parameters
        fast_weights = OrderedDict(self.model.named_parameters())

        for _ in range(n_inner_steps):
            y_pred = self.model.forward_with_weights(support_x, fast_weights)
            loss = F.cross_entropy(y_pred, support_y)
            grads = torch.autograd.grad(
                loss,
                fast_weights.values(),
                create_graph=create_graph,  # second-order if True
                retain_graph=create_graph
            )
            fast_weights = OrderedDict(
                (name, param - self.lr_inner * g)
                for (name, param), g in zip(fast_weights.items(), grads)
            )
        return fast_weights

    def forward(self, task_data, n_inner_steps: int = 1, create_graph: bool = None):
        if create_graph is None:
            create_graph = not self.first_order

        query_losses = []
        query_accs = []

        for support_x, support_y, query_x, query_y in task_data:
            fast_weights = self.adapt(support_x, support_y, n_inner_steps, create_graph=create_graph)
            query_logits = self.model.forward_with_weights(query_x, fast_weights)
            q_loss = F.cross_entropy(query_logits, query_y)
            query_losses.append(q_loss)
            query_accs.append(accuracy_from_logits(query_logits, query_y))

        meta_loss = torch.stack(query_losses).mean()
        meta_acc = float(sum(query_accs) / len(query_accs))
        return meta_loss, meta_acc


# ---------------------------
# Synthetic task sampler
# ---------------------------

class GaussianNWayKShot:
    """
    N-way K-shot tasks in R^in_dim:
    - Each class ~ N(mu_c, sigma^2 I)
    - For each task, sample class centers and scales
    - Return (support_x, support_y, query_x, query_y)
    """
    def __init__(
        self,
        n_way: int,
        k_shot: int,
        q_queries: int,
        in_dim: int = 2,
        radius: float = 5.0,
        scale_range=(0.6, 1.2),
        device: str = "cpu",
        seed: int = 1234
    ):
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_queries = q_queries
        self.in_dim = in_dim
        self.radius = radius
        self.scale_range = scale_range
        self.device = device
        self._g = torch.Generator(device="cpu").manual_seed(seed)

    def _sample_centers(self):
        # Sample class centers on a ring (more separable) with jitter
        angles = torch.rand(self.n_way, generator=self._g) * 2 * math.pi
        base = torch.stack([torch.cos(angles), torch.sin(angles)], dim=1)  # (n_way, 2)
        if self.in_dim > 2:
            pad = torch.zeros(self.n_way, self.in_dim - 2)
            centers = torch.cat([base, pad], dim=1)
        else:
            centers = base[:, :self.in_dim]
        centers = centers * self.radius
        # Add small jitter
        centers = centers + 0.2 * torch.randn_like(centers, generator=self._g)
        return centers

    def _sample_sigma(self):
        low, high = self.scale_range
        return low + (high - low) * torch.rand(self.n_way, generator=self._g)

    def sample_task(self):
        centers = self._sample_centers()  # (n_way, in_dim)
        sigmas = self._sample_sigma()     # (n_way,)

        n_support = self.k_shot
        n_query = self.q_queries

        xs_support = []
        ys_support = []
        xs_query = []
        ys_query = []

        for c in range(self.n_way):
            mu = centers[c]
            sigma = sigmas[c]
            total = n_support + n_query
            pts = mu + sigma * torch.randn(total, self.in_dim, generator=self._g)
            xs_support.append(pts[:n_support])
            xs_query.append(pts[n_support:])
            ys_support.append(torch.full((n_support,), c, dtype=torch.long))
            ys_query.append(torch.full((n_query,), c, dtype=torch.long))

        support_x = torch.cat(xs_support, dim=0).to(self.device)
        support_y = torch.cat(ys_support, dim=0).to(self.device)
        query_x = torch.cat(xs_query, dim=0).to(self.device)
        query_y = torch.cat(ys_query, dim=0).to(self.device)

        # Shuffle within support and query for good measure
        def _shuffle(x, y):
            idx = torch.randperm(x.size(0), generator=self._g)
            return x[idx].to(self.device), y[idx].to(self.device)

        support_x, support_y = _shuffle(support_x, support_y)
        query_x, query_y = _shuffle(query_x, query_y)

        return support_x, support_y, query_x, query_y

    def sample_meta_batch(self, batch_size: int):
        return [self.sample_task() for _ in range(batch_size)]


# ---------------------------
# Training and evaluation
# ---------------------------

def train_maml(
    n_way=3,
    k_shot=5,
    q_queries=15,
    in_dim=2,
    inner_steps=5,
    inner_lr=0.4,
    first_order=False,
    outer_lr=1e-3,
    meta_batch=16,
    epochs=3000,
    val_every=50,
    seed=42,
    device="cuda" if torch.cuda.is_available() else "cpu",
    checkpoint_path="./checkpoints/maml_synthetic.pt",
):
    set_seed(seed, deterministic=True)
    os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)

    # Build components
    model = SimpleNet(in_dim=in_dim, hidden=40, out_dim=n_way).to(device)
    maml = MAML(model, lr_inner=inner_lr, first_order=first_order).to(device)
    outer_opt = optim.Adam(maml.parameters(), lr=outer_lr)

    # Task samplers (distinct RNG seeds)
    train_sampler = GaussianNWayKShot(
        n_way=n_way, k_shot=k_shot, q_queries=q_queries,
        in_dim=in_dim, radius=5.0, scale_range=(0.6, 1.2),
        device=device, seed=seed + 1
    )
    val_sampler = GaussianNWayKShot(
        n_way=n_way, k_shot=k_shot, q_queries=q_queries,
        in_dim=in_dim, radius=5.0, scale_range=(0.6, 1.2),
        device=device, seed=seed + 999
    )

    best_val_acc = 0.0
    ema_loss = None

    t0 = time.time()
    for epoch in range(1, epochs + 1):
        maml.train()
        task_batch = train_sampler.sample_meta_batch(meta_batch)
        outer_opt.zero_grad()

        meta_loss, meta_acc = maml(task_batch, n_inner_steps=inner_steps, create_graph=not first_order)
        meta_loss.backward()

        torch.nn.utils.clip_grad_norm_(maml.parameters(), max_norm=5.0)
        outer_opt.step()

        # Smooth loss for readability
        ema_loss = meta_loss.item() if ema_loss is None else 0.9 * ema_loss + 0.1 * meta_loss.item()

        if epoch % val_every == 0 or epoch == 1:
            maml.eval()
            with torch.no_grad():
                val_tasks = val_sampler.sample_meta_batch(meta_batch)
                # No second-order graph for eval
                val_loss, val_acc = maml(val_tasks, n_inner_steps=inner_steps, create_graph=False)

            is_best = val_acc > best_val_acc
            if is_best:
                best_val_acc = val_acc
                torch.save({
                    "epoch": epoch,
                    "model_state": maml.state_dict(),
                    "config": {
                        "n_way": n_way, "k_shot": k_shot, "q_queries": q_queries,
                        "in_dim": in_dim, "inner_steps": inner_steps, "inner_lr": inner_lr,
                        "first_order": first_order, "outer_lr": outer_lr,
                    }
                }, checkpoint_path)

            print(
                f"[{epoch:04d}] "
                f"train_loss(EMA)={ema_loss:.4f} | "
                f"meta_acc={meta_acc:.3f} | "
                f"val_loss={val_loss.item():.4f} | "
                f"val_acc={val_acc:.3f} | "
                f"best_val_acc={best_val_acc:.3f}"
            )

    dt = time.time() - t0
    print(f"Done. Best val acc={best_val_acc:.3f}. Elapsed {dt/60:.1f} min. Checkpoint: {checkpoint_path}")


# ---------------------------
# CLI
# ---------------------------

def parse_args():
    p = argparse.ArgumentParser(description="MAML on synthetic Gaussian N-way K-shot classification.")
    p.add_argument("--n-way", type=int, default=3, help="Number of classes per task")
    p.add_argument("--k-shot", type=int, default=5, help="Support examples per class")
    p.add_argument("--q-queries", type=int, default=15, help="Query examples per class")
    p.add_argument("--in-dim", type=int, default=2, help="Input dimensionality")
    p.add_argument("--inner-steps", type=int, default=5, help="Inner adaptation steps")
    p.add_argument("--inner-lr", type=float, default=0.4, help="Inner adaptation learning rate")
    p.add_argument("--first-order", action="store_true", help="Use first-order MAML (no second-order grads)")
    p.add_argument("--outer-lr", type=float, default=1e-3, help="Outer learning rate")
    p.add_argument("--meta-batch", type=int, default=16, help="Tasks per meta-update")
    p.add_argument("--epochs", type=int, default=3000, help="Number of meta-training iterations")
    p.add_argument("--val-every", type=int, default=50, help="Validation interval")
    p.add_argument("--seed", type=int, default=42, help="Random seed")
    p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device")
    p.add_argument("--checkpoint", type=str, default="./checkpoints/maml_synthetic.pt", help="Checkpoint path")
    return p.parse_args()


if __name__ == "__main__":
    args = parse_args()
    train_maml(
        n_way=args.n_way,
        k_shot=args.k_shot,
        q_queries=args.q_queries,
        in_dim=args.in_dim,
        inner_steps=args.inner_steps,
        inner_lr=args.inner_lr,
        first_order=args.first_order,
        outer_lr=args.outer_lr,
        meta_batch=args.meta_batch,
        epochs=args.epochs,
        val_every=args.val_every,
        seed=args.seed,
        device=args.device,
        checkpoint_path=args.checkpoint
    )