<a href="https://colab.research.google.com/github/OneFineStarstuff/Pinn/blob/main/align_curriculum_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Title: Multimodal AGI Alignment Curriculum (InfoNCE + Prototypes)
Author: Copilot
Provenance: Self-contained demo for multimodal contrastive alignment with a simple curriculum.
License: MIT

Overview:
- Synthetic multimodal "class" alignment: text tokens <-> generated image sprites
- Curriculum stages increase class count, image noise, and distractors
- InfoNCE (NT-Xent) symmetric loss with learnable temperature
- Optional prototype memory (simple world-model-ish prior) for class centroids
- CSV logging and checkpointing
"""

import os
import csv
import math
import time
import random
import argparse
from dataclasses import dataclass, asdict
from typing import Dict, Any, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


# ---------------------------
# Utilities and configuration
# ---------------------------

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


@dataclass
class AgentConfig:
    name: str = "MyAGIAgent"
    notes: str = "Multimodal InfoNCE with curriculum"
    device: str = "cpu"
    seed: int = 1337

    # Model dims
    text_vocab_size: int = 1024  # upper bound, curriculum stages pick <= this
    text_embed_dim: int = 128
    text_hidden_dim: int = 256
    img_channels: int = 1
    img_size: int = 28
    proj_dim: int = 128

    # Training
    lr: float = 3e-4
    weight_decay: float = 1e-4
    max_grad_norm: float = 1.0
    temperature_init: float = 0.07  # typical starting temp for contrastive
    use_world_model: bool = True    # enable prototype memory term
    proto_momentum: float = 0.9     # EMA update for prototypes
    proto_weight: float = 0.1       # weight for prototype consistency loss

    # Curriculum
    stages: int = 3
    steps_per_stage: int = 300
    eval_interval: int = 50
    batch_size: int = 64

    # Logging/checkpoints
    log_dir: str = "./runs"
    save_every_stage: bool = True


class AlignmentLogger:
    def __init__(self, log_dir: str, run_name: str):
        os.makedirs(log_dir, exist_ok=True)
        ts = time.strftime("%Y%m%d-%H%M%S")
        self.dir = os.path.join(log_dir, f"{run_name}_{ts}")
        os.makedirs(self.dir, exist_ok=True)
        self.csv_path = os.path.join(self.dir, "metrics.csv")
        self._init_csv()
        self.last_print = time.time()

    def _init_csv(self):
        with open(self.csv_path, "w", newline="") as f:
            w = csv.writer(f)
            w.writerow([
                "step", "stage", "loss", "loss_infonce", "loss_proto",
                "acc@1_text2img", "acc@1_img2text", "temperature", "lr"
            ])

    def log(self, step: int, stage: int, metrics: Dict[str, Any]):
        with open(self.csv_path, "a", newline="") as f:
            w = csv.writer(f)
            w.writerow([
                step, stage,
                float(metrics.get("loss", float("nan"))),
                float(metrics.get("loss_infonce", float("nan"))),
                float(metrics.get("loss_proto", 0.0)),
                float(metrics.get("acc_t2i", float("nan"))),
                float(metrics.get("acc_i2t", float("nan"))),
                float(metrics.get("temperature", float("nan"))),
                float(metrics.get("lr", float("nan"))),
            ])

    def console(self, step: int, stage: int, metrics: Dict[str, Any], throttle_sec: float = 1.0):
        now = time.time()
        if now - self.last_print >= throttle_sec:
            msg = (
                f"[stage {stage} | step {step}] "
                f"loss={metrics.get('loss', 'NA'):.4f} "
                f"(infonce={metrics.get('loss_infonce','NA'):.4f}, proto={metrics.get('loss_proto',0.0):.4f}) "
                f"acc@1 t2i={metrics.get('acc_t2i','NA'):.3f} i2t={metrics.get('acc_i2t','NA'):.3f} "
                f"T={metrics.get('temperature','NA'):.4f} lr={metrics.get('lr','NA'):.2e}"
            )
            print(msg, flush=True)
            self.last_print = now


# ---------------------------
# Synthetic multimodal dataset
# ---------------------------

def generate_sprite(class_id: int, size: int, variant: int, noise: float, device: str) -> torch.Tensor:
    """
    Creates a simple 2D sprite pattern deterministically from class_id and variant.
    Patterns: vertical/horizontal bars, diagonals, blocks; class_id controls which and where.
    """
    img = torch.zeros((size, size), dtype=torch.float32)
    cid = class_id % 8
    rng = (class_id * 131 + variant * 17) % (size - 2) + 1

    if cid == 0:
        img[:, rng] = 1.0
    elif cid == 1:
        img[rng, :] = 1.0
    elif cid == 2:
        for i in range(size):
            j = (i + rng) % size
            img[i, j] = 1.0
    elif cid == 3:
        for i in range(size):
            j = (rng - i) % size
            img[i, j] = 1.0
    elif cid == 4:
        img[rng-1:rng+1, rng-1:rng+1] = 1.0
    elif cid == 5:
        img[rng-1:rng+1, :] = 1.0
    elif cid == 6:
        img[:, rng-1:rng+1] = 1.0
    else:
        # checkerboard-ish
        for i in range(0, size, 2):
            for j in range((i + rng) % 2, size, 2):
                img[i, j] = 1.0

    # normalize pattern and add noise
    img = img + noise * torch.randn_like(img)
    img = (img - img.mean()) / (img.std() + 1e-6)
    return img.to(device)


def make_text_sequence(class_id: int, vocab_offset: int = 2) -> torch.Tensor:
    """
    Creates a tiny sequence: [BOS(0), class_token(vocab_offset + class_id), EOS(1)].
    """
    return torch.tensor([0, vocab_offset + class_id, 1], dtype=torch.long)


def sample_batch(
    batch_size: int,
    num_classes: int,
    img_size: int,
    noise: float,
    device: str,
    vocab_offset: int = 2
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Returns:
      text_tokens: (B, L) with token ids
      images: (B, 1, H, W)
      labels: (B,) class ids [0..num_classes-1]
    """
    labels = torch.randint(low=0, high=num_classes, size=(batch_size,), device=device)
    # build sequences
    seqs = [make_text_sequence(int(c), vocab_offset) for c in labels.cpu().tolist()]
    L = len(seqs[0])
    text_tokens = torch.stack(seqs, dim=0).to(device)  # (B, L)

    # images
    imgs = []
    for i, c in enumerate(labels.cpu().tolist()):
        variant = i  # distinct within batch
        sprite = generate_sprite(c, size=img_size, variant=variant, noise=noise, device=device)
        imgs.append(sprite.unsqueeze(0))  # (1, H, W)
    images = torch.stack(imgs, dim=0).to(device)  # (B, 1, H, W)
    return text_tokens, images, labels


