Cell 1: Initial Setup and GPU Check

In [None]:
#@title Initial Setup and GPU Check
import os
import sys
import time
import torch
import numpy as np
from datetime import datetime

# Check GPU availability and type
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    print(f"✓ GPU Available: {gpu_name}")
    print(f"✓ CUDA Version: {torch.version.cuda}")
    print(f"✓ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

    # Verify we have A100
    if "A100" not in gpu_name:
        print("⚠️  Warning: Not running on A100. Performance may vary.")
else:
    raise RuntimeError("❌ No GPU available! Please enable GPU in Runtime > Change runtime type")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
torch.cuda.manual_seed_all(42)

print("\n✓ Initial setup complete!")

Cell 2: Install Dependencies

In [None]:
#@title Install DeepSpeed from GitHub
%%time

# First ensure build tools
!pip install ninja packaging

# Clone and install deepspeed from source
!git clone https://github.com/microsoft/DeepSpeed.git
%cd DeepSpeed
!pip install -e .
%cd ..

# Install other deps
!pip install -U pytorch-lightning transformers datasets safetensors einops wandb huggingface-hub torchmetrics msgpack

print("✓ DeepSpeed built from source")

Cell 3: Clone and Setup RWKV-LM Repository

In [None]:
#@title Clone RWKV-LM and Setup Environment
import os
import shutil

# Clone RWKV-LM repository
if os.path.exists('RWKV-LM'):
    shutil.rmtree('RWKV-LM')

!git clone https://github.com/BlinkDL/RWKV-LM.git
%cd RWKV-LM

# Set environment variables for RWKV
os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_CUDA_ON"] = "1"

# Create necessary directories
os.makedirs("gazelle_implementation", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)
os.makedirs("logs", exist_ok=True)

print("✓ RWKV-LM repository cloned and environment configured!")

Cell 4: Mount Google Drive for Persistent Storage

In [None]:
#@title Mount Google Drive for Checkpoints
from google.colab import drive

# Mount drive
drive.mount('/content/drive', force_remount=True)

# Create checkpoint directory in Drive
checkpoint_path = '/content/drive/MyDrive/gazelle_0.5b_checkpoints'
os.makedirs(checkpoint_path, exist_ok=True)
os.makedirs(f"{checkpoint_path}/models", exist_ok=True)
os.makedirs(f"{checkpoint_path}/logs", exist_ok=True)
os.makedirs(f"{checkpoint_path}/configs", exist_ok=True)

print(f"✓ Google Drive mounted!")
print(f"✓ Checkpoint directory: {checkpoint_path}")

# Create symlink for easy access
if os.path.exists('/content/gazelle_checkpoints'):
    os.unlink('/content/gazelle_checkpoints')
os.symlink(checkpoint_path, '/content/gazelle_checkpoints')

Cell 5: Download and Prepare Dolphin Distill Dataset

In [None]:
#@title Download and Prepare Dolphin Distill Dataset (FIXED)
%%time

import os
from datasets import load_dataset
import numpy as np
import json
from tqdm import tqdm

# Create data directory
os.makedirs("data", exist_ok=True)
%cd data

print("Loading Dolphin Distill dataset from HuggingFace...")
dataset = load_dataset("cognitivecomputations/dolphin-distill", split="train")
print(f"✓ Dataset loaded! Total examples: {len(dataset):,}")

# First, let's SEE what's happening with wget
print("\nDownloading RWKV 20B tokenizer...")
print("Checking available tokenizer files...")

# Remove -q to see actual error
!wget https://github.com/BlinkDL/RWKV-LM/raw/main/RWKV-v5/20B_tokenizer.json

# If that fails, try alternative URLs
if not os.path.exists("20B_tokenizer.json"):
    print("\nTrying alternative URL...")
    # Try the v4 tokenizer which is more stable
    !wget https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/tokenizer/rwkv_vocab_v20230424.json -O 20B_tokenizer.json

# If still failing, try the model repo
if not os.path.exists("20B_tokenizer.json"):
    print("\nTrying HuggingFace model repo...")
    !wget https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/tokenizer.json -O 20B_tokenizer.json

# Check what we got
!ls -la *.json

# Continue with rest of code...

In [None]:
# Reset and create directories
import os
os.chdir('/content')
os.makedirs('RWKV-LM/gazelle_implementation', exist_ok=True)
os.chdir('/content/RWKV-LM')
print(f"Current directory: {os.getcwd()}")

Cell 6: Core Gazelle Architecture Implementation

In [None]:
#@title Gazelle 0.5B Core Architecture

# First, reset to home directory to avoid nested path issues
%cd /content

# Check where we are
!pwd

# Create the directory structure if it doesn't exist
!mkdir -p /content/RWKV-LM/gazelle_implementation

# Now cd to the correct directory
%cd /content/RWKV-LM

# Verify we're in the right place
!pwd

In [None]:
%%writefile gazelle_implementation/gazelle_model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple
from dataclasses import dataclass

@dataclass
class GazelleConfig:
    n_layer: int = 12
    n_embd: int = 1536
    n_head: int = 24
    vocab_size: int = 65536
    ctx_len: int = 512
       # GhostRNN parameters
    ghost_ratio: float = 0.375
    use_ghost: bool = True

    # Thinking parameters
    enable_thinking: bool = False
    max_think_steps: int = 5
    think_threshold: float = 0.3

    # RWKV-7 specific
    head_qk: int = 0
    time_decay_init: str = "log"

    @property
    def head_dim(self):
        return self.n_embd // self.n_head

    @property
    def intrinsic_dim(self):
        return int(self.head_dim * (1 - self.ghost_ratio))

    @property
    def ghost_dim(self):
        return self.head_dim - self.intrinsic_dim

class GhostRWKV7State(nn.Module):
    """Implements GhostRNN decomposition for RWKV-7 matrix states"""

    def __init__(self, config: GazelleConfig):
        super().__init__()
        self.config = config

        # Ghost transformation layers (cheap operations)
        self.ghost_linear = nn.ModuleList([
            nn.Linear(config.intrinsic_dim, config.ghost_dim, bias=False)
            for _ in range(config.n_head)
        ])

        # Initialize ghost transforms near identity
        for linear in self.ghost_linear:
            nn.init.xavier_uniform_(linear.weight, gain=0.1)

    def split_state(self, state):
        """Split state into intrinsic and ghost parts"""
        B, H, D1, D2 = state.shape
        intrinsic = state[:, :, :self.config.intrinsic_dim, :self.config.intrinsic_dim]
        return intrinsic

    def generate_ghost(self, intrinsic, head_idx):
        """Generate ghost state from intrinsic state"""
        B, D1, D2 = intrinsic.shape

        # Apply cheap linear transform
        ghost_rows = self.ghost_linear[head_idx](intrinsic.transpose(-1, -2))
        ghost_cols = self.ghost_linear[head_idx](intrinsic)

        # Combine to form ghost block
        ghost = torch.zeros(B, self.config.ghost_dim, self.config.ghost_dim,
                          device=intrinsic.device, dtype=intrinsic.dtype)

        # Fill ghost state (simplified for initial implementation)
        ghost = ghost_rows.transpose(-1, -2)[:, :self.config.ghost_dim, :]

        return ghost

    def combine_state(self, intrinsic, ghost):
        """Recombine intrinsic and ghost into full state"""
        B, H = intrinsic.shape[:2]
        D = self.config.head_dim

        # Create full state matrix
        full_state = torch.zeros(B, H, D, D, device=intrinsic.device, dtype=intrinsic.dtype)

        # Fill intrinsic part
        i_dim = self.config.intrinsic_dim
        full_state[:, :, :i_dim, :i_dim] = intrinsic

        # Fill ghost part
        full_state[:, :, i_dim:, i_dim:] = ghost

        # Cross terms (can be refined later)
        full_state[:, :, :i_dim, i_dim:] = intrinsic.mean(dim=-1, keepdim=True).expand(-1, -1, -1, self.config.ghost_dim) * 0.1
        full_state[:, :, i_dim:, :i_dim] = ghost.mean(dim=-1, keepdim=True).expand(-1, -1, -1, i_dim) * 0.1

        return full_state

class ThinkingModule(nn.Module):
    """Adaptive thinking mechanism for Gazelle"""

    def __init__(self, config: GazelleConfig):
        super().__init__()
        self.config = config

        # Complexity estimator network
        self.complexity_net = nn.Sequential(
            nn.Linear(config.n_embd, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

        # Thinking refinement parameters
        self.think_k = nn.Parameter(torch.zeros(1, config.n_head, 1, config.head_dim))
        self.think_v = nn.Parameter(torch.zeros(1, config.n_head, 1, config.head_dim))
        self.think_decay = nn.Parameter(torch.ones(1, config.n_head, config.head_dim))

        # Initialize thinking parameters
        nn.init.normal_(self.think_k, std=0.02)
        nn.init.normal_(self.think_v, std=0.02)

    def estimate_complexity(self, x, state):
        """Estimate complexity of current position"""
        # Combine input and state info
        state_summary = state.mean(dim=(2, 3)).flatten(1)  # [B, H*D]
        combined = torch.cat([x, state_summary], dim=-1)

        # Estimate complexity
        complexity = self.complexity_net(x).squeeze(-1)  # [B]

        return complexity

    def compute_think_steps(self, complexity, training=False):
        """Determine number of thinking steps needed"""
        if not self.config.enable_thinking:
            return 1

        # During training, sometimes force different step counts
        if training:
            # Curriculum: gradually increase max steps
            # Need a way to track training progress, maybe pass it in
            # For now, use a placeholder or simpler curriculum
            max_steps = self.config.max_think_steps # Simplified for now
        else:
            max_steps = self.config.max_think_steps

        # Threshold-based step calculation
        steps = torch.where(
            complexity > self.config.think_threshold,
            torch.clamp((complexity * max_steps).int() + 1, 1, max_steps),
            torch.ones_like(complexity, dtype=torch.int)
        )

        return steps

    def thinking_step(self, state, wkv_func):
        """Execute one thinking refinement step"""
        B, H, D1, D2 = state.shape

        # Create think keys and values
        k = self.think_k.expand(B, -1, D1, -1)
        v = self.think_v.expand(B, -1, D1, -1)

        # Apply thinking WKV operation
        # Need a WKV implementation that works with state
        # This is a placeholder, RWKV-7 WKV is complex
        # For now, a simplified state update
        refined_state = state * torch.exp(-self.think_decay.unsqueeze(-1)) + k.transpose(-1,-2) @ v

        return refined_state

class GazelleRWKV7Layer(nn.Module):
    """Single Gazelle layer combining RWKV-7, GhostRNN, and Thinking"""

    def __init__(self, config: GazelleConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx

        # Layer norms
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)

        # RWKV-7 Time Mixing
        self.time_mixing = RWKV7TimeMixing(config, layer_idx)

        # RWKV-7 Channel Mixing
        self.channel_mixing = RWKV7ChannelMixing(config, layer_idx)

        # Ghost state handler
        if config.use_ghost:
            self.ghost_state = GhostRWKV7State(config)

        # Thinking module (only in middle layers)
        if config.enable_thinking and 3 <= layer_idx <= config.n_layer - 3:
            self.thinking = ThinkingModule(config)
        else:
            self.thinking = None

    def forward(self, x, state=None, use_thinking=True):
        # Store thinking info if enabled
        thinking_info = {}

        # Time mixing with optional ghost states
        current_state = state
        if self.config.use_ghost and current_state is not None:
            # Decompose state
            intrinsic = self.ghost_state.split_state(current_state)
            ghost = torch.stack([
                self.ghost_state.generate_ghost(intrinsic[:, h], h)
                for h in range(self.config.n_head)
            ], dim=1)
            current_state = self.ghost_state.combine_state(intrinsic, ghost)

        # Thinking steps if enabled
        if self.thinking is not None and use_thinking:
            complexity = self.thinking.estimate_complexity(x, current_state)
            think_steps = self.thinking.compute_think_steps(complexity, self.training)

            thinking_info['complexities'] = complexity.mean().item()
            thinking_info['steps'] = think_steps.float().mean().item()
            thinking_states_list = [current_state.clone()] # Store initial state

            # Execute thinking steps
            for step in range(think_steps.max().item()):
                # Only apply thinking step for batches that require this step
                mask = think_steps > step
                if mask.any():
                    current_state[mask] = self.thinking.thinking_step(current_state[mask], self.time_mixing.wkv)
                    thinking_states_list.append(current_state.clone())

            thinking_info['states'] = thinking_states_list

        # Standard RWKV-7 forward pass
        # The WKV function needs to be able to handle the state argument
        x = x + self.time_mixing(self.ln1(x), current_state)
        x = x + self.channel_mixing(self.ln2(x))

        return x, current_state, thinking_info

class RWKV7TimeMixing(nn.Module):
    """RWKV-7 Time Mixing implementation"""

    def __init__(self, config: GazelleConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx

        # Receptance, Key, Value, Gate projections
        self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.key = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.value = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.gate = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.output = nn.Linear(config.n_embd, config.n_embd, bias=False)

        # Time decay and bonus parameters
        # Shape should be [H, D]
        self.time_decay = nn.Parameter(torch.zeros(config.n_head, config.head_dim))
        self.time_bonus = nn.Parameter(torch.zeros(config.n_head, config.head_dim))

        # Layer-specific initialization
        ratio_0_to_1 = layer_idx / max(config.n_layer - 1, 1)
        ratio_1_to_almost0 = 1.0 - (layer_idx / config.n_layer)

        # Initialize weights
        nn.init.orthogonal_(self.receptance.weight, gain=1)
        nn.init.orthogonal_(self.key.weight, gain=0.1)
        nn.init.orthogonal_(self.value.weight, gain=1)
        nn.init.orthogonal_(self.gate.weight, gain=0.1)
        nn.init.zeros_(self.output.weight)

        # Initialize time decay
        with torch.no_grad():
            # RWKV-7 decay initialization
            decay_values = -5 + 8 * ratio_0_to_1
            self.time_decay.uniform_(decay_values - 0.5, decay_values + 0.5)

    def wkv(self, state, k, v, decay):
        """WKV computation for RWKV-7"""
        # This needs to be the actual efficient CUDA/optimized implementation
        # The current implementation is a simplified placeholder and will be slow
        # For training, you'll need the official RWKV CUDA kernel
        # https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v5/cuda/wkv_cuda.py

        # Simplified placeholder WKV (DO NOT USE FOR REAL TRAINING)
        B, H, T, D = k.shape

        # state shape should be [B, H, D, D] for full matrix state in RWKV-7
        # Or [B, H, 2, D] for simplified state
        # The current state shape in GazelleModel is None or handled by GhostState.
        # Need to align state representation. Assuming state is None or [B, H, D, D] for now.

        output = torch.zeros_like(v)
        current_state = state if state is not None else torch.zeros(B, H, D, D, device=k.device, dtype=k.dtype)

        for t in range(T):
            # This is NOT the correct RWKV-7 state update
            # This is a generic RNN-like update for the placeholder
            current_state = current_state * decay.unsqueeze(-1) + k[:, :, t:t+1].transpose(-1, -2) @ v[:, :, t:t+1]
            output[:, :, t] = (current_state @ k[:, :, t:t+1].transpose(-1, -2)).squeeze(-1)

        return output

    def forward(self, x, state=None):
        B, T, C = x.shape
        H = self.config.n_head
        D = self.config.head_dim

        # Compute RKWG
        r = self.receptance(x).view(B, T, H, D).transpose(1, 2)
        k = self.key(x).view(B, T, H, D).transpose(1, 2)
        v = self.value(x).view(B, T, H, D).transpose(1, 2)
        g = torch.sigmoid(self.gate(x)).view(B, T, H, D).transpose(1, 2)

        # Apply WKV
        # Need to pass state to WKV
        x = self.wkv(state, k, v, torch.exp(-torch.exp(self.time_decay)))

        # Apply receptance and gate
        x = x * torch.sigmoid(r) * g

        # Combine heads and output
        x = x.transpose(1, 2).reshape(B, T, C)
        x = self.output(x)

        return x

class RWKV7ChannelMixing(nn.Module):
    """RWKV-7 Channel Mixing (FFN) implementation"""

    def __init__(self, config: GazelleConfig, layer_idx: int):
        super().__init__()
        self.config = config

        # FFN components
        self.key = nn.Linear(config.n_embd, config.n_embd * 4, bias=False)
        self.value = nn.Linear(config.n_embd * 4, config.n_embd, bias=False)
        self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)

        # Initialize
        nn.init.orthogonal_(self.key.weight, gain=1)
        nn.init.zeros_(self.value.weight)
        nn.init.zeros_(self.receptance.weight)

    def forward(self, x):
        # RWKV-7 style FFN with receptance gate
        k = self.key(x)
        k = torch.relu(k) ** 2  # Square ReLU activation
        v = self.value(k)
        r = torch.sigmoid(self.receptance(x))

        return r * v

class GazelleModel(nn.Module):
    """Complete Gazelle 0.5B Model"""

    def __init__(self, config: GazelleConfig):
        super().__init__()
        self.config = config

        # Token embeddings
        self.embedding = nn.Embedding(config.vocab_size, config.n_embd)
        self.ln_emb = nn.LayerNorm(config.n_embd)

        # Gazelle layers
        self.layers = nn.ModuleList([
            GazelleRWKV7Layer(config, idx) for idx in range(config.n_layer)
        ])

        # Output head
        self.ln_head = nn.LayerNorm(config.n_embd)
        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # Initialize embeddings
        nn.init.uniform_(self.embedding.weight, -1e-4, 1e-4)
        nn.init.orthogonal_(self.head.weight, gain=0.5 * math.sqrt(config.vocab_size / config.n_embd))

        # State management - list of states, one per layer
        self.states = [None] * self.config.n_layer

        # Thinking info storage
        self.thinking_info = {}

        print(f"✓ Gazelle Model initialized with {self.num_parameters():.2f}M parameters")
        if config.use_ghost:
            print(f"  - Ghost ratio: {config.ghost_ratio:.1%} reduction")
        if config.enable_thinking:
            print(f"  - Thinking enabled: max {config.max_think_steps} steps")

    def num_parameters(self):
        """Count trainable parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # Token embeddings
        x = self.embedding(idx)
        x = self.ln_emb(x)

        # Clear thinking info
        self.thinking_info = {'complexities': [], 'steps': [], 'state_changes': []}

        # Forward through layers
        new_states = []
        for i, layer in enumerate(self.layers):
            x, new_state, layer_thinking_info = layer(x, self.states[i])
            new_states.append(new_state)

            # Accumulate thinking info
            if self.config.enable_thinking and layer.thinking is not None:
                self.thinking_info['complexities'].append(layer_thinking_info.get('complexities', 0.0))
                self.thinking_info['steps'].append(layer_thinking_info.get('steps', 0.0))
                # State change info is calculated in ThinkingLossCalculator

        self.states = new_states # Update states

        # Output projection
        x = self.ln_head(x)
        logits = self.head(x)

        # Calculate loss if targets provided
        loss = None
        if targets is not None:
            # Reshape for cross_entropy: [N, C] and [N]
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss

    def generate(self, idx, max_tokens=100, temperature=1.0, top_p=0.9):
        """Generate text with thinking steps tracking"""
        self.eval()
        generated = []
        thinking_log = []
        current_idx = idx.clone()

        # Initialize states for generation
        self.states = [None] * self.config.n_layer

        with torch.no_grad():
            for _ in range(max_tokens):
                # Forward pass (process only the last token)
                logits, _ = self(current_idx[:, -1].unsqueeze(0)) # Process one token at a time

                logits = logits[:, -1, :] / temperature

                # Top-p sampling
                probs = F.softmax(logits, dim=-1)
                sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                cumsum = torch.cumsum(sorted_probs, dim=-1)
                mask = cumsum > top_p
                mask[:, 0] = False  # Keep at least one token
                sorted_probs[mask] = 0
                sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)

                # Sample
                next_token = torch.multinomial(sorted_probs, 1)
                next_token = sorted_indices.gather(-1, next_token)

                # Track thinking if enabled (aggregate across layers)
                if self.config.enable_thinking and self.thinking_info:
                    avg_complexity = sum(self.thinking_info.get('complexities', [])) / max(1, len(self.thinking_info.get('complexities', [])))
                    avg_steps = sum(self.thinking_info.get('steps', [])) / max(1, len(self.thinking_info.get('steps', [])))

                    thinking_log.append({
                        'token_id': next_token.item(),
                        'avg_complexity': avg_complexity,
                        'avg_think_steps': avg_steps
                    })
                    # Clear thinking info after processing one token
                    self.thinking_info = {'complexities': [], 'steps': [], 'state_changes': []}


                # Append to sequence
                current_idx = torch.cat([current_idx, next_token], dim=1)
                generated.append(next_token.item())

                # Stop on EOS token (assuming 0 is EOS, adjust if needed)
                if next_token.item() == 0:
                    break

        return generated, thinking_log

# Save the configuration
config = GazelleConfig()

In [None]:
#@title Gazelle 0.5B Core Architecture
%cd /content/RWKV-LM # Ensure we are in the correct directory

%%writefile gazelle_implementation/gazelle_model.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple
from dataclasses import dataclass

@dataclass
class GazelleConfig:
    n_layer: int = 12
    n_embd: int = 1536
    n_head: int = 24
    vocab_size: int = 65536
    ctx_len: int = 512

    # GhostRNN parameters
    ghost_ratio: float = 0.375
    use_ghost: bool = True

    # Thinking parameters
    enable_thinking: bool = False
    max_think_steps: int = 5
    think_threshold: float = 0.3

    # RWKV-7 specific
    head_qk: int = 0
    time_decay_init: str = "log"

    @property
    def head_dim(self):
        return self.n_embd // self.n_head

    @property
    def intrinsic_dim(self):
        return int(self.head_dim * (1 - self.ghost_ratio))

    @property
    def ghost_dim(self):
        return self.head_dim - self.intrinsic_dim

class GhostRWKV7State(nn.Module):
    """Implements GhostRNN decomposition for RWKV-7 matrix states"""

    def __init__(self, config: GazelleConfig):
        super().__init__()
        self.config = config

        # Ghost transformation layers (cheap operations)
        self.ghost_linear = nn.ModuleList([
            nn.Linear(config.intrinsic_dim, config.ghost_dim, bias=False)
            for _ in range(config.n_head)
        ])

        # Initialize ghost transforms near identity
        for linear in self.ghost_linear:
            nn.init.xavier_uniform_(linear.weight, gain=0.1)

    def split_state(self, state):
        """Split state into intrinsic and ghost parts"""
        B, H, D1, D2 = state.shape
        intrinsic = state[:, :, :self.config.intrinsic_dim, :self.config.intrinsic_dim]
        return intrinsic

    def generate_ghost(self, intrinsic, head_idx):
        """Generate ghost state from intrinsic state"""
        B, D1, D2 = intrinsic.shape

        # Apply cheap linear transform
        ghost_rows = self.ghost_linear[head_idx](intrinsic.transpose(-1, -2))
        ghost_cols = self.ghost_linear[head_idx](intrinsic)

        # Combine to form ghost block
        ghost = torch.zeros(B, self.config.ghost_dim, self.config.ghost_dim,
                          device=intrinsic.device, dtype=intrinsic.dtype)

        # Fill ghost state (simplified for initial implementation)
        ghost = ghost_rows.transpose(-1, -2)[:, :self.config.ghost_dim, :]

        return ghost

    def combine_state(self, intrinsic, ghost):
        """Recombine intrinsic and ghost into full state"""
        B, H = intrinsic.shape[:2]
        D = self.config.head_dim

        # Create full state matrix
        full_state = torch.zeros(B, H, D, D, device=intrinsic.device, dtype=intrinsic.dtype)

        # Fill intrinsic part
        i_dim = self.config.intrinsic_dim
        full_state[:, :, :i_dim, :i_dim] = intrinsic

        # Fill ghost part
        full_state[:, :, i_dim:, i_dim:] = ghost

        # Cross terms (can be refined later)
        full_state[:, :, :i_dim, i_dim:] = intrinsic.mean(dim=-1, keepdim=True).expand(-1, -1, -1, self.config.ghost_dim) * 0.1
        full_state[:, :, i_dim:, :i_dim] = ghost.mean(dim=-1, keepdim=True).expand(-1, -1, -1, i_dim) * 0.1

        return full_state

class ThinkingModule(nn.Module):
    """Adaptive thinking mechanism for Gazelle"""

    def __init__(self, config: GazelleConfig):
        super().__init__()
        self.config = config

        # Complexity estimator network
        self.complexity_net = nn.Sequential(
            nn.Linear(config.n_embd, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

        # Thinking refinement parameters
        self.think_k = nn.Parameter(torch.zeros(1, config.n_head, 1, config.head_dim))
        self.think_v = nn.Parameter(torch.zeros(1, config.n_head, 1, config.head_dim))
        self.think_decay = nn.Parameter(torch.ones(1, config.n_head, config.head_dim))

        # Initialize thinking parameters
        nn.init.normal_(self.think_k, std=0.02)
        nn.init.normal_(self.think_v, std=0.02)

    def estimate_complexity(self, x, state):
        """Estimate complexity of current position"""
        # Combine input and state info
        state_summary = state.mean(dim=(2, 3)).flatten(1)  # [B, H*D]
        combined = torch.cat([x, state_summary], dim=-1)

        # Estimate complexity
        complexity = self.complexity_net(x).squeeze(-1)  # [B]

        return complexity

    def compute_think_steps(self, complexity, training=False):
        """Determine number of thinking steps needed"""
        if not self.config.enable_thinking:
            return 1

        # During training, sometimes force different step counts
        if training:
            # Curriculum: gradually increase max steps
            # Need a way to track training progress, maybe pass it in
            # For now, use a placeholder or simpler curriculum
            max_steps = self.config.max_think_steps # Simplified for now
        else:
            max_steps = self.config.max_think_steps

        # Threshold-based step calculation
        steps = torch.where(
            complexity > self.config.think_threshold,
            torch.clamp((complexity * max_steps).int() + 1, 1, max_steps),
            torch.ones_like(complexity, dtype=torch.int)
        )

        return steps

    def thinking_step(self, state, wkv_func):
        """Execute one thinking refinement step"""
        B, H, D1, D2 = state.shape

        # Create think keys and values
        k = self.think_k.expand(B, -1, D1, -1)
        v = self.think_v.expand(B, -1, D1, -1)

        # Apply thinking WKV operation
        # Need a WKV implementation that works with state
        # This is a placeholder, RWKV-7 WKV is complex
        # For now, a simplified state update
        refined_state = state * torch.exp(-self.think_decay.unsqueeze(-1)) + k.transpose(-1,-2) @ v

        return refined_state

class GazelleRWKV7Layer(nn.Module):
    """Single Gazelle layer combining RWKV-7, GhostRNN, and Thinking"""

    def __init__(self, config: GazelleConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx

        # Layer norms
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)

        # RWKV-7 Time Mixing
        self.time_mixing = RWKV7TimeMixing(config, layer_idx)

        # RWKV-7 Channel Mixing
        self.channel_mixing = RWKV7ChannelMixing(config, layer_idx)

        # Ghost state handler
        if config.use_ghost:
            self.ghost_state = GhostRWKV7State(config)

        # Thinking module (only in middle layers)
        if config.enable_thinking and 3 <= layer_idx <= config.n_layer - 3:
            self.thinking = ThinkingModule(config)
        else:
            self.thinking = None

    def forward(self, x, state=None, use_thinking=True):
        # Store thinking info if enabled
        thinking_info = {}

        # Time mixing with optional ghost states
        current_state = state
        if self.config.use_ghost and current_state is not None:
            # Decompose state
            intrinsic = self.ghost_state.split_state(current_state)
            ghost = torch.stack([
                self.ghost_state.generate_ghost(intrinsic[:, h], h)
                for h in range(self.config.n_head)
            ], dim=1)
            current_state = self.ghost_state.combine_state(intrinsic, ghost)

        # Thinking steps if enabled
        if self.thinking is not None and use_thinking:
            complexity = self.thinking.estimate_complexity(x, current_state)
            think_steps = self.thinking.compute_think_steps(complexity, self.training)

            thinking_info['complexities'] = complexity.mean().item()
            thinking_info['steps'] = think_steps.float().mean().item()
            thinking_states_list = [current_state.clone()] # Store initial state

            # Execute thinking steps
            for step in range(think_steps.max().item()):
                # Only apply thinking step for batches that require this step
                mask = think_steps > step
                if mask.any():
                    current_state[mask] = self.thinking.thinking_step(current_state[mask], self.time_mixing.wkv)
                    thinking_states_list.append(current_state.clone())

            thinking_info['states'] = thinking_states_list

        # Standard RWKV-7 forward pass
        # The WKV function needs to be able to handle the state argument
        x = x + self.time_mixing(self.ln1(x), current_state)
        x = x + self.channel_mixing(self.ln2(x))

        return x, current_state, thinking_info

class RWKV7TimeMixing(nn.Module):
    """RWKV-7 Time Mixing implementation"""

    def __init__(self, config: GazelleConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx

        # Receptance, Key, Value, Gate projections
        self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.key = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.value = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.gate = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.output = nn.Linear(config.n_embd, config.n_embd, bias=False)

        # Time decay and bonus parameters
        # Shape should be [H, D]
        self.time_decay = nn.Parameter(torch.zeros(config.n_head, config.head_dim))
        self.time_bonus = nn.Parameter(torch.zeros(config.n_head, config.head_dim))

        # Layer-specific initialization
        ratio_0_to_1 = layer_idx / max(config.n_layer - 1, 1)
        ratio_1_to_almost0 = 1.0 - (layer_idx / config.n_layer)

        # Initialize weights
        nn.init.orthogonal_(self.receptance.weight, gain=1)
        nn.init.orthogonal_(self.key.weight, gain=0.1)
        nn.init.orthogonal_(self.value.weight, gain=1)
        nn.init.orthogonal_(self.gate.weight, gain=0.1)
        nn.init.zeros_(self.output.weight)

        # Initialize time decay
        with torch.no_grad():
            # RWKV-7 decay initialization
            decay_values = -5 + 8 * ratio_0_to_1
            self.time_decay.uniform_(decay_values - 0.5, decay_values + 0.5)

    def wkv(self, state, k, v, decay):
        """WKV computation for RWKV-7"""
        # This needs to be the actual efficient CUDA/optimized implementation
        # The current implementation is a simplified placeholder and will be slow
        # For training, you'll need the official RWKV CUDA kernel
        # https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v5/cuda/wkv_cuda.py

        # Simplified placeholder WKV (DO NOT USE FOR REAL TRAINING)
        B, H, T, D = k.shape

        # state shape should be [B, H, D, D] for full matrix state in RWKV-7
        # Or [B, H, 2, D] for simplified state
        # The current state shape in GazelleModel is None or handled by GhostState.
        # Need to align state representation. Assuming state is None or [B, H, D, D] for now.

        output = torch.zeros_like(v)
        current_state = state if state is not None else torch.zeros(B, H, D, D, device=k.device, dtype=k.dtype)

        for t in range(T):
            # This is NOT the correct RWKV-7 state update
            # This is a generic RNN-like update for the placeholder
            current_state = current_state * decay.unsqueeze(-1) + k[:, :, t:t+1].transpose(-1, -2) @ v[:, :, t:t+1]
            output[:, :, t] = (current_state @ k[:, :, t:t+1].transpose(-1, -2)).squeeze(-1)

        return output

    def forward(self, x, state=None):
        B, T, C = x.shape
        H = self.config.n_head
        D = self.config.head_dim

        # Compute RKWG
        r = self.receptance(x).view(B, T, H, D).transpose(1, 2)
        k = self.key(x).view(B, T, H, D).transpose(1, 2)
        v = self.value(x).view(B, T, H, D).transpose(1, 2)
        g = torch.sigmoid(self.gate(x)).view(B, T, H, D).transpose(1, 2)

        # Apply WKV
        # Need to pass state to WKV
        x = self.wkv(state, k, v, torch.exp(-torch.exp(self.time_decay)))

        # Apply receptance and gate
        x = x * torch.sigmoid(r) * g

        # Combine heads and output
        x = x.transpose(1, 2).reshape(B, T, C)
        x = self.output(x)

        return x

class RWKV7ChannelMixing(nn.Module):
    """RWKV-7 Channel Mixing (FFN) implementation"""

    def __init__(self, config: GazelleConfig, layer_idx: int):
        super().__init__()
        self.config = config

        # FFN components
        self.key = nn.Linear(config.n_embd, config.n_embd * 4, bias=False)
        self.value = nn.Linear(config.n_embd * 4, config.n_embd, bias=False)
        self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)

        # Initialize
        nn.init.orthogonal_(self.key.weight, gain=1)
        nn.init.zeros_(self.value.weight)
        nn.init.zeros_(self.receptance.weight)

    def forward(self, x):
        # RWKV-7 style FFN with receptance gate
        k = self.key(x)
        k = torch.relu(k) ** 2  # Square ReLU activation
        v = self.value(k)
        r = torch.sigmoid(self.receptance(x))

        return r * v

class GazelleModel(nn.Module):
    """Complete Gazelle 0.5B Model"""

    def __init__(self, config: GazelleConfig):
        super().__init__()
        self.config = config

        # Token embeddings
        self.embedding = nn.Embedding(config.vocab_size, config.n_embd)
        self.ln_emb = nn.LayerNorm(config.n_embd)

        # Gazelle layers
        self.layers = nn.ModuleList([
            GazelleRWKV7Layer(config, idx) for idx in range(config.n_layer)
        ])

        # Output head
        self.ln_head = nn.LayerNorm(config.n_embd)
        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # Initialize embeddings
        nn.init.uniform_(self.embedding.weight, -1e-4, 1e-4)
        nn.init.orthogonal_(self.head.weight, gain=0.5 * math.sqrt(config.vocab_size / config.n_embd))

        # State management - list of states, one per layer
        self.states = [None] * self.config.n_layer

        # Thinking info storage
        self.thinking_info = {}

        print(f"✓ Gazelle Model initialized with {self.num_parameters():.2f}M parameters")
        if config.use_ghost:
            print(f"  - Ghost ratio: {config.ghost_ratio:.1%} reduction")
        if config.enable_thinking:
            print(f"  - Thinking enabled: max {config.max_think_steps} steps")

    def num_parameters(self):
        """Count trainable parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # Token embeddings
        x = self.embedding(idx)
        x = self.ln_emb(x)

        # Clear thinking info
        self.thinking_info = {'complexities': [], 'steps': [], 'state_changes': []}

        # Forward through layers
        new_states = []
        for i, layer in enumerate(self.layers):
            x, new_state, layer_thinking_info = layer(x, self.states[i])
            new_states.append(new_state)

            # Accumulate thinking info
            if self.config.enable_thinking and layer.thinking is not None:
                self.thinking_info['complexities'].append(layer_thinking_info.get('complexities', 0.0))
                self.thinking_info['steps'].append(layer_thinking_info.get('steps', 0.0))
                # State change info is calculated in ThinkingLossCalculator

        self.states = new_states # Update states

        # Output projection
        x = self.ln_head(x)
        logits = self.head(x)

        # Calculate loss if targets provided
        loss = None
        if targets is not None:
            # Reshape for cross_entropy: [N, C] and [N]
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss

    def generate(self, idx, max_tokens=100, temperature=1.0, top_p=0.9):
        """Generate text with thinking steps tracking"""
        self.eval()
        generated = []
        thinking_log = []
        current_idx = idx.clone()

        # Initialize states for generation
        self.states = [None] * self.config.n_layer

        with torch.no_grad():
            for _ in range(max_tokens):
                # Forward pass (process only the last token)
                logits, _ = self(current_idx[:, -1].unsqueeze(0)) # Process one token at a time

                logits = logits[:, -1, :] / temperature

                # Top-p sampling
                probs = F.softmax(logits, dim=-1)
                sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                cumsum = torch.cumsum(sorted_probs, dim=-1)
                mask = cumsum > top_p
                mask[:, 0] = False  # Keep at least one token
                sorted_probs[mask] = 0
                sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)

                # Sample
                next_token = torch.multinomial(sorted_probs, 1)
                next_token = sorted_indices.gather(-1, next_token)

                # Track thinking if enabled (aggregate across layers)
                if self.config.enable_thinking and self.thinking_info:
                    avg_complexity = sum(self.thinking_info.get('complexities', [])) / max(1, len(self.thinking_info.get('complexities', [])))
                    avg_steps = sum(self.thinking_info.get('steps', [])) / max(1, len(self.thinking_info.get('steps', [])))

                    thinking_log.append({
                        'token_id': next_token.item(),
                        'avg_complexity': avg_complexity,
                        'avg_think_steps': avg_steps
                    })
                    # Clear thinking info after processing one token
                    self.thinking_info = {'complexities': [], 'steps': [], 'state_changes': []}


                # Append to sequence
                current_idx = torch.cat([current_idx, next_token], dim=1)
                generated.append(next_token.item())

                # Stop on EOS token (assuming 0 is EOS, adjust if needed)
                if next_token.item() == 0:
                    break

        return generated, thinking_log

# Save the configuration
config = GazelleConfig()
# Ensure the directory exists before saving
config_save_dir = '/content/gazelle_checkpoints/configs'
os.makedirs(config_save_dir, exist_ok=True)
torch.save(config, f'{config_save_dir}/base_config.pt')
print("✓ Gazelle model architecture defined and saved!")

Cell 7: Training Utilities and Data Loading

In [None]:
#@title Training Utilities and Data Loading
%%writefile gazelle_implementation/training_utils.py

import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import time
import os
from contextlib import contextmanager

class MemoryEfficientDataset(Dataset):
    """Memory-mapped dataset for Colab's limited RAM"""

    def __init__(self, data_path, idx_path, ctx_len=512, epoch_length=None):
        self.ctx_len = ctx_len

        # Memory map the data files
        self.data = np.memmap(data_path, dtype=np.uint16, mode='r')
        self.idx = np.load(idx_path).astype(np.int64)

        # Calculate actual data length
        self.data_length = len(self.idx) - 1
        self.epoch_length = epoch_length or self.data_length

        print(f"✓ Dataset initialized:")
        print(f"  - Total tokens: {len(self.data):,}")
        print(f"  - Total samples: {self.data_length:,}")
        print(f"  - Context length: {ctx_len}")

    def __len__(self):
        return self.epoch_length

    def __getitem__(self, idx):
        # Use modulo for cycling through data
        real_idx = idx % self.data_length

        # Get start and end positions
        start = self.idx[real_idx]
        end = self.idx[real_idx + 1]

        # Extract chunk
        chunk = self.data[start:end]

        # Pad or truncate to ctx_len
        if len(chunk) >= self.ctx_len + 1:
            # Random offset for variety
            offset = np.random.randint(0, len(chunk) - self.ctx_len)
            chunk = chunk[offset:offset + self.ctx_len + 1]
        else:
            # Pad if too short
            chunk = np.pad(chunk, (0, self.ctx_len + 1 - len(chunk)), constant_values=0)

        # Convert to torch tensors
        x = torch.from_numpy(chunk[:-1].astype(np.int64))
        y = torch.from_numpy(chunk[1:].astype(np.int64))

        return x, y

class ColabCheckpointer:
    """Handle checkpointing with Colab session management"""

    def __init__(self, save_dir, max_runtime_hours=11.5):
        self.save_dir = save_dir
        self.start_time = time.time()
        self.max_runtime = max_runtime_hours * 3600
        self.checkpoint_counter = 0

    def should_checkpoint(self, force=False):
        """Check if we should save a checkpoint"""
        elapsed = time.time() - self.start_time

        # Save if approaching Colab timeout
        if elapsed > self.max_runtime - 600:  # 10 min buffer
            return True, "approaching_timeout"

        # Regular interval saves (every 30 min)
        if elapsed > (self.checkpoint_counter + 1) * 1800:
            return True, "regular_interval"

        return force, "forced" if force else ""

    def save(self, model, optimizer, epoch, step, metrics=None):
        """Save checkpoint with metadata"""
        should_save, reason = self.should_checkpoint()

        if should_save:
            checkpoint = {
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch': epoch,
                'step': step,
                'metrics': metrics or {},
                'config': model.config,
                'timestamp': time.time(),
                'reason': reason
            }

            # Save with descriptive filename
            filename = f"gazelle_e{epoch}_s{step}_{reason}.pt"
            path = os.path.join(self.save_dir, filename)

            torch.save(checkpoint, path)
            self.checkpoint_counter += 1

            print(f"✓ Checkpoint saved: {filename}")

            # Also save a 'latest' symlink
            latest_path = os.path.join(self.save_dir, "latest.pt")
            if os.path.exists(latest_path):
                os.unlink(latest_path)
            os.symlink(path, latest_path)

            return True

        return False

    def load_latest(self, model, optimizer=None):
        """Load the most recent checkpoint"""
        latest_path = os.path.join(self.save_dir, "latest.pt")

        if os.path.exists(latest_path):
            checkpoint = torch.load(latest_path)
            model.load_state_dict(checkpoint['model_state_dict'])

            if optimizer is not None:
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

            print(f"✓ Loaded checkpoint from epoch {checkpoint['epoch']}, step {checkpoint['step']}")
            return checkpoint

        return None

class GradientAccumulator:
    """Handle gradient accumulation for larger effective batch sizes"""

    def __init__(self, accumulation_steps=16):
        self.accumulation_steps = accumulation_steps
        self.step_count = 0

    @contextmanager
    def accumulate(self, model):
        """Context manager for gradient accumulation"""
        self.step_count += 1

        # Scale gradients by accumulation steps
        if self.step_count < self.accumulation_steps:
            # Don't sync gradients yet
            with model.no_sync() if hasattr(model, 'no_sync') else contextlib.nullcontext():
                yield self.accumulation_steps
        else:
            # Time to sync and step
            yield self.accumulation_steps
            self.step_count = 0

class ThinkingLossCalculator:
    """Calculate auxiliary losses for thinking mechanism"""

    def __init__(self, config):
        self.config = config
        self.state_change_weight = 0.1
        self.convergence_weight = 0.05
        self.efficiency_weight = 0.02

    def calculate(self, thinking_states, complexities, actual_steps):
        """Calculate thinking-specific losses"""
        losses = {}

        if not thinking_states or len(thinking_states) < 2:
            return losses

        # State change loss - ensure thinking steps are meaningful
        state_changes = []
        for i in range(len(thinking_states) - 1):
            change = torch.norm(thinking_states[i+1] - thinking_states[i], dim=-1)
            state_changes.append(change)

        state_change_loss = -torch.stack(state_changes).mean() * self.state_change_weight
        losses['state_change'] = state_change_loss

        # Convergence loss - encourage diminishing changes
        if len(state_changes) > 1:
            convergence_ratios = []
            for i in range(len(state_changes) - 1):
                ratio = state_changes[i+1] / (state_changes[i] + 1e-6)
                convergence_ratios.append(ratio)

            convergence_loss = torch.stack(convergence_ratios).mean() * self.convergence_weight
            losses['convergence'] = convergence_loss

        # Efficiency loss - penalize unnecessary thinking
        efficiency_loss = (actual_steps.float() - complexities).pow(2).mean() * self.efficiency_weight
        losses['efficiency'] = efficiency_loss

        return losses

class TrainingMetrics:
    """Track and log training metrics"""

    def __init__(self, log_interval=100):
        self.log_interval = log_interval
        self.reset()

    def reset(self):
        self.losses = []
        self.lrs = []
        self.thinking_stats = {
            'avg_steps': [],
            'complexity': [],
            'state_changes': []
        }
        self.step = 0

    def update(self, loss, lr, thinking_info=None):
        """Update metrics"""
        self.losses.append(loss)
        self.lrs.append(lr)

        if thinking_info:
            for key, value in thinking_info.items():
                if key in self.thinking_stats:
                    self.thinking_stats[key].append(value)

        self.step += 1

        # Log at intervals
        if self.step % self.log_interval == 0:
            self.log()

    def log(self):
        """Print current metrics"""
        avg_loss = np.mean(self.losses[-self.log_interval:])
        current_lr = self.lrs[-1]

        print(f"\nStep {self.step}:")
        print(f"  Loss: {avg_loss:.4f}")
        print(f"  LR: {current_lr:.2e}")

        if any(len(v) > 0 for v in self.thinking_stats.values()):
            print("  Thinking stats:")
            for key, values in self.thinking_stats.items():
                if values:
                    print(f"    {key}: {np.mean(values[-self.log_interval:]):.3f}")

# Memory monitoring utilities
def log_gpu_memory():
    """Log current GPU memory usage"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1e9
        reserved = torch.cuda.memory_reserved() / 1e9
        print(f"GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")

def optimize_memory():
    """Clear GPU cache and optimize memory"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

print("✓ Training utilities defined!")

Cell 8: Main Training Script

In [None]:
#@title Main Training Script for Gazelle 0.5B
%%writefile train_gazelle.py

import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
import os
import sys
import time
import json
from datetime import datetime

# Add implementation to path
sys.path.append('gazelle_implementation')

from gazelle_model import GazelleModel, GazelleConfig
from training_utils import (
    MemoryEfficientDataset, ColabCheckpointer, GradientAccumulator,
    ThinkingLossCalculator, TrainingMetrics, log_gpu_memory, optimize_memory
)

def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    """Cosine learning rate schedule with warmup"""
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

class GazelleTrainer:
    """Main trainer class for Gazelle model"""

    def __init__(self, config, checkpoint_dir):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Initialize model
        self.model = GazelleModel(config).to(self.device)

        # Training components
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=3e-4,
            betas=(0.9, 0.99),
            weight_decay=0.1,
            eps=1e-8
        )

        # Mixed precision training
        self.scaler = GradScaler()

        # Gradient accumulation
        self.accumulator = GradientAccumulator(accumulation_steps=16)

        # Checkpointing
        self.checkpointer = ColabCheckpointer(checkpoint_dir)

        # Metrics
        self.metrics = TrainingMetrics()

        # Thinking loss (if enabled)
        if config.enable_thinking:
            self.thinking_loss_calc = ThinkingLossCalculator(config)

        # Load dataset
        self.train_dataset = MemoryEfficientDataset(
            'data/minipile.bin',
            'data/minipile.idx',
            ctx_len=config.ctx_len
        )

        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=8,  # Small batch for Colab
            shuffle=True,
            num_workers=2,
            pin_memory=True
        )

        # Calculate training steps
        self.steps_per_epoch = len(self.train_loader) // self.accumulator.accumulation_steps
        self.total_steps = self.steps_per_epoch * 10  # 10 epochs initially

        # Learning rate scheduler
        self.scheduler = get_cosine_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=100,
            num_training_steps=self.total_steps
        )

        # Try to load checkpoint
        checkpoint = self.checkpointer.load_latest(self.model, self.optimizer)
        if checkpoint:
            self.start_epoch = checkpoint['epoch']
            self.global_step = checkpoint['step']
        else:
            self.start_epoch = 0
            self.global_step = 0

    def train_epoch(self, epoch):
        """Train for one epoch"""
        self.model.train()
        epoch_start = time.time()

        for batch_idx, (x, y) in enumerate(self.train_loader):
            # Move to device
            x = x.to(self.device)
            y = y.to(self.device)

            # Forward pass with mixed precision
            with autocast():
                with self.accumulator.accumulate(self.model) as acc_steps:
                    logits, loss = self.model(x, y)

                    # Scale loss by accumulation steps
                    loss = loss / acc_steps

                    # Add thinking losses if enabled
                    if self.config.enable_thinking and hasattr(self.model, 'thinking_info'):
                        thinking_losses = self.thinking_loss_calc.calculate(
                            self.model.thinking_info.get('states', []),
                            self.model.thinking_info.get('complexities', []),
                            self.model.thinking_info.get('steps', [])
                        )

                        for t_loss in thinking_losses.values():
                            loss = loss + t_loss / acc_steps

            # Backward pass
            self.scaler.scale(loss).backward()

            # Optimizer step (only when accumulator says so)
            if self.accumulator.step_count == 0:
                # Gradient clipping
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)

                # Optimizer step
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.optimizer.zero_grad()

                # Scheduler step
                self.scheduler.step()

                # Update metrics
                self.metrics.update(
                    loss.item() * acc_steps,
                    self.scheduler.get_last_lr()[0]
                )

                self.global_step += 1

            # Checkpointing
            if self.global_step > 0 and self.global_step % 100 == 0:
                self.checkpointer.save(
                    self.model, self.optimizer, epoch, self.global_step,
                    {'loss': loss.item() * acc_steps}
                )

            # Memory management
            if batch_idx % 50 == 0:
                optimize_memory()
                log_gpu_memory()

        epoch_time = time.time() - epoch_start
        print(f"\nEpoch {epoch} completed in {epoch_time:.1f}s")
        print(f"Average time per step: {epoch_time/self.steps_per_epoch:.2f}s")

    def train(self, num_epochs=None):
        """Main training loop"""
        num_epochs = num_epochs or (10 - self.start_epoch)

        print(f"\nStarting training:")
        print(f"  Model parameters: {self.model.num_parameters():.2f}M")
        print(f"  Batch size: 8 x {self.accumulator.accumulation_steps} = {8 * self.accumulator.accumulation_steps}")
        print(f"  Total steps: {self.total_steps}")
        print(f"  Starting from epoch: {self.start_epoch}")

        try:
            for epoch in range(self.start_epoch, self.start_epoch + num_epochs):
                print(f"\n{'='*50}")
                print(f"Epoch {epoch + 1}/{self.start_epoch + num_epochs}")
                print(f"{'='*50}")

                self.train_epoch(epoch)

                # Save end-of-epoch checkpoint
                self.checkpointer.save(
                    self.model, self.optimizer, epoch + 1, self.global_step,
                    {'epoch_complete': True}
                )

        except KeyboardInterrupt:
            print("\n\nTraining interrupted! Saving checkpoint...")
            self.checkpointer.save(
                self.model, self.optimizer, epoch, self.global_step,
                {'interrupted': True}
            )

        print("\n✓ Training completed!")
        return self.model

