# VishwamAI Colab Training

Training with custom VishwamAI Transformer architecture using advanced features.

In [1]:
# Setup Environment
!nvidia-smi

from google.colab import drive
drive.mount('/content/drive')

# Setup directories
import os
DRIVE_DIR = '/content/drive/MyDrive/VishwamAI'
CHECKPOINT_DIR = f'{DRIVE_DIR}/checkpoints'
!mkdir -p {CHECKPOINT_DIR}

Tue Feb 18 16:52:51 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   45C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
# Install dependencies and clone repo
!git clone https://github.com/VishwamAI/VishwamAI.git
%cd VishwamAI

%pip install -q torch transformers datasets accelerate bitsandbytes wandb triton transformer-engine
%pip install -e .

# Secure Hugging Face authentication
from huggingface_hub import login
import os
import getpass

def get_huggingface_token():
    """Get Hugging Face token from environment or prompt"""
    token = os.getenv('HUGGINGFACE_TOKEN')
    if not token:
        print("HUGGINGFACE_TOKEN not found in environment")
        token = getpass.getpass('Enter your Hugging Face token (input will be hidden): ')
        # Store temporarily for this session
        os.environ['HUGGINGFACE_TOKEN'] = token
    return token

try:
    token = get_huggingface_token()
    login(token=token)
    print("Successfully logged in to Hugging Face")
except Exception as e:
    print(f"Error logging in to Hugging Face: {str(e)}")

