# VishwamAI: GSM8K Training Pipeline

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vishwamai/vishwamai/blob/main/train_gsm8k.ipynb)

This notebook implements the training pipeline for VishwamAI on the GSM8K dataset.

## Setup

First, let's set up our environment and install required packages.

In [1]:
!pip install -q jax jaxlib
!pip install -q flax optax
!pip install -q datasets transformers huggingface_hub
!pip install -q tqdm einops

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.2/69.2 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m39.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.8/194.8 kB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m274.9/274.9 kB[0m [31m19.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
# Clone VishwamAI repository
!git clone https://github.com/vishwamai/vishwamai.git
!cd vishwamai

Cloning into 'vishwamai'...
remote: Enumerating objects: 1946, done.[K
remote: Counting objects: 100% (335/335), done.[K
remote: Compressing objects: 100% (289/289), done.[K
remote: Total 1946 (delta 52), reused 305 (delta 38), pack-reused 1611 (from 1)[K
Receiving objects: 100% (1946/1946), 34.98 MiB | 50.60 MiB/s, done.
Resolving deltas: 100% (890/890), done.


In [3]:
import os
import json
from pathlib import Path
import jax
import jax.numpy as jnp
from datasets import load_dataset
from huggingface_hub import HfFolder
from tqdm.auto import tqdm

# Import VishwamAI modules
from vishwamai.model import VishwamAIModel, ModelConfig
from vishwamai.tokenizer import VishwamAITokenizer
from vishwamai.training import create_train_state, train_epoch

ModuleNotFoundError: No module named 'vishwamai.model'

## Authentication

Let's set up authentication for Hugging Face Hub.

In [4]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## Load GSM8K Dataset

In [5]:
# Load GSM8K dataset
dataset = load_dataset("openai/gsm8k", "main")
print(f"Train size: {len(dataset['train'])}")
print(f"Test size: {len(dataset['test'])}")

# Display sample
print("\nSample question:")
print(dataset['train'][0]['question'])
print("\nSample answer:")
print(dataset['train'][0]['answer'])

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/7.94k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

Train size: 7473
Test size: 1319

Sample question:
Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?

Sample answer:
Natalia sold 48/2 = <<48/2=24>>24 clips in May.
Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.
#### 72


## Prepare Training Data

In [6]:
def format_example(example):
    """Format GSM8K example for training."""
    return {
        'text': f"Question: {example['question']}\nAnswer: {example['answer']}"
    }

# Format datasets
train_dataset = dataset['train'].map(format_example)
test_dataset = dataset['test'].map(format_example)

print("Sample formatted example:")
print(train_dataset[0]['text'])

Map:   0%|          | 0/7473 [00:00<?, ? examples/s]

Map:   0%|          | 0/1319 [00:00<?, ? examples/s]

Sample formatted example:
Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
Answer: Natalia sold 48/2 = <<48/2=24>>24 clips in May.
Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.
#### 72


## Initialize Model and Tokenizer

In [7]:
# Load model configuration
config_path = "vishwamai/configs/config_10B.json"
with open(config_path) as f:
    config = ModelConfig(**json.load(f))

# Initialize model
model = VishwamAIModel(config)

# Initialize tokenizer
tokenizer = VishwamAITokenizer.from_pretrained("gpt2")  # Base tokenizer
tokenizer.save_pretrained("tokenizer")

FileNotFoundError: [Errno 2] No such file or directory: 'vishwamai/configs/config_10B.json'

## Training Setup

In [None]:
def create_data_loader(dataset, tokenizer, batch_size):
    """Create a data loader for training."""
    def tokenize(examples):
        return tokenizer(
            examples['text'],
            padding='max_length',
            truncation=True,
            max_length=config.max_seq_len,
            return_tensors='np'
        )

    tokenized = dataset.map(
        tokenize,
        batched=True,
        remove_columns=dataset.column_names
    )

    return tokenized.with_format('numpy').iter(batch_size=batch_size)

# Training parameters
batch_size = 32
num_epochs = 10
learning_rate = 1e-4

# Create data loaders
train_loader = create_data_loader(train_dataset, tokenizer, batch_size)
test_loader = create_data_loader(test_dataset, tokenizer, batch_size)

# Initialize training state
rng = jax.random.PRNGKey(42)
state = create_train_state(model, config, learning_rate, rng)

## Training Loop

In [None]:
# Create output directory
output_dir = Path("checkpoints")
output_dir.mkdir(exist_ok=True)

# Training loop
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")

    # Train
    rng, epoch_rng = jax.random.split(rng)
    state, metrics = train_epoch(
        state=state,
        train_loader=train_loader,
        rng=epoch_rng,
        error_correction=None,  # No error correction for initial training
        epoch=epoch + 1
    )

    print(f"Train - Loss: {metrics['loss']:.4f}, Accuracy: {metrics['accuracy']:.4f}")

    # Save checkpoint
    if (epoch + 1) % 2 == 0:
        checkpoint_dir = output_dir / f"checkpoint-{epoch+1}"
        model.save_pretrained(checkpoint_dir)
        tokenizer.save_pretrained(checkpoint_dir)

## Push to Hugging Face Hub

In [None]:
# Push final model to hub
model.push_to_hub("VishwamAI/VishwamAI", commit_message="Trained on GSM8K")
tokenizer.push_to_hub("VishwamAI/VishwamAI", commit_message="Updated tokenizer")

## Evaluation

In [None]:
def evaluate_model(model, test_loader):
    """Evaluate model on test set."""
    total_loss = 0
    total_correct = 0
    total_samples = 0

    for batch in tqdm(test_loader, desc="Evaluating"):
        outputs = model(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask']
        )

        logits = outputs['logits']
        predictions = jnp.argmax(logits, axis=-1)

        # Compute accuracy
        correct = (predictions == batch['labels']) * batch['attention_mask']
        total_correct += jnp.sum(correct)
        total_samples += jnp.sum(batch['attention_mask'])

        # Compute loss
        loss = compute_loss(logits, batch['labels'], batch['attention_mask'])
        total_loss += loss * jnp.sum(batch['attention_mask'])

    return {
        'loss': total_loss / total_samples,
        'accuracy': total_correct / total_samples
    }

# Evaluate final model
metrics = evaluate_model(model, test_loader)
print(f"\nTest Results:")
print(f"Loss: {metrics['loss']:.4f}")
print(f"Accuracy: {metrics['accuracy']:.4f}")