# Configuration for phased training
def get_phase_config(phase):
    """Get configuration for different training phases"""
    base_config = GazelleConfig()

    if phase == 1:
        # Phase 1: Baseline RWKV-7
        base_config.use_ghost = False
        base_config.enable_thinking = False
        print("Phase 1: Baseline RWKV-7 training")

    elif phase == 2:
        # Phase 2: Add GhostRNN
        base_config.use_ghost = True
        base_config.enable_thinking = False
        print("Phase 2: GhostRNN integration")

    elif phase == 3:
        # Phase 3: Add Thinking
        base_config.use_ghost = True
        base_config.enable_thinking = True
        base_config.max_think_steps = 3  # Start small
        print("Phase 3: Thinking mechanism")

    elif phase == 4:
        # Phase 4: Full model
        base_config.use_ghost = True
        base_config.enable_thinking = True
        base_config.max_think_steps = 5
        print("Phase 4: Full Gazelle model")

    return base_config

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--phase', type=int, default=1, help='Training phase (1-4)')
    parser.add_argument('--epochs', type=int, default=10, help='Number of epochs')
    parser.add_argument('--checkpoint_dir', type=str, default='/content/gazelle_checkpoints/models')
    args = parser.parse_args()

    # Get config for current phase
    config = get_phase_config(args.phase)

    # Save config
    config_path = f"{args.checkpoint_dir}/phase{args.phase}_config.json"
    with open(config_path, 'w') as f:
        json.dump(config.__dict__, f, indent=2)

    # Initialize trainer
    trainer = GazelleTrainer(config, args.checkpoint_dir)

    # Train model
    model = trainer.train(args.epochs)

    # Save final model
    final_path = f"{args.checkpoint_dir}/gazelle_phase{args.phase}_final.pt"
    torch.save(model.state_dict(), final_path)
    print(f"✓ Final model saved to {final_path}")