Cloning into 'VishwamAI'...
remote: Enumerating objects: 1453, done.[K
remote: Counting objects: 100% (422/422), done.[K
remote: Compressing objects: 100% (293/293), done.[K
remote: Total 1453 (delta 209), reused 304 (delta 122), pack-reused 1031 (from 1)[K
Receiving objects: 100% (1453/1453), 34.06 MiB | 22.38 MiB/s, done.
Resolving deltas: 100% (742/742), done.

  lfs.transfer.maxretries
  lfs.transfer.maxverifies
  lfs.transfer.maxconcurrenttransfers
  filter.lfs.clean
  filter.lfs.smudge
  filter.lfs.process
  filter.lfs.required

  lfs.transfer.maxretries
  lfs.transfer.maxverifies
  lfs.transfer.maxconcurrenttransfers
  filter.lfs.clean
  filter.lfs.smudge
  filter.lfs.process
  filter.lfs.required
/content/VishwamAI
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m70.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━

HUGGINGFACE_TOKEN not found in environment
Enter your Hugging Face token (input will be hidden): ··········
Successfully logged in to Hugging Face


In [3]:
# Import required modules
import torch
from transformers import AutoTokenizer
from datasets import load_dataset

# Import VishwamAI components with correct paths
from vishwamai.base_layers import Linear  # Import from base_layers instead of utils
from vishwamai.Transformer import Transformer
from vishwamai import (
    create_model,
    ModelArgs,
    VishwamAITokenizer,
    TokenizerConfig
)

# Import training components
from vishwamai.advanced_training import AdvancedTrainer
from vishwamai.fp8_cast_bf16 import main
from vishwamai.neural_memory import NeuralMemory
from vishwamai.tree_of_thoughts import TreeConfig, RewardConfig
from vishwamai.curriculum import CurriculumConfig
from vishwamai.initialize import initialize_model_and_trainer
from vishwamai.config import ModelArgs
from vishwamai.utils import precompute_freqs_cis
from vishwamai.advanced_training import AdvancedTrainer
from vishwamai.tree_of_thoughts import TreeConfig, RewardConfig
from vishwamai.curriculum import CurriculumConfig
from vishwamai.neural_memory import NeuralMemory

# Import visualization tools
import matplotlib.pyplot as plt
import pandas as pd

# Verify imports were successful
print("Imports completed successfully")

Imports completed successfully


In [4]:
# Initialize visualization and analysis tools
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from datetime import datetime

# Configure plotting style
# plt.style.use('seaborn')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = [12, 6]
plt.rcParams['figure.dpi'] = 100
plt.rcParams['axes.grid'] = True

# Initialize performance tracking
performance_history = {
    'steps': [],
    'loss': [],
    'learning_rate': [],
    'memory_usage': [],
    'curriculum_level': [],
    'expert_usage': [],
    'evaluation_scores': []
}

performance_df = pd.DataFrame(performance_history)

In [5]:
def update_performance_tracking(stats, step):
    """Update performance tracking with new statistics"""
    performance_df.loc[len(performance_df)] = {
        'steps': step,
        'loss': stats['loss'],
        'learning_rate': stats['lr'],
        'memory_usage': stats['memory_usage']['allocated'],
        'curriculum_level': stats['curriculum_stats']['current_difficulty'],
        'expert_usage': sum(stats.get('moe_metrics', {}).values()) / len(stats.get('moe_metrics', {})),
        'evaluation_scores': stats.get('eval_score', 0)
    }

def plot_training_progress():
    """Generate training progress visualization"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Training Progress Overview', fontsize=16)

    axes[0,0].plot(performance_df['steps'], performance_df['loss'])
    axes[0,0].set_title('Training Loss')
    axes[0,0].set_xlabel('Steps')
    axes[0,0].set_ylabel('Loss')

    axes[0,1].plot(performance_df['steps'], performance_df['learning_rate'])
    axes[0,1].set_title('Learning Rate')
    axes[0,1].set_xlabel('Steps')
    axes[0,1].set_ylabel('Learning Rate')

    axes[1,0].plot(performance_df['steps'], performance_df['curriculum_level'])
    axes[1,0].set_title('Curriculum Difficulty')
    axes[1,0].set_xlabel('Steps')
    axes[1,0].set_ylabel('Difficulty Level')

    axes[1,1].plot(performance_df['steps'], performance_df['expert_usage'])
    axes[1,1].set_title('Expert Usage')
    axes[1,1].set_xlabel('Steps')
    axes[1,1].set_ylabel('Average Usage')

    plt.tight_layout()
    plt.savefig(f"{DRIVE_DIR}/training_progress.png")
    plt.show()

print("Visualization and performance tracking initialized")

Visualization and performance tracking initialized


In [6]:
# Create model arguments with explicit typing
model_args = ModelArgs(
    max_batch_size=4,
    max_seq_len=2048,
    dtype="bfloat16",  # Change dtype to 'bfloat16' or 'float32'
    vocab_size=32000,
    dim=1024,
    inter_dim=2816,
    moe_inter_dim=512,
    n_layers=12,
    n_dense_layers=1,
    n_heads=16,
    n_routed_experts=8,
    n_shared_experts=1,
    n_activated_experts=2,
    n_expert_groups=1,
    n_limited_groups=1,
    score_func="softmax",
    route_scale=1.0,
    q_lora_rank=0,
    kv_lora_rank=64,
    qk_nope_head_dim=64,
    qk_rope_head_dim=32,
    v_head_dim=64,
    original_seq_len=2048,
    rope_theta=10000.0,
    rope_factor=20,
    beta_fast=16,
    beta_slow=1,
    mscale=0.5,
    use_alibi=False,
    use_rope_scaling=True,
    gradient_checkpointing=True,
    parallel_attn=True,
    rope_condense_ratio=1.0,
    max_steps=100000  # Set default number of training steps
)

print("\nInitializing neural memory...")
try:
    # Initialize neural memory
    neural_memory = NeuralMemory(
        args=model_args,  # Pass model_args as 'args'
        memory_size=512,  # Size of memory buffer
        num_memory_heads=4  # Number of memory attention heads
    )
    print("✓ Neural memory initialized successfully")
except Exception as e:
    print(f"✗ Neural memory initialization failed: {str(e)}")
    raise



Initializing neural memory...
✓ Neural memory initialized successfully


In [7]:
# Initialize training components
import inspect

# Print the signature of the RewardConfig.__init__ method
print(inspect.signature(RewardConfig.__init__))
# Use the correct parameter names based on TreeConfig signature
tot_config = TreeConfig(
    num_beams=4,  # Corresponds to num_thoughts
    max_depth=3,  # Corresponds to thought_depth
    beam_width=2, # Corresponds to thought_width
    reward_gamma=0.95
)

reward_config = RewardConfig(
    math_reasoning_weight=0.2,  # Part of reasoning_weight
    logical_coherence_weight=0.2,  # Part of reasoning_weight and consistency_weight
    real_world_applicability_weight=0.2, # Keeping as is, could reflect external knowledge
    solution_validity_weight=0.4  # Corresponds to accuracy_weight
)
curriculum_config = CurriculumConfig(
    min_sequence_length=32,
    max_sequence_length=512,
    min_vocab_complexity=0.3,
    max_vocab_complexity=1.0,
    min_reasoning_steps=1,
    max_reasoning_steps=8,
    pacing_function='root',
    total_curriculum_steps=10000,
    performance_threshold=0.8,
    min_samples_before_advance=100,
    smoothing_factor=0.95
)

(self, math_reasoning_weight: float = 0.3, logical_coherence_weight: float = 0.3, real_world_applicability_weight: float = 0.2, solution_validity_weight: float = 0.2) -> None


In [8]:
# Train tokenizer on dataset texts
print("Training tokenizer...")

import logging
import sentencepiece as spm
from pathlib import Path
from datasets import load_dataset # Import load_dataset here

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Ensure tokenizer directory exists
tokenizer_dir = Path("tokenizer")
tokenizer_dir.mkdir(exist_ok=True)

# Prepare training data
def prepare_training_data(datasets, max_samples=10000):
    logger.info("Preparing training data...")
    train_path = tokenizer_dir / "train.txt"

    with open(train_path, "w", encoding="utf-8") as f:
        for name, dataset in datasets.items():
            logger.info(f"Processing {name} dataset")
            if name in ["gsm8k", "mmlu"]:
                for i, item in enumerate(dataset):
                    if i >= max_samples:
                        break
                    if "question" in item and "answer" in item:
                        f.write(f"{item['question']}\n")
                        f.write(f"{item['answer']}\n")
            elif name == "code":
                for i, item in enumerate(dataset):
                    if i >= max_samples:
                        break
                    if "content" in item:
                        f.write(f"{item['content']}\n")

    return train_path

try:
    # Initialize tokenizer with reduced vocabulary size
    tokenizer_config = TokenizerConfig(
        vocab_size=26519,  # Reduced vocab size
        model_prefix="vishwamai",
        character_coverage=0.9995,
        max_sentence_length=2048,
        pad_id=0,
        bos_id=1,
        eos_id=2,
        unk_id=3
    )

    # Load datasets before calling prepare_training_data
    datasets = { # Define datasets here
        "gsm8k": load_dataset("gsm8k", split="train"),
        "mmlu": load_dataset("cais/mmlu", "abstract_algebra", split="validation"),  # Use validation or test split
    }

    # Create training data file
    train_path = prepare_training_data(datasets)
    logger.info(f"Training data saved to {train_path}")

    # Train SentencePiece model
    model_prefix = str(tokenizer_dir / tokenizer_config.model_prefix)
    spm.SentencePieceTrainer.train(
        input=str(train_path),
        model_prefix=model_prefix,
        vocab_size=tokenizer_config.vocab_size,
        character_coverage=tokenizer_config.character_coverage,
        model_type="bpe",
        max_sentence_length=tokenizer_config.max_sentence_length,
        pad_id=tokenizer_config.pad_id,
        bos_id=tokenizer_config.bos_id,
        eos_id=tokenizer_config.eos_id,
        unk_id=tokenizer_config.unk_id,
        input_sentence_size=10000000,
        shuffle_input_sentence=True,
        train_extremely_large_corpus=True
    )

    # Load the trained tokenizer
    tokenizer = VishwamAITokenizer(tokenizer_config)
    tokenizer.load(f"{model_prefix}.model")
    logger.info("Tokenizer training complete")

    # Verify tokenizer works
    test_text = "Hello world"
    encoded = tokenizer.encode(test_text)
    decoded = tokenizer.decode(encoded)
    logger.info(f"Tokenizer test - Encoded: {encoded}")
    logger.info(f"Tokenizer test - Decoded: {decoded}")

except Exception as e:
    logger.error(f"Error during tokenizer training: {str(e)}")
    raise

Training tokenizer...


Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/53.2k [00:00<?, ?B/s]

dataset_infos.json:   0%|          | 0.00/138k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/9.96k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/3.73k [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/3.45k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

In [9]:
# Load and process datasets with trained tokenizer
print("Loading datasets...")
datasets = {
        "gsm8k": load_dataset("gsm8k", split="train"),
        "mmlu": load_dataset("cais/mmlu", "abstract_algebra", split="validation"),  # Use validation or test split
}

def process_dataset(examples, dataset_type):
    if dataset_type in ["gsm8k", "mmlu"]:
        text = [f"Question: {q}\nAnswer: {a}"
                for q, a in zip(examples["question"], examples["answer"])]
    else:
        text = examples["content"]

    encoded = []
    attention_mask = []
    max_len = 2048

    for t in text:
        # Encode with trained tokenizer
        ids = tokenizer.encode(
            t,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True
        )

        # Pad sequence
        padding_length = max_len - len(ids)
        if padding_length > 0:
            ids = ids + [tokenizer.config.pad_id] * padding_length
            mask = [1] * (max_len - padding_length) + [0] * padding_length
        else:
            ids = ids[:max_len]
            mask = [1] * max_len

        encoded.append(ids)
        attention_mask.append(mask)

    return {
        "input_ids": encoded,
        "attention_mask": attention_mask
    }

# Process datasets
processed_datasets = {}
for name, dataset in datasets.items():
    print(f"Processing {name}...")
    processed_datasets[name] = dataset.map(
        lambda x: process_dataset(x, name),
        batched=True,
        remove_columns=dataset.column_names
    )
    print(f"Processed {len(processed_datasets[name])} examples from {name}")

Loading datasets...
Processing gsm8k...


Map:   0%|          | 0/7473 [00:00<?, ? examples/s]

Processed 7473 examples from gsm8k
Processing mmlu...


Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Processed 11 examples from mmlu


In [11]:
# GPU Setup and Verification
def setup_gpu():
    """Setup and verify GPU configuration."""
    if not torch.cuda.is_available():
        raise RuntimeError("This notebook requires a GPU runtime. Please change runtime type to GPU.")

    device = torch.device("cuda")
    gpu_name = torch.cuda.get_device_name(0)
    memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9

    print(f"Using GPU: {gpu_name}")
    print(f"Total GPU memory: {memory_gb:.1f} GB")

    # Set memory efficient settings
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True

    if 'T4' in gpu_name:
        print("Detected T4 GPU - Adjusting model configuration for optimal performance")
        return 'T4', min(memory_gb * 0.8, 12)  # Use 80% of available memory, max 12GB
    elif 'V100' in gpu_name:
        print("Detected V100 GPU - Using high-performance configuration")
        return 'V100', min(memory_gb * 0.85, 28)  # Use 85% of available memory
    elif 'A100' in gpu_name:
        print("Detected A100 GPU - Using maximum performance configuration")
        return 'A100', min(memory_gb * 0.9, 36)  # Use 90% of available memory
    else:
        print("Unknown GPU type - Using conservative settings")
        return 'unknown', min(memory_gb * 0.7, 8)  # Use 70% of available memory

# Execute GPU setup
gpu_type, available_memory = setup_gpu()

# Adjust model configuration based on GPU
def adjust_model_config(model_args, gpu_type, available_memory):
    """Adjust model configuration based on available GPU."""

    # Base memory requirements per parameter
    mem_per_param = 2  # bytes per parameter in FP16

    if gpu_type == 'T4':
        # Conservative settings for T4
        model_args.max_batch_size = 2
        model_args.max_seq_len = 1024
        model_args.dim = 1024
        model_args.n_layers = 12
        model_args.n_heads = 16
        model_args.n_routed_experts = 8

    elif gpu_type == 'V100':
        # Moderate settings for V100
        model_args.max_batch_size = 4
        model_args.max_seq_len = 2048
        model_args.dim = 2048
        model_args.n_layers = 24
        model_args.n_heads = 32
        model_args.n_routed_experts = 16

    elif gpu_type == 'A100':
        # Maximum settings for A100
        model_args.max_batch_size = 8
        model_args.max_seq_len = 4096
        model_args.dim = 4096
        model_args.n_layers = 32
        model_args.n_heads = 64
        model_args.n_routed_experts = 32

    # Calculate approximate model size
    num_params = (model_args.dim * model_args.n_layers * 4 *
                 model_args.max_seq_len * model_args.n_heads)
    estimated_memory = num_params * mem_per_param / 1e9  # Convert to GB

    if estimated_memory > available_memory:
        reduction_factor = (available_memory / estimated_memory) ** 0.5
        print(f"Warning: Reducing model size by {(1-reduction_factor)*100:.1f}% to fit in GPU memory")

        model_args.dim = int(model_args.dim * reduction_factor)
        model_args.n_layers = int(model_args.n_layers * reduction_factor)
        model_args.n_heads = max(8, int(model_args.n_heads * reduction_factor))
        model_args.n_routed_experts = max(4, int(model_args.n_routed_experts * reduction_factor))

    print("\nAdjusted model configuration:")
    print(f"Dimension: {model_args.dim}")
    print(f"Layers: {model_args.n_layers}")
    print(f"Heads: {model_args.n_heads}")
    print(f"Experts: {model_args.n_routed_experts}")
    print(f"Sequence Length: {model_args.max_seq_len}")
    print(f"Batch Size: {model_args.max_batch_size}")

    return model_args

# Update model initialization
try:
    # Adjust model configuration based on GPU
    model_args = adjust_model_config(model_args, gpu_type, available_memory)

    # Initialize model and trainer
    model, trainer, start_step = initialize_model_and_trainer(
        model_args=model_args,
        checkpoint_dir=CHECKPOINT_DIR,
        tot_config=tot_config,
        reward_config=reward_config,
        curriculum_config=curriculum_config
    )

    print("Model initialization successful!")
    print(f"Starting training from step {start_step}")
    print(f"Model will train for {model_args.max_steps} steps")

except torch.cuda.OutOfMemoryError as e:
    print("Error: GPU out of memory!")
    print("Try reducing model size or batch size")
    raise e
except Exception as e:
    print(f"Error during initialization: {str(e)}")
    raise

Using GPU: Tesla T4
Total GPU memory: 15.8 GB
Detected T4 GPU - Adjusting model configuration for optimal performance

Adjusted model configuration:
Dimension: 1024
Layers: 12
Heads: 16
Experts: 8
Sequence Length: 1024
Batch Size: 2
Validating model arguments...
Dimension: 1024
Max sequence length: 1024
Initializing default neural memory...
Creating model with configuration:
  dim: 1024
  max_seq_len: 1024
  n_layers: 12
  n_heads: 16
  n_routed_experts: 8
Using dtype: torch.bfloat16
Computing frequencies with dim=1024, max_seq_len=1024
Error computing frequencies: precompute_freqs_cis() got an unexpected keyword argument 'dtype'
Error during initialization: precompute_freqs_cis() got an unexpected keyword argument 'dtype'
Error during initialization: precompute_freqs_cis() got an unexpected keyword argument 'dtype'


TypeError: precompute_freqs_cis() got an unexpected keyword argument 'dtype'

In [None]:
# Training Loop with Performance Tracking
from tqdm.notebook import tqdm
import wandb

wandb.init(project="vishwamai-training")

performance_data = []

try:
    for step in tqdm(range(model_args.max_steps)): # Use model_args.max_steps
        stats = trainer.train_step()
        update_performance_tracking(stats, step)

        wandb.log({
            "loss": stats["loss"],
            "learning_rate": stats["lr"],
            "batch_size": stats["batch_size"],
            "curriculum_level": stats["curriculum_stats"]["current_difficulty"],
            "memory_usage": stats["memory_usage"]["allocated"],
            "moe_loss": stats.get("moe_loss", 0),
            "gradient_norm": stats["gradient_norm"],
            "expert_usage": stats.get("moe_metrics", {})
        })

        if step % 1000 == 0:
            plot_training_progress()
            checkpoint_path = f"{CHECKPOINT_DIR}/step_{step}.pt"
            trainer.save_checkpoint(checkpoint_path)

            trainer.push_to_hub(
                "VishwamAI/VishwamAI",
                commit_message=f"Training checkpoint at step {step}"
            )

        if step % 5000 == 0:
            print(f"\nEvaluating at step {step}...")
            eval_metrics = trainer.evaluate()
            wandb.log({"eval": eval_metrics})

except KeyboardInterrupt:
    print("\nTraining interrupted. Saving final visualization...")
    plot_training_progress()
    trainer.save_checkpoint(f"{CHECKPOINT_DIR}/interrupted.pt")

plot_training_progress()
performance_df.to_csv(f"{DRIVE_DIR}/training_metrics.csv", index=False)
print("Training complete with performance tracking")

In [None]:
# Generate Performance Graphs
performance_df = pd.DataFrame(performance_data)

plt.figure(figsize=(12, 6))
plt.plot(performance_df["step"], performance_df["loss"], label="Loss")
plt.xlabel("Step")
plt.ylabel("Loss")
plt.title("Training Loss Over Time")
plt.legend()
plt.grid(True)
plt.savefig(f"{DRIVE_DIR}/training_loss.png")
plt.show()

plt.figure(figsize=(12, 6))
plt.plot(performance_df["step"], performance_df["memory_usage"], label="Memory Usage (GB)")
plt.xlabel("Step")
plt.ylabel("Memory Usage (GB)")
plt.title("Memory Usage Over Time")
plt.legend()
plt.grid(True)
plt.savefig(f"{DRIVE_DIR}/memory_usage.png")
plt.show()

In [None]:
# Final Evaluation
print("Running final evaluation...")

eval_datasets = [
    "gsm8k",
    "TIGER-Lab/MMLU-Pro",
    "MMMU/MMMU",
    "microsoft/SCBench",
    "camel-ai/math",
    "camel-ai/code"
]

results = {}
for dataset in eval_datasets:
    print(f"\nEvaluating on {dataset}...")
    try:
        eval_data = load_dataset(dataset, split="test")
        metrics = trainer.evaluate(eval_data)
        results[dataset] = metrics
        print(f"{dataset}: {metrics}")
    except Exception as e:
        print(f"Error evaluating {dataset}: {str(e)}")

# Save results
import json
with open(f"{DRIVE_DIR}/final_evaluation.json", "w") as f:
    json.dump(results, f, indent=2)

print("\nTraining complete!")