# ---------------------------
# Encoders and agent
# ---------------------------

class TextEncoder(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, hidden_dim: int, proj_dim: int):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True)
        self.proj = nn.Linear(hidden_dim, proj_dim)

    def forward(self, tokens: torch.Tensor) -> torch.Tensor:
        x = self.embed(tokens)               # (B, L, E)
        _, h = self.gru(x)                   # (1, B, H)
        h = h.squeeze(0)                     # (B, H)
        z = self.proj(h)                     # (B, D)
        z = F.normalize(z, dim=-1)
        return z


class ImageEncoder(nn.Module):
    def __init__(self, in_ch: int, proj_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),                 # 14x14 for 28x28
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),                 # 7x7
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1)),    # 128x1x1
        )
        self.proj = nn.Linear(128, proj_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.net(x).flatten(1)           # (B, 128)
        z = self.proj(h)                     # (B, D)
        z = F.normalize(z, dim=-1)
        return z


class MyAGIAgent(nn.Module):
    def __init__(self, cfg: AgentConfig, vocab_size_cap: int):
        super().__init__()
        self.cfg = cfg
        self.text_enc = TextEncoder(
            vocab_size=vocab_size_cap,
            embed_dim=cfg.text_embed_dim,
            hidden_dim=cfg.text_hidden_dim,
            proj_dim=cfg.proj_dim
        )
        self.img_enc = ImageEncoder(
            in_ch=cfg.img_channels,
            proj_dim=cfg.proj_dim
        )
        # learnable temperature (clamped)
        self.logit_scale = nn.Parameter(torch.tensor(math.log(1.0 / cfg.temperature_init)))
        # prototype memory: class_id -> vector
        self.prototypes: Optional[torch.Tensor] = None  # (num_classes, D)

    def reset_prototypes(self, num_classes: int, device: str):
        self.prototypes = F.normalize(torch.randn(num_classes, self.cfg.proj_dim, device=device), dim=-1)

    def temperature(self) -> torch.Tensor:
        # temperature = exp(-logit_scale)
        return torch.exp(-self.logit_scale).clamp(1e-3, 1.0)

    def infonce_loss(self, z_t: torch.Tensor, z_i: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, float]]:
        # cosine similarities scaled by 1/T
        T = self.temperature()
        logits = (z_t @ z_i.t()) / T
        labels = torch.arange(z_t.size(0), device=z_t.device)
        loss_t2i = F.cross_entropy(logits, labels)
        loss_i2t = F.cross_entropy(logits.t(), labels)
        loss = 0.5 * (loss_t2i + loss_i2t)

        with torch.no_grad():
            acc_t2i = (logits.argmax(dim=1) == labels).float().mean().item()
            acc_i2t = (logits.t().argmax(dim=1) == labels).float().mean().item()

        return loss, {
            "loss_infonce": float(loss.item()),
            "acc_t2i": float(acc_t2i),
            "acc_i2t": float(acc_i2t),
            "temperature": float(T.item()),
        }

    def prototype_loss(self, z_t: torch.Tensor, z_i: torch.Tensor, class_ids: torch.Tensor) -> torch.Tensor:
        # encourage both modalities to be close to class prototype
        if self.prototypes is None:
            return torch.tensor(0.0, device=z_t.device)

        p = self.prototypes[class_ids]  # (B, D)
        # cosine distance -> 1 - cosine similarity
        d_t = 1.0 - (z_t * p).sum(dim=-1)
        d_i = 1.0 - (z_i * p).sum(dim=-1)
        return (d_t.mean() + d_i.mean()) * 0.5

    @torch.no_grad()
    def update_prototypes(self, class_ids: torch.Tensor, z_t: torch.Tensor, z_i: torch.Tensor):
        if self.prototypes is None:
            return
        momentum = self.cfg.proto_momentum
        z = F.normalize((z_t + z_i) * 0.5, dim=-1)  # fused representation
        # EMA update per class present in batch
        for cls in class_ids.unique():
            mask = (class_ids == cls)
            if mask.any():
                mean_z = F.normalize(z[mask].mean(dim=0), dim=-1)
                self.prototypes[cls] = F.normalize(momentum * self.prototypes[cls] + (1 - momentum) * mean_z, dim=-1)

    def forward(self, tokens: torch.Tensor, images: torch.Tensor, class_ids: torch.Tensor) -> Dict[str, Any]:
        z_t = self.text_enc(tokens)     # (B, D)
        z_i = self.img_enc(images)      # (B, D)

        loss_infonce, stats = self.infonce_loss(z_t, z_i)
        loss_proto = torch.tensor(0.0, device=images.device)
        if self.cfg.use_world_model:
            loss_proto = self.prototype_loss(z_t, z_i, class_ids) * self.cfg.proto_weight

        loss = loss_infonce + loss_proto
        out = {"loss": loss, "loss_infonce": loss_infonce, "loss_proto": loss_proto}
        out.update(stats)
        return out


