# Mamba3-JEPA Training Notebook

Train the **Mamba3-JEPA** world model — a State-Space Vision-Language architecture
that replaces the Transformer components of [VL-JEPA](https://arxiv.org/abs/2512.10942)
with **Mamba-3** SSM blocks from `kore-mamba`.

This notebook covers:
1. Architecture overview
2. Building the model from config
3. Two-stage training pipeline
4. Evaluation (embedding, classification, decoding)
5. Selective decoding for streaming
6. Scaling to production

## 1. Architecture

```
┌─────────────┐     ┌──────────────────┐     ┌─────────────┐
│  X-Encoder   │     │   Predictor       │     │  Y-Encoder   │
│  (ViT,frozen)│────▶│  (Mamba-3 SSM)    │     │  (Mamba-3)   │
│  image → Sv  │     │  ⟨Sv,Xq⟩ → Ŝy    │     │  text → Sy   │
└─────────────┘     └────────┬───────────┘     └──────┬──────┘
                             │                        │
                             ▼                        ▼
                     ┌──────────────────────────────────┐
                     │   InfoNCE Loss: align Ŝy ↔ Sy   │
                     └──────────────────────────────────┘
                             │
                             ▼ (inference only)
                     ┌──────────────────┐
                     │   Y-Decoder       │
                     │  (Mamba-3 LM)     │
                     │   Ŝy → text       │
                     └──────────────────┘
```

| Component | Original VL-JEPA | Mamba3-JEPA (Ours) |
|-----------|-----------------|--------------------|
| **X-Encoder** | V-JEPA 2 ViT-L (304M, frozen) | Same (ViT, frozen) |
| **Predictor** | Llama-3.2-1B (8 Transformer layers) | Mamba-3 MixerModel (12 SSM layers) |
| **Y-Encoder** | EmbeddingGemma-300M | Mamba-3 MixerModel (text encoder) |
| **Y-Decoder** | Lightweight Transformer LM | Mamba-3 LM (MambaLMHeadModel) |

### Why Mamba-3?

| Feature | Benefit for VL-JEPA |
|---------|--------------------|
| **Complex-valued A** (data-dependent RoPE) | Tracks rotational dynamics in visual latent space |
| **Trapezoidal discretization** | 2nd-order ODE accuracy for continuous video |
| **MIMO multi-head** | High-throughput processing of ViT patch embeddings |
| **O(N) sequence, O(1) decode** | Linear training, constant-time generation |

## 2. Setup

Build the `kore-vljepa` crate and run the training example.

In [None]:
# Clone and build Kore (if not already done)
!git clone https://github.com/KidIkaros/KORE.git 2>/dev/null || true
%cd KORE

# Verify the kore-vljepa crate compiles
!cargo check -p kore-vljepa 2>&1 | tail -5

In [None]:
# Run the full training example (tiny config, ~7 seconds)
!cargo run --example mamba3_jepa_train -p kore-vljepa --release 2>&1

## 3. Model Configuration

The `Mamba3JepaConfig` struct controls all four components.

### Tiny Preset (for testing)

```rust
let config = Mamba3JepaConfig::tiny();
// ViT:       d_model=32,  layers=2,  patch=4×4,  image=16×16
// Predictor: d_model=64,  layers=2,  d_state=16, headdim=16
// Y-Encoder: d_model=64,  layers=2,  vocab=256
// Y-Decoder: d_model=64,  layers=2,  prefix_len=4
// Shared embed dim: 32
```

### Small Preset (~150M trainable params)

```rust
let config = Mamba3JepaConfig::small();
// ViT:       d_model=1024, layers=24, patch=14×14, image=224×224 (V-JEPA 2 ViT-L)
// Predictor: d_model=1024, layers=12, d_state=128, headdim=64
// Y-Encoder: d_model=768,  layers=12, vocab=32000
// Y-Decoder: d_model=512,  layers=6,  prefix_len=8
// Shared embed dim: 1536
```

### Building the Model

```rust
use kore_vljepa::{Mamba3Jepa, Mamba3JepaConfig};

let config = Mamba3JepaConfig::small();
let model = Mamba3Jepa::new(config);
// model.x_encoder  — VisionEncoder (ViT, frozen)
// model.predictor   — Mamba3Predictor
// model.y_encoder   — Mamba3TextEncoder
// model.y_decoder   — Mamba3Decoder
```

## 4. Training Pipeline

VL-JEPA training has **two stages** (§3.2 of the paper):

### Stage 1: Query-Free Captioning Pretrain

The model learns to predict text embeddings from visual input alone.
Query tokens = target tokens (self-supervised).

```rust
// Stage 1: query = target (self-supervised)
let query_tokens = target_tokens.clone();

let output = model.train_forward(
    &images,         // (batch, C, H, W) flattened
    &query_tokens,   // (batch * n_qry,) flattened
    &target_tokens,  // (batch * n_tgt,) flattened
    batch, h, w,
    n_qry, n_tgt,
    temperature,     // τ = 0.07
);
// output.loss      — InfoNCE loss (scalar)
// output.predicted — (batch, embed_dim)
// output.target    — (batch, embed_dim)
```

### Stage 2: Query-Conditioned SFT

The model learns to predict target embeddings given a visual input AND a query prompt.
This is the key VL-JEPA innovation: predicting continuous embeddings instead of tokens.

```rust
// Stage 2: query ≠ target (supervised)
let output = model.train_forward(
    &images, &query_tokens, &target_tokens,
    batch, h, w, n_qry, n_tgt, temperature,
);
```

### Forward Pass Internals

```
train_forward(images, query_tokens, target_tokens):
  1. X-Encoder:  images → visual_tokens     (ViT, frozen)
  2. Embed:      query_tokens → query_embeds (shared embedding table)
  3. Predictor:  (visual_tokens, query_embeds) → predicted_embedding
     └─ VisionProjection → QueryProjection → concat → Mamba-3 layers → pool → proj
  4. Y-Encoder:  target_tokens → target_embedding
     └─ TokenEmbed → Mamba-3 layers → pool → proj
  5. InfoNCE:    loss = -log(exp(sim(Ŝy,Sy)/τ) / Σ exp(sim(Ŝy,Sj)/τ))
```

## 5. Training Hyperparameters

From VL-JEPA paper §3.2, adapted for Mamba-3:

| Parameter | Stage 1 (Pretrain) | Stage 2 (SFT) |
|-----------|-------------------|----------------|
| **Steps** | 50,000 | 30,000 |
| **Batch size** | 2,048 | 512 |
| **Base LR** | 1e-4 | 5e-5 |
| **Y-Encoder LR** | ×0.05 | ×0.05 |
| **X-Encoder** | Frozen | Frozen |
| **Optimizer** | AdamW | AdamW |
| **β** | (0.9, 0.95) | (0.9, 0.95) |
| **Weight decay** | 0.05 | 0.05 |
| **LR schedule** | Cosine + 2K warmup | Cosine + 2K warmup |
| **Temperature τ** | 0.07 | 0.07 |

### Cosine LR Schedule

```rust
fn cosine_lr(step: usize, warmup: usize, total: usize, base_lr: f32) -> f32 {
    if step < warmup {
        base_lr * (step as f32 / warmup as f32)
    } else {
        let progress = (step - warmup) as f32 / (total - warmup) as f32;
        base_lr * 0.5 * (1.0 + (PI * progress).cos())
    }
}
```

### Differential Learning Rates

The Y-Encoder trains at **5% of the base LR** to prevent catastrophic forgetting
of pretrained text representations. The X-Encoder (ViT) is completely frozen.

```rust
// Parameter groups for AdamW:
// Group 1: Predictor params     → lr = base_lr
// Group 2: Y-Encoder params     → lr = base_lr * 0.05
// Group 3: Y-Decoder params     → lr = base_lr
// Group 4: X-Encoder params     → lr = 0 (frozen)
```

## 6. Data Pipeline

Training data consists of `(visual_frames, query_text, target_text)` triplets.

### Recommended Datasets

| Dataset | Type | Size | Use |
|---------|------|------|-----|
| **CC3M** | Image-caption | 3M pairs | Stage 1 pretrain |
| **CC12M** | Image-caption | 12M pairs | Stage 1 pretrain |
| **WebVid-10M** | Video-caption | 10M clips | Stage 1 pretrain (video) |
| **VQAv2** | Image-QA | 1.1M QA pairs | Stage 2 SFT |
| **TextVQA** | Image-QA (OCR) | 45K images | Stage 2 SFT |
| **Kinetics-400** | Video classification | 300K clips | Evaluation |

### Data Format

```rust
// Each training sample:
struct TrainSample {
    image: Vec<f32>,         // (C, H, W) normalized to [-1, 1]
    query_tokens: Vec<usize>,  // tokenized query text
    target_tokens: Vec<usize>, // tokenized target text
}

// Stage 1: query_tokens == target_tokens (caption)
// Stage 2: query_tokens = "What is happening?" , target_tokens = "A cat is sleeping"
```

### Image Preprocessing

```rust
// 1. Resize to (image_size × image_size) — e.g., 224×224
// 2. Normalize: pixel = (pixel / 255.0 - mean) / std
//    mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]
// 3. Flatten to (C, H, W) = (3, 224, 224)
```

## 7. Evaluation

After training, the model supports three inference modes:

### 7a. Embedding Inference (Retrieval / Classification)

```rust
use kore_vljepa::{InferenceMode, InferenceOutput};

let output = model.infer(&image, &query, h, w, InferenceMode::Embedding);
if let InferenceOutput::Embedding(emb) = output {
    // emb: Vec<f32> of shape (shared_embed_dim,)
    // Use for nearest-neighbor classification or similarity ranking
}
```

### 7b. Zero-Shot Classification

Compare predicted embedding against candidate text embeddings:

```rust
// Encode class names through Y-Encoder
let cat_emb = model.y_encoder.forward(&tokenize("cat"), 1, n_tokens);
let dog_emb = model.y_encoder.forward(&tokenize("dog"), 1, n_tokens);
let candidates = [cat_emb, dog_emb].concat();

let class_idx = model.classify(&predicted_emb, &candidates, 2);
// class_idx: 0 = cat, 1 = dog
```

### 7c. Text Decoding (VQA)

```rust
let output = model.infer(
    &image, &query, h, w,
    InferenceMode::Decode { max_tokens: 50, bos_token: 1 },
);
if let InferenceOutput::Tokens(tokens) = output {
    let text = tokenizer.decode(&tokens);
    println!("Answer: {}", text);
}
```

## 8. Selective Decoding (Streaming)

Mamba-3's hidden state `h_t` naturally tracks semantic change. The `SelectiveDecoder`
monitors `‖h_t - h_{t-k}‖` and only triggers the expensive Y-Decoder when the state
drifts past a threshold. Ideal for streaming video — decode only at scene changes.

```rust
use kore_vljepa::{SelectiveDecoder, SelectiveDecodeConfig};

let config = SelectiveDecodeConfig {
    window_size: 8,       // compare against state from 8 frames ago
    drift_threshold: 0.5, // decode when drift exceeds this
    smoothing: 0.3,       // exponential smoothing factor
};
let mut selective = SelectiveDecoder::new(config);

// For each video frame:
for frame in video_stream {
    let ssm_state = get_predictor_state(&model, &frame);
    if selective.should_decode(&ssm_state) {
        let text = decode(&model, &frame, &query);
        println!("Scene change detected: {}", text);
    }
}
```

This avoids running the decoder on every frame, reducing compute by 5-10× on
typical video streams where most frames are visually similar.

## 9. Loading Pretrained Weights

### V-JEPA 2 ViT Weights (X-Encoder)

```rust
use kore_vljepa::{load_safetensors, load_vit_weights};

// Load V-JEPA 2 ViT-L weights from safetensors file
let tensors = load_safetensors("vjepa2_vit_l.safetensors")?;
load_vit_weights(&mut model.x_encoder, &tensors)?;
// X-Encoder is now frozen with pretrained weights
```

### Checkpoint Save/Load

```rust
// Save training checkpoint (predictor + y_encoder + y_decoder)
// Use safetensors format for interoperability
use safetensors::serialize;

let state = collect_state_dict(&model);
let bytes = serialize(&state, &None)?;
std::fs::write("mamba3_jepa_checkpoint.safetensors", bytes)?;
```

## 10. Scaling to Production

### Step-by-Step

1. **Replace synthetic data** with real image-text pairs (CC3M, WebVid)
2. **Load V-JEPA 2 ViT weights** via `load_vit_weights()`
3. **Use AdamW** from `kore-optim` with differential LR groups
4. **Scale config** to `Mamba3JepaConfig::small()` (150M trainable params)
5. **Add autograd** via `kore-autograd` for automatic differentiation
6. **Multi-GPU** via data parallelism (split batches across devices)

### Full Training Script Template

```rust
use kore_vljepa::*;

fn main() {
    // 1. Config
    let config = Mamba3JepaConfig::small();
    let model = Mamba3Jepa::new(config.clone());

    // 2. Load pretrained ViT
    let tensors = load_safetensors("vjepa2_vit_l.safetensors").unwrap();
    load_vit_weights(&mut model.x_encoder, &tensors).unwrap();

    // 3. Optimizer (differential LR)
    // let optimizer = AdamW::new([
    //     ParamGroup::new(model.predictor.params(), base_lr),
    //     ParamGroup::new(model.y_encoder.params(), base_lr * 0.05),
    //     ParamGroup::new(model.y_decoder.params(), base_lr),
    // ], weight_decay=0.05, betas=(0.9, 0.95));

    // 4. Stage 1: Query-free pretrain
    for step in 0..50_000 {
        let (images, _, targets) = load_batch("cc3m", batch=2048);
        let queries = targets.clone(); // self-supervised
        let lr = cosine_lr(step, 2000, 50_000, 1e-4);

        let output = model.train_forward(
            &images, &queries, &targets,
            2048, 224, 224, n_qry, n_tgt, 0.07,
        );
        // optimizer.step(output.loss, lr);
    }

    // 5. Stage 2: Query-conditioned SFT
    for step in 0..30_000 {
        let (images, queries, targets) = load_batch("vqav2", batch=512);
        let lr = cosine_lr(step, 2000, 30_000, 5e-5);

        let output = model.train_forward(
            &images, &queries, &targets,
            512, 224, 224, n_qry, n_tgt, 0.07,
        );
        // optimizer.step(output.loss, lr);
    }

    // 6. Save checkpoint
    // save_checkpoint(&model, "mamba3_jepa_small.safetensors");
}
```

## API Reference

### Core Types

| Type | Description |
|------|-------------|
| `Mamba3Jepa` | Full model: x_encoder + predictor + y_encoder + y_decoder |
| `Mamba3JepaConfig` | Top-level config (presets: `tiny()`, `small()`) |
| `TrainOutput` | Training forward result: `{ loss, predicted, target }` |
| `InferenceMode` | `Embedding` or `Decode { max_tokens, bos_token }` |
| `InferenceOutput` | `Embedding(Vec<f32>)` or `Tokens(Vec<usize>)` |
| `SelectiveDecoder` | SSM state drift monitor for streaming decode |

### Key Methods

| Method | Signature |
|--------|-----------|
| `Mamba3Jepa::new` | `(config: Mamba3JepaConfig) -> Self` |
| `train_forward` | `(&self, images, query_tokens, target_tokens, batch, h, w, n_qry, n_tgt, temperature) -> TrainOutput` |
| `infer` | `(&self, image, query_tokens, h, w, mode) -> InferenceOutput` |
| `classify` | `(&self, predicted, candidates, n_candidates) -> usize` |
| `info_nce_loss` | `(predictions, targets, batch, embed_dim, temperature) -> f32` |
| `load_vit_weights` | `(&mut VisionEncoder, &SafeTensors) -> Result<()>` |

### Config Structs

| Config | Key Fields |
|--------|------------|
| `VitConfig` | `patch_size, image_size, d_model, n_heads, n_layers, d_ff` |
| `Mamba3PredictorConfig` | `d_model, n_layers, d_state, expand, headdim, trapezoidal_alpha, embed_dim` |
| `Mamba3TextEncoderConfig` | `vocab_size, d_model, n_layers, d_state, embed_dim` |
| `Mamba3DecoderConfig` | `d_model, n_layers, vocab_size, prefix_len, embed_dim` |

In [None]:
# Run the tests to verify everything works
!cargo test -p kore-vljepa 2>&1 | tail -15