In [None]:
import os
import re
import pandas as pd
import torch
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from vishwamai.model import VishwamaiModel, VishwamaiConfig
from vishwamai.training import VishwamaiTrainer
from vishwamai.conceptual_tokenizer import ConceptualTokenizer, ConceptualTokenizerConfig
from torch.utils.data import DataLoader, Dataset

In [None]:
# Test tokenizer initialization first
def test_tokenizer_setup():
    config = ConceptualTokenizerConfig(
        vocab_size=256,  # Increased from 64 to accommodate all characters
        max_length=512,
        model_prefix="test_tokenizer",
        character_coverage=0.9995  # Reduced coverage
    )
    tokenizer = ConceptualTokenizer(config)
    
    # Test basic tokenization
    text = "Test equation: 2 + 2 = 4"
    tokens = tokenizer.encode(text)
    decoded = tokenizer.decode(tokens)
    print(f"Original: {text}")
    print(f"Decoded: {decoded}")
    return tokenizer

# Run test
test_tokenizer = test_tokenizer_setup()

In [None]:
class GSM8KDataset(Dataset):
    def __init__(self, data, max_length=512):
        self.data = data
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        return {'question': item['question'], 'answer': item['answer']}

def collate_fn(batch, tokenizer, max_length=512):
    questions = [f"Question: {item['question']}\nLet's solve this step by step:" for item in batch]
    answers = [item['answer'] for item in batch]
    
    try:
        # Basic tokenization first
        inputs = tokenizer.encode(questions)
        input_ids = torch.tensor(inputs)
        attention_mask = (input_ids != tokenizer.pad_token_id).long()
        
        # Encode targets/labels
        labels = torch.tensor(tokenizer.encode(answers))
        
        # Ensure all tensors have same sequence length
        max_len = max(input_ids.size(1), labels.size(1))
        if max_len > max_length:
            max_len = max_length
            
        # Pad or truncate
        input_ids = input_ids[:, :max_len]
        attention_mask = attention_mask[:, :max_len]
        labels = labels[:, :max_len]
        
        # Add padding
        if max_len < max_length:
            pad_length = max_length - max_len
            input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=tokenizer.pad_token_id)
            attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=0)
            labels = torch.nn.functional.pad(labels, (0, pad_length), value=-100)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }
    except Exception as e:
        print(f"Error in collate_fn: {str(e)}")
        raise e

In [None]:
try:
    train_data = pd.read_parquet('gsm8k/train-00000-of-00001.parquet')
    test_data = pd.read_parquet('gsm8k/test-00000-of-00001.parquet')
    
    print("Loading data successful")
    print(f"Training examples: {len(train_data)}")
    print(f"Test examples: {len(test_data)}")
    
    train_dataset = GSM8KDataset(train_data)
    test_dataset = GSM8KDataset(test_data)
except Exception as e:
    print(f"Error loading data: {str(e)}")

In [None]:
# Model & Tokenizer setup
try:
    # Model configuration
    model_config = VishwamaiConfig(
        vocab_size=256,  # Increased to match tokenizer
        hidden_size=256,
        num_hidden_layers=4,
        num_attention_heads=8,
        intermediate_size=512,
        max_position_embeddings=512
    )

    # Tokenizer configuration
    tokenizer_config = ConceptualTokenizerConfig(
        vocab_size=256,  # Increased to handle all characters
        max_length=512,
        model_prefix="gsm8k_tokenizer",
        concept_tokens=["math", "equation", "solve"],
        reasoning_tokens=["therefore", "because", "result"],
        character_coverage=0.9995,  # Reduced coverage
        control_symbols=["[", "]", "=", "+", "-", "*", "/"],
        user_defined_symbols=["$", "%"]
    )

    print("Creating tokenizer...")
    tokenizer = ConceptualTokenizer(tokenizer_config)
    print("Creating model...")
    model = VishwamaiModel(model_config)

    # Train tokenizer if needed
    if not os.path.exists(f"{tokenizer_config.model_prefix}.model"):
        print("Training tokenizer...")
        all_texts = list(train_data['question'][:100]) + list(train_data['answer'][:100])  # Start with subset
        tokenizer.train_tokenizer(all_texts)
        print("Tokenizer training completed")
    else:
        print("Loading existing tokenizer model")
    
    print("Setup completed successfully")
    
except Exception as e:
    print(f"Error in setup: {str(e)}")
    raise e

In [None]:
try:
    batch_size = 4
    print("Creating data loaders...")
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=lambda x: collate_fn(x, tokenizer)
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=lambda x: collate_fn(x, tokenizer)
    )

    print("Data loaders created successfully")
    
    # Test batch generation
    print("\nTesting batch generation...")
    test_batch = next(iter(train_loader))
    print(f"Batch keys: {test_batch.keys()}")
    print(f"Input shape: {test_batch['input_ids'].shape}")
    
    # Initialize trainer
    print("\nInitializing trainer...")
    trainer = VishwamaiTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_loader,
        eval_dataset=test_loader,
        device="cuda" if torch.cuda.is_available() else "cpu",
        optimizer_class=torch.optim.AdamW,
        use_wandb=False
    )
    
    print("Trainer initialized successfully")
    
except Exception as e:
    print(f"Error: {str(e)}")
    raise e

In [None]:
try:
    save_dir = "math_model_checkpoints"
    os.makedirs(save_dir, exist_ok=True)
    
    trainer.train(
        num_epochs=1,  # Start with 1 epoch for testing
        save_dir=save_dir,
        evaluation_steps=10,
        save_steps=50,
        logging_steps=5,
        gradient_accumulation_steps=4,
        max_grad_norm=1.0,
        fp16=torch.cuda.is_available()
    )
    
except Exception as e:
    print(f"Training error: {str(e)}")
    raise e