print("✓ Training script created!")

Cell 9: Run Phase 1 Training (Baseline RWKV-7)

In [25]:
# Check what's in the data directory
!ls -la /content/RWKV-LM/data/

total 8
drwxr-xr-x  2 root root 4096 Jul  2 05:54 .
drwxr-xr-x 17 root root 4096 Jul  2 05:52 ..
-rw-r--r--  1 root root    0 Jul  2 05:44 20B_tokenizer.json
lrwxrwxrwx  1 root root   19 Jul  2 05:54 minipile.bin -> dolphin_distill.bin
lrwxrwxrwx  1 root root   23 Jul  2 05:54 minipile.idx -> dolphin_distill.idx.npy
lrwxrwxrwx  1 root root   23 Jul  2 05:54 minipile_val.bin -> dolphin_distill_val.bin
lrwxrwxrwx  1 root root   27 Jul  2 05:54 minipile_val.idx -> dolphin_distill_val.idx.npy


In [26]:
# Create the expected data directory if it doesn't exist
!mkdir -p /content/RWKV-LM/data

# Copy or move the files to where the script expects them
# Adjust the source path based on what the find command shows
!cp /content/data/dolphin_distill* /content/RWKV-LM/data/
!ls -la /content/RWKV-LM/data/

cp: cannot stat '/content/data/dolphin_distill*': No such file or directory
total 8
drwxr-xr-x  2 root root 4096 Jul  2 05:54 .
drwxr-xr-x 17 root root 4096 Jul  2 05:52 ..
-rw-r--r--  1 root root    0 Jul  2 05:44 20B_tokenizer.json
lrwxrwxrwx  1 root root   19 Jul  2 05:54 minipile.bin -> dolphin_distill.bin
lrwxrwxrwx  1 root root   23 Jul  2 05:54 minipile.idx -> dolphin_distill.idx.npy
lrwxrwxrwx  1 root root   23 Jul  2 05:54 minipile_val.bin -> dolphin_distill_val.bin
lrwxrwxrwx  1 root root   27 Jul  2 05:54 minipile_val.idx -> dolphin_distill_val.idx.npy


