# üß™ Cortex-3: Battle Arena & Systematic Search

## üö® Diagn√≥stico de Errores
El usuario report√≥ salida "basura" (``). Esto ocurre en modelos **Byte-Level** cuando:
1.  **Falta de Convergencia**: El modelo no ha aprendido las reglas b√°sicas de UTF-8 (que ciertos bytes siempre van juntos).
2.  **Temperatura Alta**: El muestreo es demasiado ca√≥tico.
3.  **Arquitectura Inestable**: Mamba o MoE pueden tener gradientes que explotan.

## ‚öîÔ∏è La Soluci√≥n: Battle Arena
En lugar de confiar en una evoluci√≥n ciega, vamos a enfrentar a las arquitecturas en igualdad de condiciones.
Entrenaremos 3 modelos simult√°neamente:
1.  üîµ **Transformer Puro** (Baseline)
2.  üü¢ **Mamba Puro** (SSM)
3.  üî¥ **Cortex Hybrid** (Nuestra apuesta)

Veremos cu√°l converge m√°s r√°pido y cu√°l genera texto limpio.

---

In [None]:
# 0. Configuraci√≥n Robusta
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pandas as pd
import random
from IPython.display import clear_output, display

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"üöÄ Cortex-3 Engine: {device.upper()}")

def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

set_seed(42)

In [None]:
# --- Componentes (Mismos de antes, asegurando estabilidad) ---

class MambaBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.in_proj = nn.Linear(d_model, d_model * 2)
        self.out_proj = nn.Linear(d_model, d_model)
        self.conv = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1, groups=d_model)
        
    def forward(self, x):
        B, L, D = x.shape
        x_and_res = self.in_proj(x)
        x_val, res = x_and_res.chunk(2, dim=-1)
        x_val = x_val.transpose(1, 2)
        x_val = self.conv(x_val)
        x_val = x_val.transpose(1, 2)
        x_val = F.silu(x_val)
        return self.out_proj(x_val * F.sigmoid(res))

