# Hippocampal Transformer Training on Wikitext-2

This notebook trains a biologically-inspired Hippocampal Transformer with:
- **Theta-Gamma Positional Encoding** (8Hz/40Hz oscillations)
- **Place Cell Semantic Encoder** (sparse 3% activity)
- **Prosody-Modulated Attention** with episodic memory
- **Wake/Sleep Phase Training** with replay consolidation
- **EWC (Elastic Weight Consolidation)** for continual learning

## Setup

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# Install dependencies
!pip install -q datasets transformers

## Upload Source Files

Upload these files from your local `aura_clean` directory:
1. `src/core/hippocampal.py`
2. `src/core/language_zone/theta_gamma_encoding.py`
3. `src/core/language_zone/place_cell_encoder.py`
4. `src/core/language_zone/hippocampal_attention.py`
5. `src/core/language_zone/hippocampal_layer.py`
6. `src/core/language_zone/hippocampal_transformer.py`
7. `src/training/hippocampal_trainer.py`

Or run the cell below to use Google Drive:

In [None]:
# Option 1: Mount Google Drive (if you've uploaded files there)
from google.colab import drive
drive.mount('/content/drive')

# If files are in Drive, add to path:
# import sys
# sys.path.insert(0, '/content/drive/MyDrive/aura_clean')

In [None]:
# Option 2: Manual file upload
from google.colab import files
import os

# Create directory structure
!mkdir -p src/core/language_zone src/training

print("Please upload the source files using the file browser on the left -->")
print("Or use the upload button below:")
# uploaded = files.upload()

## Model Components

In [None]:
# Import all components
import sys
sys.path.insert(0, '/content')

from src.core.hippocampal import HippocampalFormation
from src.core.language_zone.hippocampal_transformer import HippocampalTransformer
from src.training.hippocampal_trainer import HippocampalTransformerTrainer

print("All components imported successfully!")

## Configuration & Dataset

In [None]:
from dataclasses import dataclass
import torch
import torch.nn as nn
from datasets import load_dataset
from transformers import GPT2Tokenizer
import math

@dataclass
class Config:
    # Model Config
    vocab_size: int = 50257
    embedding_dim: int = 384
    num_layers: int = 4
    num_heads: int = 6
    dropout: float = 0.1
    max_seq_len: int = 128
    intermediate_size: int = 1536
    
    # Hippocampal
    theta_frequency: float = 8.0
    gamma_frequency: float = 40.0
    n_place_cells: int = 800
    
    # Training
    batch_size: int = 32
    lr: float = 3e-4
    max_steps: int = 2000
    sleep_interval: int = 500
    sleep_steps: int = 10
    eval_interval: int = 100

config = Config()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

In [None]:
# Load dataset
print("Loading Wikitext-2...")
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

print(f"Train samples: {len(dataset['train'])}")
print(f"Validation samples: {len(dataset['validation'])}")

