In [None]:
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from vishwamai.training import VishwamaiTrainer
from vishwamai.conceptual_tokenizer import ConceptualTokenizer, ConceptualTokenizerConfig
from vishwamai.architecture import init_model, VishwamaiConfig
from vishwamai.generate import VishwamaiGenerator, GenerationConfig
from vishwamai.deepthinking import COTGenerationWrapper,GRPOTrainer,ReasioningDataset
from vishwamai.deep_reasoning import ReasioningOutput,EnhancedReasioning
import logging
import matplotlib.pyplot as plt

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load the GSM8K dataset from local parquet files
train_dataset = load_dataset(
    'parquet',
    data_files={'train': 'gsm8k/train-00000-of-00001.parquet'},
    split='train'
)
test_dataset = load_dataset(
    'parquet',
    data_files={'test': 'gsm8k/test-00000-of-00001.parquet'},
    split='test'
)

# Preprocess the dataset for training and testing
def preprocess_function(examples):
    tokenizer_config = ConceptualTokenizerConfig(vocab_size=32000, max_length=512)
    tokenizer = ConceptualTokenizer(tokenizer_config)
    model_inputs = tokenizer.batch_encode_with_concepts(examples['question'])
    model_inputs['labels'] = tokenizer.batch_encode_with_concepts(examples['answer'])['input_ids']
    return model_inputs

train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=['question', 'answer'])
test_dataset = test_dataset.map(preprocess_function, batched=True, remove_columns=['question', 'answer'])

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Initialize model configuration
config = VishwamaiConfig(
    vocab_size=32000,
    hidden_size=768,
    num_hidden_layers=12,
    num_attention_heads=12,
    num_key_value_heads=4,
    intermediate_size=3072,
    max_position_embeddings=2048
)

# Initialize model
model = init_model(config)

# Initialize trainer
trainer = VishwamaiTrainer(
    model=model,
    tokenizer=ConceptualTokenizer(ConceptualTokenizerConfig(vocab_size=32000, max_length=512)),
    train_dataset=train_loader,
    eval_dataset=val_loader,
    device="cuda" if torch.cuda.is_available() else "cpu",
    optimizer_class=torch.optim.AdamW,
    scheduler_class=torch.optim.lr_scheduler.CosineAnnealingLR
)

# Train the model
trainer.train(
    num_epochs=10,
    save_dir="./models",
    evaluation_steps=100,
    save_steps=1000,
    logging_steps=10
)

# Generate questions and answers based on the trained model
generator = VishwamaiGenerator(
    model=model,
    tokenizer=ConceptualTokenizer(ConceptualTokenizerConfig(vocab_size=32000, max_length=512)),
    config=GenerationConfig(max_length=100, temperature=0.7, top_p=0.9)
)

test_problems = [
    "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and sells the rest at the farmers market daily for $2 per egg. How much money does she make every day at the farmers market?",
    "A robe takes 2 blue pieces of cloth and 5 white pieces of cloth. If I have 18 blue pieces and 45 white pieces, how many complete robes can I make?",
    "John has 5 times as many marbles as Peter. If Peter has 8 marbles, how many marbles does John have?"
]

for problem in test_problems:
    answer = generator.generate(problem)
    logger.info(f"\nProblem: {problem}\nAnswer: {answer[0]}")