class CortexOrganism(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embedding = nn.Embedding(256, config['d_model'])
        self.layers = nn.ModuleList()
        
        for i in range(config['n_layers']):
            if config['backbone'] == 'mamba':
                self.layers.append(MambaBlock(config['d_model']))
            elif config['backbone'] == 'transformer':
                self.layers.append(nn.TransformerEncoderLayer(
                    d_model=config['d_model'], nhead=config['n_heads'], 
                    dim_feedforward=4*config['d_model'], batch_first=True
                ))
            elif config['backbone'] == 'hybrid':
                if i % 2 == 0: self.layers.append(MambaBlock(config['d_model']))
                else: self.layers.append(nn.TransformerEncoderLayer(
                    d_model=config['d_model'], nhead=config['n_heads'], 
                    dim_feedforward=4*config['d_model'], batch_first=True
                ))
        
        self.ln_f = nn.LayerNorm(config['d_model'])
        self.head = nn.Linear(config['d_model'], 256)

    def forward(self, idx, targets=None):
        x = self.embedding(idx)
        for layer in self.layers:
            x = layer(x)
        x = self.ln_f(x)
        logits = self.head(x)
        
        loss = None
        if targets is not None:
            B, T, C = logits.shape
            loss = F.cross_entropy(logits.view(B*T, C), targets.view(B*T))
        return logits, loss

## ‚öîÔ∏è Battle Arena: Comparativa Directa
Entrenamos los 3 modelos a la vez con los mismos datos.

In [None]:
# Datos Dummy (Reemplazar con Scraper si hay internet)
dummy_data = torch.randint(0, 256, (5000,), dtype=torch.long)
def get_batch():
    ix = torch.randint(len(dummy_data) - 64, (32,))
    x = torch.stack([dummy_data[i:i+64] for i in ix]).to(device)
    y = torch.stack([dummy_data[i+1:i+65] for i in ix]).to(device)
    return x, y

# Configuraciones de los Contendientes
base_config = {'n_layers': 4, 'd_model': 256, 'n_heads': 4, 'learning_rate': 1e-3}

models = {
    'Transformer': CortexOrganism({**base_config, 'backbone': 'transformer'}).to(device),
    'Mamba': CortexOrganism({**base_config, 'backbone': 'mamba'}).to(device),
    'Hybrid': CortexOrganism({**base_config, 'backbone': 'hybrid'}).to(device)
}

optimizers = {name: torch.optim.AdamW(m.parameters(), lr=1e-3) for name, m in models.items()}
history = {name: [] for name in models}

print("üîî ¬°Que comience la batalla!")

for step in range(500):
    xb, yb = get_batch()
    
    for name, model in models.items():
        _, loss = model(xb, yb)
        
        optimizers[name].zero_grad()
        loss.backward()
        # Gradient Clipping para evitar explosiones (Crucial para Mamba)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizers[name].step()
        
        history[name].append(loss.item())
    
    if step % 50 == 0:
        clear_output(wait=True)
        plt.figure(figsize=(10, 6))
        for name, losses in history.items():
            plt.plot(losses, label=name)
        plt.title("Battle Arena: Loss Comparison")
        plt.xlabel("Iteraciones")
        plt.ylabel("Loss")
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.show()
        print(f"Step {step}: T={history['Transformer'][-1]:.3f}, M={history['Mamba'][-1]:.3f}, H={history['Hybrid'][-1]:.3f}")

## üîç B√∫squeda Sistem√°tica (Grid Search)
En lugar de azar, probamos combinaciones espec√≠ficas para encontrar estabilidad.

In [None]:
def grid_search():
    # Espacio de b√∫squeda reducido y sensato
    layers_options = [2, 4]
    dim_options = [128, 256]
    
    results = []
    
    print("üîç Iniciando Grid Search...")
    for n_layers in layers_options:
        for d_model in dim_options:
            config = {'n_layers': n_layers, 'd_model': d_model, 'n_heads': 4, 'backbone': 'hybrid'}
            model = CortexOrganism(config).to(device)
            optim = torch.optim.AdamW(model.parameters(), lr=1e-3)
            
            # Sprint corto
            losses = []
            for _ in range(50):
                xb, yb = get_batch()
                _, loss = model(xb, yb)
                optim.zero_grad()
                loss.backward()
                optim.step()
                losses.append(loss.item())
            
            final_loss = sum(losses[-10:]) / 10
            results.append({'layers': n_layers, 'dim': d_model, 'loss': final_loss})
            print(f"   Config [L={n_layers}, D={d_model}] -> Loss: {final_loss:.4f}")
    
    # Mostrar mapa de calor de resultados
    df = pd.DataFrame(results)
    pivot = df.pivot(index='layers', columns='dim', values='loss')
    plt.figure(figsize=(6, 4))
    sns.heatmap(pivot, annot=True, cmap='viridis_r') # Invertido: Azul es mejor (menor loss)
    plt.title("Grid Search Results (Loss)")
    plt.show()

grid_search()

## üõ†Ô∏è Generaci√≥n Robusta (Fixing the Garbage Output)
Aqu√≠ solucionamos el problema de los caracteres extra√±os (``).

In [None]:
def safe_generate(model, prompt, max_len=100, temperature=0.7):
    model.eval()
    idx = torch.tensor([b for b in prompt.encode('utf-8')], dtype=torch.long).unsqueeze(0).to(device)
    
    for _ in range(max_len):
        with torch.no_grad():
            logits, _ = model(idx)
            logits = logits[:, -1, :] / temperature # Controlar caos
            probs = F.softmax(logits, dim=-1)
            
            # Sampling m√°s conservador (Top-K)
            top_k = 10
            v, _ = torch.topk(probs, top_k)
            probs[probs < v[:, [-1]]] = 0
            probs = probs / probs.sum(dim=-1, keepdim=True)
            
            next_token = torch.multinomial(probs, 1)
            idx = torch.cat((idx, next_token), dim=1)
            
    # Decodificaci√≥n resiliente
    raw_bytes = idx[0].tolist()
    decoded = bytes(raw_bytes).decode('utf-8', errors='replace')
    return decoded

print("ü§ñ Generaci√≥n (Transformer):", safe_generate(models['Transformer'], "AI is"))
print("ü§ñ Generaci√≥n (Hybrid):", safe_generate(models['Hybrid'], "AI is"))