In [2]:
# Install Dependencies (Colab / Cloud)
import sys
import subprocess

def install_packages():
    print("üì¶ Installing dependencies...")
    packages = [
        "torch>=2.0.0",
        "transformers>=4.30.0",
        "datasets>=2.12.0",
        "bitsandbytes>=0.41.0", 
        "accelerate>=0.20.0"
    ]
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install"] + packages)
        print("‚úÖ Dependencies installed.")
    except Exception as e:
        print(f"‚ùå Error installing packages: {e}")

# Run install if in Colab/Cloud
try:
    import google.colab
    IN_COLAB = True
    install_packages()
except ImportError:
    IN_COLAB = False
    print("‚ÑπÔ∏è Local environment detected. Skipping auto-install (ensure requirements are met).")

üì¶ Installing dependencies...
‚úÖ Dependencies installed.


In [3]:
import os
import time
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from datasets import load_dataset
import bitsandbytes as bnb

# Handle environment
try:
    from google.colab import drive
    IN_COLAB = True
    print("‚úÖ Detected Colab Environment")
except ImportError:
    IN_COLAB = False
    print("‚ö†Ô∏è Not running in Colab - Google Drive features disabled")

‚úÖ Detected Colab Environment


In [None]:
# ==========================================
# 1. CONFIGURATION
# ==========================================
class Config:
    vocab_size = 50257   # GPT2 Tokenizer standard - Future proof for web data
    d_model = 2048       # 1B Scale hidden dim
    n_layer = 18         # Depth
    head_size = 64       # Tensor Core friendly
    grad_accum_steps = 64 # INCREASED: To maintain batch size with lower micro-batch
    micro_batch_size = 1  # REDUCED: Critical for 1B on T4 to prevent OOM
    learning_rate = 4e-4
    max_seq_len = 512     # TinyStories context
    project_name = "groundthink_1B"

config = Config()

In [5]:
# ==========================================
# 2. THE SELECTIVE-WKV BLOCK (CHUNKING ENABLED)
# ==========================================
class SelectiveWKV_1B(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dim = config.d_model
        self.n_head = config.d_model // config.head_size
        self.head_size = config.head_size
        
        # Projections for Selective Gates (Mamba style)
        self.x_proj = nn.Linear(self.dim, self.dim, bias=False)
        self.w_proj = nn.Linear(self.dim, self.dim) # Selective Decay
        
        # RWKV-style Key/Value/Receptance
        self.k_proj = nn.Linear(self.dim, self.dim, bias=False)
        self.v_proj = nn.Linear(self.dim, self.dim, bias=False)
        self.r_proj = nn.Linear(self.dim, self.dim, bias=False)
        self.out_proj = nn.Linear(self.dim, self.dim, bias=False)
        
        # Layer Norms
        self.ln_x = nn.LayerNorm(self.dim)

    def forward(self, x, state=None):
        B, T, C = x.size()
        x = self.ln_x(x)
        
        # Selection logic
        w = torch.sigmoid(self.w_proj(self.x_proj(x))) 
        
        k = self.k_proj(x)
        v = self.v_proj(x)
        r = torch.sigmoid(self.r_proj(x))
        
        # Reshape for head-wise matrix updates
        k = k.view(B, T, self.n_head, self.head_size, 1)
        v = v.view(B, T, self.n_head, 1, self.head_size)
        w = w.view(B, T, self.n_head, self.head_size, 1) 
        
        # Initial State: [B, n_head, head_size, head_size]
        if state is None:
            state = torch.zeros(B, self.n_head, self.head_size, self.head_size, device=x.device)
        
        # CHUNKED RECURRENCE (For Training Efficiency/Stability)
        chunk_size = 64 # Good balance for T4 GPU
        num_chunks = (T + chunk_size - 1) // chunk_size
        
        outputs = []
        
        for i in range(num_chunks):
            start = i * chunk_size
            end = min(start + chunk_size, T)
            
            # Slice chunk
            k_chunk = k[:, start:end]
            v_chunk = v[:, start:end]
            w_chunk = w[:, start:end]
            r_chunk = r[:, start:end]
            
            chunk_out = []
            
            # Inner loop (Scan within chunk)
            for t in range(end - start):
                # Matrix update: S = (1-w)*S + (k @ v^T)
                kv = k_chunk[:, t] @ v_chunk[:, t]
                state = (1 - w_chunk[:, t]) * state + kv
                
                # Read: r @ state
                ctx = r_chunk[:, t].view(B, self.n_head, 1, self.head_size) @ state
                chunk_out.append(ctx.view(B, C))
            
            outputs.append(torch.stack(chunk_out, dim=1))
            
        return self.out_proj(torch.cat(outputs, dim=1)), state

In [6]:
class GroundThinkBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.d_model)
        self.mixer = SelectiveWKV_1B(config)
        self.ln2 = nn.LayerNorm(config.d_model)
        self.mlp = nn.Sequential(
            nn.Linear(config.d_model, 4 * config.d_model),
            nn.GELU(),
            nn.Linear(4 * config.d_model, config.d_model)
        )

    def forward(self, x):
        mixer_out, _ = self.mixer(self.ln1(x))
        x = x + mixer_out
        x = x + self.mlp(self.ln2(x))
        return x

