# Full Token Sonar Training (Production Grade)

This notebook implements a robust training pipeline for the **Full Token Sonar** student model. It is designed for "set it and forget it" execution on platforms like Kaggle or Colab, with features for checkpointing, easy resuming, and validation monitoring.

### Features:
- **Architecture**: 6-layer Transformer Encoder (Student) vs 24-layer SONAR (Teacher).
- **Robustness**: Automated Checkpointing (Best & Last), Early Stopping.
- **Optimization**: Mixed Precision (AMP), Multi-GPU DataParallel.
- **Resume Capability**: Automatically detects and loads the best previous checkpoint.
- **Verification**: Integrated Unit Tests to ensure model integrity before training.
- **Integration**: GitHub synchronization for code updates.

In [None]:
# @title 1. Environment Setup & Dependencies
import os
import sys
import shutil
from pathlib import Path

# Detect Environment
IS_KAGGLE = os.environ.get('KAGGLE_KERNEL_RUN_TYPE', '') != ''
IS_COLAB = 'google.colab' in sys.modules

print(f"Environment: {'Kaggle' if IS_KAGGLE else 'Colab' if IS_COLAB else 'Local'}")

# AGGRESSIVE CLEANUP
# We must uninstall numpy and pandas first to remove any system-level mismatched binaries.
print("Cleaning up environment (Uninstalling existing numpy/pandas)...")
%pip uninstall -y numpy pandas

# FRESH INSTALL
# Pin numpy==1.26.4 explicitly to ensure maximum compatibility.
print("Installing dependencies...")
%pip install -U --force-reinstall "numpy==1.26.4" "pandas<2.2.0" datasets sonar-space fairseq2 onnxruntime-gpu

# Verify Imports Immediately
print("Verifying installations...")
try:
    import numpy
    print(f"✓ Numpy Version: {numpy.__version__}")
    import pandas
    print(f"✓ Pandas Version: {pandas.__version__}")
    import sonar
    from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
    print(f"✓ SONAR installed successfully")
except (ImportError, ValueError) as e:
    print(f"❌ CRITICAL INSTALLATION ERROR: {e}")
    print("\n!!! PLEASE RESTART THE KERNEL AND RUN THIS CELL AGAIN !!!")
    print("In Colab: Runtime > Restart Session")
    print("In Kaggle: Run > Restart Kernel")
    raise e


In [None]:
# @title 2. Git & Checkpoint Management
# Setup paths to persist data across sessions

if IS_KAGGLE:
    # Kaggle specific paths
    WORK_DIR = Path("/kaggle/working")
    CHECKPOINT_DIR = Path("/kaggle/working/checkpoints")
    # Try to load input dataset if available (pseudo-code path)
    INPUT_DIR = Path("/kaggle/input/fst-checkpoints")
elif IS_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    WORK_DIR = Path("/content/fst")
    CHECKPOINT_DIR = Path("/content/drive/MyDrive/FST_Checkpoints")
else:
    # Local
    WORK_DIR = Path(os.getcwd())
    CHECKPOINT_DIR = WORK_DIR / "checkpoints"

CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
print(f"Checkpoints will be saved to: {CHECKPOINT_DIR}")

# GitHub Sync (Optional - for code updates)
REPO_URL = "https://github.com/KidIkaros/fst.git"
if not (WORK_DIR / ".git").exists() and (IS_KAGGLE or IS_COLAB):
    print("Cloning repository...")
    !git clone $REPO_URL $WORK_DIR
else:
    print("Repository already present or local mode.")
    # If we are in the repo but want to pull latest changes (e.g. after a fix)
    if (WORK_DIR / ".git").exists() and (IS_KAGGLE or IS_COLAB):
        print("Pulling latest changes...")
        !cd $WORK_DIR && git pull

# Reconstruct Checkpoint from Split Parts (if needed)
def reconstruct_checkpoint():
    # Check if we have parts but no main file
    parts = sorted([str(p) for p in (WORK_DIR / "checkpoints").glob("full_token_sonar_best_part_*")])
    if not parts:
         # Check directly in current dir just in case
         parts = sorted([str(p) for p in (WORK_DIR).glob("checkpoints/full_token_sonar_best_part_*")])

    target = CHECKPOINT_DIR / "full_token_sonar_best.pt"
    
    if parts and not target.exists():
        print(f"Found {len(parts)} checkpoint parts. Reconstructing to {target}...")
        try:
            with open(target, "wb") as outfile:
                for part in parts:
                    print(f"Mergin {part}...")
                    with open(part, "rb") as infile:
                        outfile.write(infile.read())
            print("Reconstruction complete.")
        except Exception as e:
            print(f"Reconstruction failed: {e}")
    elif target.exists():
        print("Checkpoint file already exists.")
    else:
        print("No checkpoint parts found to reconstruct.")

reconstruct_checkpoint()