In [27]:
# Search for the data files
!find /content -name "dolphin_distill.bin" -type f 2>/dev/null
!find /content -name "*.bin" -type f 2>/dev/null | grep -E "(dolphin|minipile)"

# Check current working directory
!pwd

# List what's in various data directories
!ls -la /content/data/ 2>/dev/null || echo "No /content/data/"
!ls -la /content/RWKV-LM/data/ 2>/dev/null || echo "No /content/RWKV-LM/data/"

/content/RWKV-LM
No /content/data/
total 8
drwxr-xr-x  2 root root 4096 Jul  2 05:54 .
drwxr-xr-x 17 root root 4096 Jul  2 05:52 ..
-rw-r--r--  1 root root    0 Jul  2 05:44 20B_tokenizer.json
lrwxrwxrwx  1 root root   19 Jul  2 05:54 minipile.bin -> dolphin_distill.bin
lrwxrwxrwx  1 root root   23 Jul  2 05:54 minipile.idx -> dolphin_distill.idx.npy
lrwxrwxrwx  1 root root   23 Jul  2 05:54 minipile_val.bin -> dolphin_distill_val.bin
lrwxrwxrwx  1 root root   27 Jul  2 05:54 minipile_val.idx -> dolphin_distill_val.idx.npy


In [28]:
%cd /content/RWKV-LM/data