class GroundThink1B(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.blocks = nn.ModuleList([GroundThinkBlock(config) for _ in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        # Weight tying
        self.token_emb.weight = self.head.weight

    def forward(self, idx, targets=None):
        B, T = idx.size()
        x = self.token_emb(idx)
        
        for block in self.blocks:
            x = block(x)
            
        x = self.ln_f(x)
        logits = self.head(x)
        
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
            
        return logits, loss

In [7]:
# ==========================================
# 3. DATA & TOKENIZATION (TinyStories)
# ==========================================
def get_dataloaders(config):
    print("üìö Loading TinyStories dataset...")
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    
    dataset = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
    
    # Custom collate for streaming dataset
    def collate_fn(batch):
        texts = [item['text'] for item in batch]
        encoded = tokenizer(
            texts, 
            padding=True, 
            truncation=True, 
            max_length=config.max_seq_len, 
            return_tensors="pt"
        )
        input_ids = encoded['input_ids']
        labels = input_ids.clone()
        return input_ids, labels

    return DataLoader(dataset, batch_size=config.micro_batch_size, collate_fn=collate_fn)

In [None]:
# ==========================================
# 4. TRAINING LOOP (Pure FP16 - No Scaler)
# ==========================================
import gc

def learning_rate_schedule(step, warmup_steps=1000, max_lr=4e-4):
    if step < warmup_steps:
        return max_lr * (step / warmup_steps)
    return max_lr 

def train_step(model, optimizer, dataloader, config, start_run_step=0):
    print("üî• Starting Training (Pure FP16 Mode)...")
    
    # Aggressive memory cleanup before start
    gc.collect()
    torch.cuda.empty_cache()
    
    model.train()
    optimizer.zero_grad()
    
    # NOTE: GradScaler is REMOVED. 
    # FP16 Weights (required for memory) + GradScaler (requires FP32 grads) = Incompatible.
    # We will run in pure FP16 accumulation.
    
    running_loss = 0
    t0 = time.time()
    
    # Setup checkpoints
    save_dir = f"checkpoints/{config.project_name}"
    if IN_COLAB:
        save_dir = f"/content/drive/MyDrive/{config.project_name}"
    os.makedirs(save_dir, exist_ok=True)
    checkpoint_path = os.path.join(save_dir, "latest.pt")
    
    # Attempt to load checkpoint (Scaler state ignored)
    if os.path.exists(checkpoint_path):
        pass # Optimizer state load handles the chaos

    for step, (x, y) in enumerate(dataloader, start=start_run_step):
        # Update LR with warmup
        lr = learning_rate_schedule(step, max_lr=config.learning_rate)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        x, y = x.cuda(), y.cuda()
        
        # Forward Pass
        try:
            # Autocast still useful for ops that demand FP32 (like Softmax) internally
            with torch.amp.autocast('cuda'):
                _, loss = model(x, y)
                loss = loss / config.grad_accum_steps
            
            # Backward (No Scaler)
            loss.backward()
            
            # Detach to save memory
            running_loss += loss.item() * config.grad_accum_steps
            
        except torch.cuda.OutOfMemoryError:
            print(f"‚ö†Ô∏è OOM at step {step}. Attempting recovery...")
            torch.cuda.empty_cache()
            continue
            
        # Step (Gradient Accumulation)
        if (step + 1) % config.grad_accum_steps == 0:
            
            # Clip Grads (Standard)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            # Optimizer Step
            optimizer.step()
            optimizer.zero_grad()
            
            # Logging
            if (step + 1) % 100 == 0:
                dt = time.time() - t0
                print(f"Step {step+1} | Loss: {running_loss/config.grad_accum_steps:.4f} | LR: {lr:.2e} | Time: {dt:.2f}s")
                running_loss = 0
                t0 = time.time()
        
        # Checkpoint (Resilient)
        if (step + 1) % 500 == 0:
            print(f"üíæ Saving checkpoint at step {step+1}...")
            torch.save({
                'step': step + 1,
                'model': model.state_dict(),
                'opt': optimizer.state_dict(),
                # No scaler state
            }, checkpoint_path)
            
            # Periodic cleanup
            torch.cuda.empty_cache()

In [None]:
# Execute Training

# 0. MEMORY RECOVERY & CHECK
import gc
import sys

# Aggressive cleanup
keys_to_clean = ['model', 'optimizer', 'dataloader', 'scaler', 'x', 'y', 'loss']
for key in keys_to_clean:
    if key in globals():
        del globals()[key]
gc.collect()
torch.cuda.empty_cache()

# VRAM Status Check
t = torch.cuda.get_device_properties(0).total_memory
r = torch.cuda.memory_reserved(0)
a = torch.cuda.memory_allocated(0)
free = t - r

print(f"üìä VRAM Status: {free/1024**3:.2f}GB Free | {a/1024**3:.2f}GB Allocated")

# LOCK: Prevent execution if Zombie Process exists
if free < 6 * 1024**3:
     print("‚ö†Ô∏è WARNING: Less than 6GB Free. Trying anyway with FP16 optimization...")

# 1. Setup Drive
if IN_COLAB:
    drive.mount('/content/drive', force_remount=True)
    
# 2. Initialize
# MVP FALLBACK OPTION: Un-comment these lines if 1B fails repeatedly
# config = Config()
# config.d_model = 768  # 150M Scale (Approximation)
# config.n_layer = 12
# config.project_name = "groundthink_150M_MVP"

config = Config()
print(f"üöÄ Initializing GroundThink ({config.d_model} dim, {config.n_layer} layers)")

# OPTIMIZATION: Robust FP16 Initialization
print("üìâ Converting model to FP16 (Half Precision)...")
try:
    model = GroundThink1B(config)
    model.to(dtype=torch.float16)  # Convert ALL weights to FP16
    
    # STABILITY FIX: Keep LayerNorms in FP32
    # Pure FP16 LayerNorms are notoriously unstable (exploding gradients)
    print("üõ°Ô∏è Restoring LayerNorms to FP32 for stability...")
    for module in model.modules():
        if isinstance(module, nn.LayerNorm):
            module.float()

    model = model.cuda() # Move to GPU
    print(f"‚úÖ Model loaded on GPU. VRAM: {torch.cuda.memory_allocated(0)/1024**3:.2f}GB")
except Exception as e:
    print(f"‚ùå Model Init Failed: {e}")
    raise

# 3. Optimizer (8-bit)
# 8-bit Adam works natively with FP16 params
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=config.learning_rate)

# 4. Resume logic
start_step = 0
save_dir = f"/content/drive/MyDrive/{config.project_name}" if IN_COLAB else f"checkpoints/{config.project_name}"
checkpoint_path = os.path.join(save_dir, "latest.pt")

if os.path.exists(checkpoint_path):
    print(f"üîÑ Resuming from {checkpoint_path}")
    ckpt = torch.load(checkpoint_path)
    model.load_state_dict(ckpt['model'])
    optimizer.load_state_dict(ckpt['opt'])
    start_step = ckpt['step']
else:
    print("üÜï Starting fresh run")

# 5. Run
dataloader = get_dataloaders(config)
train_step(model, optimizer, dataloader, config, start_step)

KeyboardInterrupt: 