In [None]:
# @title 3. Model Definition
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SinusoidalPositionEncoder(nn.Module):
    """Sinusoidal positional encoding"""
    def __init__(self, encoding_dim, max_seq_len=514):
        super().__init__()
        self.encoding_dim = encoding_dim
        position = torch.arange(max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, encoding_dim, 2) * (-math.log(10000.0) / encoding_dim))
        pe = torch.zeros(max_seq_len, encoding_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class FullTokenSonar(nn.Module):
    def __init__(self, vocab_size=256206, embed_dim=512, layers=6, num_heads=8, output_dim=1024, dropout=0.1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=1)
        self.pos_encoder = SinusoidalPositionEncoder(embed_dim, max_seq_len=514)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * 4,
            batch_first=True,
            activation="gelu",
            norm_first=True,
            dropout=dropout
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=layers)
        
        self.attention_pool = nn.Sequential(
            nn.Linear(embed_dim, 128),
            nn.Tanh(),
            nn.Linear(128, 1)
        )
        
        self.projection = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, output_dim),
            nn.LayerNorm(output_dim)
        )
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_ids, attention_mask):
        x = self.embed(input_ids)
        x = self.pos_encoder(x)
        x = self.dropout(x)
        x = self.encoder(x, src_key_padding_mask=attention_mask)
        
        attn_scores = self.attention_pool(x).squeeze(-1)
        mask_float = attention_mask.float() * -1e9
        attn_scores = attn_scores + mask_float
        attn_weights = F.softmax(attn_scores, dim=-1).unsqueeze(-1)
        
        x_pooled = torch.sum(x * attn_weights, dim=1)
        x = self.projection(x_pooled)
        return F.normalize(x, p=2, dim=-1)

In [None]:
# @title 4. Unit Tests
import unittest

class TestFullTokenSonar(unittest.TestCase):
    def setUp(self):
        self.model = FullTokenSonar(vocab_size=100, embed_dim=32, layers=2, num_heads=4, output_dim=64)
        self.model.eval()

    def test_output_shape(self):
        batch_size = 2
        seq_len = 10
        input_ids = torch.randint(0, 100, (batch_size, seq_len))
        mask = torch.zeros((batch_size, seq_len), dtype=torch.bool)
        
        output = self.model(input_ids, mask)
        self.assertEqual(output.shape, (batch_size, 64))
        
    def test_masking_effect(self):
        # Ensure that masking actually ignores padded tokens
        input_ids = torch.randint(0, 100, (1, 10))
        mask1 = torch.zeros((1, 10), dtype=torch.bool)
        mask2 = torch.zeros((1, 10), dtype=torch.bool)
        mask2[0, 5:] = True # Mask last 5 tokens
        
        # We expect different outputs since pooling should ignore masked tokens
        out1 = self.model(input_ids, mask1)
        out2 = self.model(input_ids, mask2)
        
        self.assertFalse(torch.allclose(out1, out2), "Masking did not affect output")

def run_tests():
    suite = unittest.TestLoader().loadTestsFromTestCase(TestFullTokenSonar)
    unittest.TextTestRunner(verbosity=2).run(suite)

run_tests()

In [None]:
# @title 5. Training Infrastructure (The Engine)
import random
from datasets import load_dataset
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline

class Trainer:
    def __init__(self, student, teacher, device, learning_rate=2e-4):
        self.student = student
        self.teacher = teacher
        self.device = device
        
        # Check for multi-GPU
        if torch.cuda.device_count() > 1:
            print(f"Trainer: Wrapping model in DataParallel ({torch.cuda.device_count()} GPUs)")
            self.student_parallel = nn.DataParallel(student)
        else:
            self.student_parallel = student
            
        self.optimizer = torch.optim.AdamW(
            student.parameters(), lr=learning_rate, weight_decay=0.01, betas=(0.9, 0.999)
        )
        self.scaler = torch.amp.GradScaler('cuda')
        self.best_loss = float('inf')
        self.patience_counter = 0
        
    def save_checkpoint(self, path, epoch, loss):
        # Handle DataParallel unwrapping
        model_state = self.student_parallel.module.state_dict() if isinstance(self.student_parallel, nn.DataParallel) else self.student_parallel.state_dict()
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model_state,
            'optimizer_state_dict': self.optimizer.state_dict(),
            'loss': loss,
        }, path)
        
    def load_checkpoint(self, path):
        if not os.path.exists(path):
            print(f"No checkpoint found at {path}")
            return 0
            
        print(f"Loading checkpoint from {path}...")
        checkpoint = torch.load(path, map_location=self.device)
        
        # Handle keys if model was saved as DataParallel but loading to single or vice versa
        # (Ideally we save unwrapped, which we do above)
        self.student.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        return checkpoint.get('epoch', 0)

    def train_epoch(self, texts, batch_size=16, step_callback=None):
        self.student_parallel.train()
        total_loss = 0
        steps = len(texts) // batch_size
        
        random.shuffle(texts)
        
        for i in range(steps):
            batch_texts = texts[i*batch_size : (i+1)*batch_size]
            
            # Batch Preparation (CPU)
            batch_tokens = []
            max_len = 0
            for text in batch_texts:
                try:
                    enc = self.teacher.tokenizer.create_encoder()(text)
                    if len(enc) < 512: 
                        batch_tokens.append(enc)
                        max_len = max(max_len, len(enc))
                except: continue
            
            if len(batch_tokens) < 1: continue
            
            # Padding
            input_ids = torch.zeros((len(batch_tokens), max_len), dtype=torch.long)
            mask = torch.ones((len(batch_tokens), max_len), dtype=torch.bool)
            
            for j, tok in enumerate(batch_tokens):
                input_ids[j, :len(tok)] = tok
                mask[j, :len(tok)] = False
                
            input_ids = input_ids.to(self.device)
            mask = mask.to(self.device)
            
            try:
                # Teacher Logic
                with torch.no_grad():
                    t_vecs = self.teacher.predict(batch_texts, source_lang="eng_Latn")
                    t_vecs = F.normalize(t_vecs, p=2, dim=-1)

                # Student Logic
                with torch.amp.autocast('cuda'):
                    s_vecs = self.student_parallel(input_ids, mask)
                    loss = 1.0 - F.cosine_similarity(s_vecs, t_vecs).mean()
                
                self.optimizer.zero_grad()
                self.scaler.scale(loss).backward()
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.student_parallel.parameters(), 1.0)
                self.scaler.step(self.optimizer)
                self.scaler.update()
                
                total_loss += loss.item()
                if step_callback: step_callback(i, loss.item())
                
            except RuntimeError as e:
                if "out of memory" in str(e):
                    torch.cuda.empty_cache()
                else:
                    print(f"Error: {e}")
        
        return total_loss / steps

In [None]:
# @title 6. Execution

# 1. Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Initializing SONAR Teacher... (this takes a moment)")
teacher = TextToEmbeddingModelPipeline(
    encoder="text_sonar_basic_encoder", 
    tokenizer="text_sonar_basic_encoder", 
    device=device,
    dtype=torch.float32
)

student = FullTokenSonar().to(device)
trainer = Trainer(student, teacher, device)

# 2. Smart Resume
resume_path = CHECKPOINT_DIR / "full_token_sonar_best.pt"
start_epoch = 0

if resume_path.exists():
    print("Found existing checkpoint. Resuming...")
    start_epoch = trainer.load_checkpoint(str(resume_path))
    print(f"Resuming from Epoch {start_epoch}")
else:
    print("No checkpoint found. Initializing from scratch.")
    # Initialize embeddings only if starting fresh
    teacher_embed = teacher.model.encoder_frontend.embed.weight.data
    student.embed.weight.data[:, :] = teacher_embed[:, :512]

# 3. Data Loading
# (Simplified for demo) - In production use the robust loader from previous cells
print("Loading Dataset...")
try:
    dataset = load_dataset("wikitext", "wikitext-2-v1", split="train")
    texts = [x['text'].strip() for x in dataset if len(x['text']) > 50]
    texts = texts[:20000] # Cap for demo speed
    print(f"Loaded {len(texts)} samples.")
except Exception as e:
    print(f"Data load failed: {e}")
    texts = ["Test sentence one.", "Another test sentence."] * 100

# 4. Run
MAX_EPOCHS = 20
PATIENCE = 3

for epoch in range(start_epoch, MAX_EPOCHS):
    print(f"\n=== Epoch {epoch+1}/{MAX_EPOCHS} ===")
    
    def log_step(step, loss):
        if step % 50 == 0:
            print(f"Step {step} | Loss: {loss:.4f} | Sim: {1-loss:.4f}")
            
    avg_loss = trainer.train_epoch(texts, batch_size=32 if torch.cuda.device_count() > 1 else 16, step_callback=log_step)
    print(f"Epoch Complete. Avg Loss: {avg_loss:.4f}")
    
    # Save Logic
    if avg_loss < trainer.best_loss:
        print(f"New Best Model (Loss {avg_loss:.4f} < {trainer.best_loss:.4f}). Saving...")
        trainer.best_loss = avg_loss
        trainer.save_checkpoint(str(CHECKPOINT_DIR / "full_token_sonar_best.pt"), epoch, avg_loss)
        trainer.patience_counter = 0
    else:
        trainer.patience_counter += 1
        print(f"No improvement. Patience {trainer.patience_counter}/{PATIENCE}")
        
    # Regular Checkpoint
    trainer.save_checkpoint(str(CHECKPOINT_DIR / "last_checkpoint.pt"), epoch, avg_loss)
    
    # Early Stopping
    if trainer.patience_counter >= PATIENCE:
        print("Early stopping triggered.")
        break