# Create symbolic links
!ln -sf dolphin_distill.bin minipile.bin
!ln -sf dolphin_distill.idx.npy minipile.idx
!ln -sf dolphin_distill_val.bin minipile_val.bin
!ln -sf dolphin_distill_val.idx.npy minipile_val.idx

# Verify links
!ls -la *.bin *.idx*
%cd /content/RWKV-LM

/content/RWKV-LM/data
lrwxrwxrwx 1 root root 19 Jul  2 05:58 minipile.bin -> dolphin_distill.bin
lrwxrwxrwx 1 root root 23 Jul  2 05:58 minipile.idx -> dolphin_distill.idx.npy
lrwxrwxrwx 1 root root 23 Jul  2 05:58 minipile_val.bin -> dolphin_distill_val.bin
lrwxrwxrwx 1 root root 27 Jul  2 05:58 minipile_val.idx -> dolphin_distill_val.idx.npy
/content/RWKV-LM


In [29]:
# Check where train_gazelle.py is looking for data
!grep -n "data/" /content/RWKV-LM/train_gazelle.py | head -10

# Replace with absolute paths (adjust based on where your files actually are)
!sed -i "s|'data/dolphin_distill.bin'|'/content/data/dolphin_distill.bin'|g" /content/RWKV-LM/train_gazelle.py
!sed -i "s|'data/dolphin_distill.idx.npy'|'/content/data/dolphin_distill.idx.npy'|g" /content/RWKV-LM/train_gazelle.py
!sed -i "s|'data/dolphin_distill_val.bin'|'/content/data/dolphin_distill_val.bin'|g" /content/RWKV-LM/train_gazelle.py
!sed -i "s|'data/dolphin_distill_val.idx.npy'|'/content/data/dolphin_distill_val.idx.npy'|g" /content/RWKV-LM/train_gazelle.py

