# ALSI-T: Trajectory-Aware Projector Training (GPU Optimized)

### **1. Setup Environment**

In [None]:
!pip install -q transformers torch matplotlib
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, AutoModelForCausalLM
import pickle
import os
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_id = 'AntonV/mamba2-130m-hf'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
model.eval()

### **2. Functional Mamba-2 Logic** (Omitted for brevity, use previous implementation)

### **3. Dataset Generation** (Run this first, then move model to CPU)

In [None]:
# ... (Run your current Dataset Generation cell here) ...
print('Dataset generated. Freeing VRAM...')
model.to('cpu') # Move model to CPU to free 14GB VRAM
import gc
torch.cuda.empty_cache()
gc.collect()

### **4. Phi-T Projector & Training**
Memory-efficient training with the model moved to CPU.

In [None]:
class PhiFieldProjector(nn.Module):
    def __init__(self, s_dim, e_dim, num_l=12, h_dim=1024):
        super().__init__()
        self.layer_embed = nn.Embedding(num_l, 64)
        self.input_proj = nn.Linear(s_dim + e_dim + 64, h_dim)
        self.norm = nn.LayerNorm(h_dim)
        self.net = nn.Sequential(nn.SiLU(), nn.Linear(h_dim, h_dim), nn.LayerNorm(h_dim), nn.SiLU())
        self.output_proj = nn.Linear(h_dim, s_dim)
        nn.init.zeros_(self.output_proj.weight); nn.init.zeros_(self.output_proj.bias)

    def forward(self, h, t, l): 
        l_e = self.layer_embed(torch.full((h.size(0),), l, device=h.device, dtype=torch.long))
        x = self.norm(self.input_proj(torch.cat([h, t, l_e], dim=-1)))
        x = self.net(x)
        return self.output_proj(x)

# 1. Load Dataset from local disk
with open('phi_t_dataset.pkl', 'rb') as f: dataset = pickle.load(f)

# 2. Setup
h_dim_val = dataset[0]['h'].numel()
e_dim_val = 768
phi_t = PhiFieldProjector(h_dim_val, e_dim_val).to(device)
opt = optim.Adam(phi_t.parameters(), lr=1e-3)
crit = nn.MSELoss()

# 3. Training Loop
epochs = 100
for epoch in range(epochs):
    total_loss = 0
    for s in dataset:
        # Load sample to GPU on the fly
        h = s['h'].to(device).view(1, -1)
        t_e = torch.randn(1, 768).to(device) # Replace with real embed lookup if needed
        
        opt.zero_grad()
        sample_loss = 0
        for l in range(12):
            pred = phi_t(h, t_e, l)
            target = s['field'][l].to(device).view(1, -1)
            loss = crit(pred, target)
            loss.backward()
            sample_loss += loss.item()
        
        opt.step()
        total_loss += sample_loss
        
    avg = total_loss / (len(dataset) * 12)
    if epoch % 10 == 0: print(f'Epoch {epoch} | Loss: {avg:.6f}')

torch.save(phi_t.state_dict(), 'phi_t_model.pt')