# ---------------------------
# Adapter
# ---------------------------

class AgentAdapter:
    def __init__(self, agent: MyAGIAgent, cfg: AgentConfig, optimizer: torch.optim.Optimizer, scheduler=None):
        self.agent = agent
        self.cfg = cfg
        self.optimizer = optimizer
        self.scheduler = scheduler

    def train_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> Dict[str, Any]:
        self.agent.train()
        tokens, images, labels = batch
        out = self.agent(tokens, images, labels)
        loss: torch.Tensor = out["loss"]
        self.optimizer.zero_grad(set_to_none=True)
        loss.backward()
        nn.utils.clip_grad_norm_(self.agent.parameters(), self.cfg.max_grad_norm)
        self.optimizer.step()
        if self.scheduler is not None:
            self.scheduler.step()
        # prototype update after step (target moves with model)
        if self.agent.cfg.use_world_model:
            with torch.no_grad():
                z_t = self.agent.text_enc(tokens)
                z_i = self.agent.img_enc(images)
                self.agent.update_prototypes(labels, z_t, z_i)
        # collect scalars
        metrics = {
            "loss": float(out["loss"].item()),
            "loss_infonce": float(out["loss_infonce"].item()),
            "loss_proto": float(out["loss_proto"].item()),
            "acc_t2i": float(out["acc_t2i"]),
            "acc_i2t": float(out["acc_i2t"]),
            "temperature": float(out["temperature"]),
            "lr": float(self.optimizer.param_groups[0]["lr"])
        }
        return metrics

    @torch.no_grad()
    def eval_batch(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> Dict[str, Any]:
        self.agent.eval()
        tokens, images, labels = batch
        out = self.agent(tokens, images, labels)
        metrics = {
            "loss": float(out["loss"].item()),
            "loss_infonce": float(out["loss_infonce"].item()),
            "loss_proto": float(out["loss_proto"].item()),
            "acc_t2i": float(out["acc_t2i"]),
            "acc_i2t": float(out["acc_i2t"]),
            "temperature": float(out["temperature"]),
            "lr": float(self.optimizer.param_groups[0]["lr"])
        }
        return metrics


# ---------------------------
# Curriculum trainer
# ---------------------------

def get_stage_params(stage_idx: int) -> Dict[str, Any]:
    # progressively harder: more classes, more noise
    stage_specs = [
        {"num_classes": 4,  "noise": 0.02},
        {"num_classes": 8,  "noise": 0.06},
        {"num_classes": 16, "noise": 0.10},
    ]
    if stage_idx < len(stage_specs):
        return stage_specs[stage_idx]
    # fallback if more stages requested
    n = 16 * (2 ** (stage_idx - 2))
    return {"num_classes": min(128, n), "noise": min(0.20, 0.10 + 0.02 * (stage_idx - 2))}


def create_optimizer(agent: MyAGIAgent, cfg: AgentConfig):
    params = [
        {"params": [p for n, p in agent.named_parameters() if "logit_scale" not in n]},
        {"params": [agent.logit_scale], "lr": cfg.lr * 0.1},  # smaller lr for temperature
    ]
    opt = torch.optim.AdamW(params, lr=cfg.lr, weight_decay=cfg.weight_decay)
    return opt


def run_curriculum(cfg: AgentConfig):
    set_seed(cfg.seed)
    device = cfg.device if torch.cuda.is_available() and cfg.device.startswith("cuda") else "cpu"

    # logging setup
    logger = AlignmentLogger(cfg.log_dir, cfg.name)

    global_step = 0
    vocab_cap = cfg.text_vocab_size
    agent = None
    adapter = None

    for stage in range(cfg.stages):
        sp = get_stage_params(stage)
        num_classes = sp["num_classes"]
        noise = sp["noise"]

        # initialize agent lazily with vocab big enough (BOS=0, EOS=1, class tokens up to offset+num_classes-1)
        vocab_size = max(2 + num_classes, 4)
        vocab_size = min(vocab_size, vocab_cap)
        if agent is None:
            agent = MyAGIAgent(cfg, vocab_size_cap=vocab_size).to(device)
            agent.reset_prototypes(num_classes=num_classes, device=device)
            optimizer = create_optimizer(agent, cfg)
            adapter = AgentAdapter(agent, cfg, optimizer)
        else:
            # if number of classes grows, resize prototypes
            if agent.prototypes is None or agent.prototypes.size(0) != num_classes:
                agent.reset_prototypes(num_classes=num_classes, device=device)

        # train loop for this stage
        for step in range(cfg.steps_per_stage):
            tokens, images, labels = sample_batch(
                batch_size=cfg.batch_size,
                num_classes=num_classes,
                img_size=cfg.img_size,
                noise=noise,
                device=device
            )
            metrics = adapter.train_step((tokens, images, labels))
            logger.log(global_step, stage, metrics)
            logger.console(global_step, stage, metrics, throttle_sec=0.5)
            global_step += 1

            # periodic eval
            if (step + 1) % cfg.eval_interval == 0:
                with torch.no_grad():
                    etok, eimg, elab = sample_batch(
                        batch_size=cfg.batch_size,
                        num_classes=num_classes,
                        img_size=cfg.img_size,
                        noise=noise,
                        device=device
                    )
                    eval_metrics = adapter.eval_batch((etok, eimg, elab))
                    logger.log(global_step, stage, {**eval_metrics, "lr": metrics["lr"]})
                    logger.console(global_step, stage, eval_metrics, throttle_sec=0.0)

        # checkpoint end of stage
        if cfg.save_every_stage:
            ckpt = {
                "config": asdict(cfg),
                "stage": stage,
                "model_state": agent.state_dict(),
                "optimizer_state": adapter.optimizer.state_dict(),
                "vocab_size_cap": vocab_size,
            }
            path = os.path.join(logger.dir, f"checkpoint_stage{stage}.pt")
            torch.save(ckpt, path)
            print(f"Saved checkpoint: {path}")

    # final save
    final_path = os.path.join(logger.dir, "final_model.pt")
    torch.save({
        "config": asdict(cfg),
        "model_state": agent.state_dict(),
    }, final_path)
    print(f"Saved final model: {final_path}")
    print(f"CSV metrics: {logger.csv_path}")


# ---------------------------
# CLI
# ---------------------------

def parse_args() -> AgentConfig:
    p = argparse.ArgumentParser(description="Multimodal AGI Alignment Curriculum (InfoNCE + Prototypes)")
    p.add_argument("--name", type=str, default="MyAGIAgent", help="Run name for logging")
    p.add_argument("--device", type=str, default="cpu", help="cpu or cuda:0")
    p.add_argument("--seed", type=int, default=1337)
    p.add_argument("--lr", type=float, default=3e-4)
    p.add_argument("--batch_size", type=int, default=64)
    p.add_argument("--stages", type=int, default=3)
    p.add_argument("--steps_per_stage", type=int, default=300)
    p.add_argument("--eval_interval", type=int, default=50)
    p.add_argument("--use_world_model", action="store_true")
    p.add_argument("--no_world_model", dest="use_world_model", action="store_false")
    p.set_defaults(use_world_model=True)
    p.add_argument("--proto_weight", type=float, default=0.1)
    p.add_argument("--proto_momentum", type=float, default=0.9)
    p.add_argument("--log_dir", type=str, default="./runs")
    p.add_argument("--img_size", type=int, default=28)
    p.add_argument("--text_vocab_size", type=int, default=1024)
    p.add_argument("--temperature_init", type=float, default=0.07)
    args = p.parse_args()

    cfg = AgentConfig(
        name=args.name,
        device=args.device,
        seed=args.seed,
        lr=args.lr,
        batch_size=args.batch_size,
        stages=args.stages,
        steps_per_stage=args.steps_per_stage,
        eval_interval=args.eval_interval,
        use_world_model=args.use_world_model,
        proto_weight=args.proto_weight,
        proto_momentum=args.proto_momentum,
        log_dir=args.log_dir,
        img_size=args.img_size,
        text_vocab_size=args.text_vocab_size,
        temperature_init=args.temperature_init,
    )
    return cfg  # ✅ now inside the function


def main():
    cfg = parse_args()
    print("Config:", cfg)
    run_curriculum(cfg)


if __name__ == "__main__":  # ✅ correct double underscores
    main()