68:            'data/dolphin_distill.bin',
69:            'data/dolphin_distill.idx.npy',


In [30]:
# Retry training
!python train_gazelle.py --phase 1 --epochs 5

✓ Training utilities defined!
Phase 1: Baseline RWKV-7 training
✓ Gazelle Model initialized with 597.81M parameters
  self.scaler = GradScaler('cuda')
Traceback (most recent call last):
  File "/content/RWKV-LM/train_gazelle.py", line 256, in <module>
    trainer = GazelleTrainer(config, args.checkpoint_dir)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/RWKV-LM/train_gazelle.py", line 67, in __init__
    self.train_dataset = MemoryEfficientDataset(
                         ^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/RWKV-LM/gazelle_implementation/training_utils.py", line 17, in __init__
    self.data = np.memmap(data_path, dtype=np.uint16, mode='r')
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/numpy/_core/memmap.py", line 233, in __new__
    f_ctx = open(
            ^^^^^
FileNotFoundError: [Errno 2] No such file or directory: '/content/data/dolphin_distill.bin'


In [None]:
#@title Phase 1: Train Baseline RWKV-7
!python train_gazelle.py --phase 1 --epochs 5

# Monitor training progress
print("\n✓ Phase 1 training initiated!")
print("Monitor the output above for training progress.")
print("Checkpoints will be saved automatically.")

Cell 10: Run Phase 2 Training (Add GhostRNN)

In [None]:
#@title Phase 2: Integrate GhostRNN
# First, load the Phase 1 checkpoint
import torch
import sys
sys.path.append('gazelle_implementation')
from gazelle_model import GazelleConfig, GazelleModel

# Load Phase 1 model
phase1_checkpoint = torch.load('/content/gazelle_checkpoints/models/gazelle_phase1_final.pt')

# Create Phase 2 config
phase2_config = GazelleConfig()
phase2_config.use_ghost = True
phase2_config.enable_thinking = False

# Initialize Phase 2 model
phase2_model = GazelleModel(phase2_config)

# Transfer weights where possible
phase2_model.load_state_dict(phase1_checkpoint, strict=False)

# Save as starting point for Phase 2
torch.save({
    'model_state_dict': phase2_model.state_dict(),
    'optimizer_state_dict': None,
    'epoch': 0,
    'step': 0,
    'config': phase2_config
}, '/content/gazelle_checkpoints/models/latest.pt')

print("✓ Phase 1 weights transferred to Phase 2 model")

# Run Phase 2 training
!python train_gazelle.py --phase 2 --epochs 5

Cell 11: Run Phase 3 Training (Add Thinking)

In [None]:
#@title Phase 3: Add Thinking Mechanism
# Similar transfer process for Phase 3
!python train_gazelle.py --phase 3 --epochs 10

print("\n✓ Phase 3 training with thinking mechanism initiated!")

Cell 12: Interactive Demo

In [None]:
#@title Interactive Gazelle Demo
import gradio as gr
import torch
import sys
sys.path.append('gazelle_implementation')
from gazelle_model import GazelleModel, GazelleConfig

# Load model for demo
model_path = '/content/gazelle_checkpoints/models/latest.pt'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load model
checkpoint = torch.load(model_path)
config = checkpoint.get('config', GazelleConfig())
model = GazelleModel(config).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

def generate_text(prompt, max_length=100, temperature=0.8):
    """Generate text from prompt"""
    # Simplified - you'd use proper tokenization
    input_ids = torch.randint(0, 1000, (1, 10)).to(device)

    generated, thinking_log = model.generate(
        input_ids,
        max_tokens=max_length,
        temperature=temperature
    )

    # Format output
    output = f"Generated text: [Model output would appear here]\n\n"

    if config.enable_thinking:
        output += "Thinking Analysis:\n"
        avg_steps = np.mean([log['think_steps'] for log in thinking_log])
        output += f"Average thinking steps: {avg_steps:.2f}\n"

    return output

# Create Gradio interface
iface = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(label="Input Prompt", lines=3),
        gr.Slider(50, 500, value=100, label="Max Length"),
        gr.Slider(0.1, 2.0, value=0.8, label="Temperature")
    ],
    outputs=gr.Textbox(label="Generated Output", lines=10),
    title="Gazelle 0.5B Demo",
    description="Test the Gazelle model with adaptive thinking"
)

# Launch demo
iface.launch(share=True)