In [None]:
# Data loading utilities
def create_batches(dataset, tokenizer, config, split='train', max_batches=None):
    """Yield batches of tokenized text."""
    texts = [item['text'] for item in dataset[split] if len(item['text'].strip()) > 10]
    
    batch_count = 0
    for i in range(0, len(texts), config.batch_size):
        if max_batches and batch_count >= max_batches:
            break
            
        batch_texts = texts[i:i+config.batch_size]
        if len(batch_texts) < config.batch_size:
            continue
        
        encoded = tokenizer(
            batch_texts,
            max_length=config.max_seq_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        input_ids = encoded['input_ids']
        labels = input_ids.clone()
        prosody = torch.rand(input_ids.size(0), input_ids.size(1), 4)
        
        batch_count += 1
        yield input_ids, labels, prosody

## Model Initialization

In [None]:
# Initialize model
print("Initializing Hippocampal Transformer...")

hippocampus = HippocampalFormation(
    config.embedding_dim,
    config.n_place_cells,
    50,
    100
)

model = HippocampalTransformer(config, hippocampus).to(device)
trainer = HippocampalTransformerTrainer(model, config, hippocampus)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## Training Loop

In [None]:
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

# Training tracking
losses = []
eval_losses = []
perplexities = []
steps = []

print("Starting training...")
global_step = 0

# Create training data generator
train_gen = create_batches(dataset, tokenizer, config, 'train', max_batches=config.max_steps)

# Progress bar
pbar = tqdm(total=config.max_steps, desc="Training")

for input_ids, labels, prosody in train_gen:
    global_step += 1
    trainer.step_counter()
    
    input_ids = input_ids.to(device)
    labels = labels.to(device)
    prosody = prosody.to(device)
    
    if trainer.phase == "wake":
        optimizer.zero_grad()
        
        logits, place_activity = model(input_ids, prosody=prosody)
        loss = nn.CrossEntropyLoss()(
            logits.view(-1, config.vocab_size),
            labels.view(-1)
        )
        
        trainer.replay_buffer.add(input_ids, labels, loss.item())
        
        if trainer.ewc.fisher:
            loss += trainer.ewc.penalty(model)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        losses.append(loss.item())
        pbar.set_postfix({'loss': f"{loss.item():.4f}", 'phase': 'Wake'})
        
        # Evaluation
        if global_step % config.eval_interval == 0:
            model.eval()
            eval_loss = 0
            eval_batches = 0
            
            with torch.no_grad():
                for eval_input, eval_labels, eval_prosody in create_batches(dataset, tokenizer, config, 'validation', max_batches=20):
                    eval_input = eval_input.to(device)
                    eval_labels = eval_labels.to(device)
                    eval_prosody = eval_prosody.to(device)
                    
                    eval_logits, _ = model(eval_input, prosody=eval_prosody)
                    batch_loss = nn.CrossEntropyLoss()(
                        eval_logits.view(-1, config.vocab_size),
                        eval_labels.view(-1)
                    )
                    eval_loss += batch_loss.item()
                    eval_batches += 1
            
            avg_eval_loss = eval_loss / eval_batches
            perplexity = math.exp(avg_eval_loss)
            
            eval_losses.append(avg_eval_loss)
            perplexities.append(perplexity)
            steps.append(global_step)
            
            print(f"\nStep {global_step}: Eval Loss={avg_eval_loss:.4f}, Perplexity={perplexity:.2f}")
            model.train()
    
    elif trainer.phase == "sleep":
        pbar.set_postfix({'phase': 'Sleep'})
        print(f"\nSleep phase at step {global_step}")
        
        if not trainer.ewc.fisher and len(trainer.replay_buffer) > 0:
            mock_loader = []
            samples = trainer.replay_buffer.sample(10)
            for item in samples:
                mock_loader.append((
                    item[0].unsqueeze(0).to(device),
                    item[1].unsqueeze(0).to(device)
                ))
            trainer.ewc.compute_fisher(mock_loader, device=device)
        
        replay_losses = []
        for _ in range(config.sleep_steps):
            optimizer.zero_grad()
            loss = trainer.train_step_sleep()
            if loss is not None:
                loss.backward()
                optimizer.step()
                replay_losses.append(loss.item())
        
        avg_replay = sum(replay_losses) / len(replay_losses) if replay_losses else 0
        print(f"Replay Loss: {avg_replay:.4f}")
        
        trainer.phase = "wake"
    
    pbar.update(1)
    
    if global_step >= config.max_steps:
        break

pbar.close()
print("\nTraining complete!")

## Results & Visualization

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Training loss
axes[0].plot(losses, alpha=0.3)
# Smooth
window = 50
if len(losses) > window:
    smoothed = [sum(losses[max(0,i-window):i+1])/min(i+1,window) for i in range(len(losses))]
    axes[0].plot(smoothed, linewidth=2, label='Smoothed')
axes[0].set_xlabel('Step')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Perplexity
axes[1].plot(steps, perplexities, marker='o', linewidth=2)
axes[1].set_xlabel('Step')
axes[1].set_ylabel('Perplexity')
axes[1].set_title('Validation Perplexity')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nFinal Metrics:")
print(f"  Best Perplexity: {min(perplexities):.2f}")
print(f"  Final Perplexity: {perplexities[-1]:.2f}")
print(f"  Replay Buffer Size: {len(trainer.replay_buffer)}")
print(f"  Episodic Memories: {len(hippocampus.episodic_memories)}")

In [None]:
# Save model
torch.save({
    'model_state_dict': model.state_dict(),
    'config': config,
    'perplexity': perplexities[-1] if perplexities else None
}, 'hippocampal_transformer.pt')

print("Model saved to hippocampal_transformer.pt")
print("Download it using: files.download('hippocampal_transformer.pt')")

## Text Generation (Optional)

In [None]:
# Generate text sample
model.eval()

prompt = "The history of artificial intelligence"
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
prosody = torch.rand(1, input_ids.size(1), 4).to(device)

print(f"Prompt: {prompt}")
print("\nGenerated continuation:")

generated = input_ids
for _ in range(50):  # Generate 50 tokens
    with torch.no_grad():
        logits, _ = model(generated, prosody=prosody)
        next_token = logits[0, -1].argmax()
        generated = torch.cat([generated, next_token.unsqueeze(0).unsqueeze(0)], dim=1)
        prosody = torch.cat([prosody, torch.rand(1, 1, 4).to(device)], dim=1)

print(tokenizer.decode(generated[0]))