In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
import wandb
import torch
import json
import subprocess
import time
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, setup_chat_format
from lighteval.pipeline import Pipeline, PipelineParameters, ParallelismManager
from lighteval.logging.evaluation_tracker import EvaluationTracker
from huggingface_hub import login

In [None]:
# Setup
# login()
wandb.init(project="trl-Teuken3.73T")

In [None]:
device = "cuda"
finetune_name = "Teuken-Instruct-TRL"


In [None]:
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("/raid/s3/opengptx/mfrey/3.73T-Tokens/checkpoints/aug15_tokfix", trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained("/raid/s3/opengptx/mfrey/3.73T-Tokens/checkpoints/aug15_tokfix")
model, tokenizer = setup_chat_format(model, tokenizer)

In [None]:
!export HF_HOME="/raid/s3/opengptx/mfrey/huggingface"

In [None]:
# Load and format Alpaca dataset
# ds = load_dataset("tatsu-lab/alpaca", split="train")
ds = load_dataset("meta-math/MetaMathQA", split="train")

def format_alpaca(example):
    """Convert Alpaca format to chat format"""
    if example["input"]:
        content = f"{example['instruction']}\n\nInput: {example['input']}"
    else:
        content = example["instruction"]
    
    return {
        "messages": [
            {"role": "user", "content": content},
            {"role": "assistant", "content": example["output"]}
        ]
    }

def format_metamath(example):
    """Format MetaMath dataset to chat format"""
    return {
        "messages": [
            {"role": "user", "content": example["query"]},
            {"role": "assistant", "content": example["response"]}
        ]
    }

ds = ds.map(format_metamath, remove_columns=ds.column_names)
ds = ds.train_test_split(test_size=0.01, seed=42)

In [None]:
def run_lighteval_cli_async(checkpoint_path, step, gpu_id=6, max_samples=10):
    """
    Run LightEval CLI with quoted arguments, save results in the checkpoint folder,
    and return the results.
    """
    try:
        print(f"üîç Starting CLI evaluation for step {step} on {checkpoint_path}")
        
        checkpoint_dir = Path(checkpoint_path)
        
        # Define the argument strings
        model_args = f"model_name={checkpoint_path},use_chat_template=True,trust_remote_code=True,batch_size=16"
        tasks = "leaderboard|hellaswag|0|1,leaderboard|gsm8k|0|1"
        
        # Construct the full command as a single string, with quotes around the arguments
        cmd_string = (
            f'lighteval accelerate '
            f'"{model_args}" '
            f'"{tasks}" '
            f'--max-samples {max_samples} '
        )
        
        # Set environment variables
        env = os.environ.copy()
        env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
        env["HF_HOME"] = "/raid/s3/opengptx/mfrey/huggingface"
        
        print(f"üöÄ Running command: CUDA_VISIBLE_DEVICES={gpu_id} {cmd_string}")
        
        # Run the evaluation using shell=True to correctly interpret the quotes
        result = subprocess.run(
            cmd_string,
            shell=True,  # Use the shell to parse the command string
            env=env,
            capture_output=True,
            text=True,
            check=False,
            cwd=os.getcwd()
        )
        
        if result.returncode != 0:
            print(f"‚ùå Evaluation failed for step {step}")
            print(f"STDOUT: {result.stdout}")
            print(f"STDERR: {result.stderr}")
            return None
            
        print(f"‚úÖ CLI evaluation completed for step {step}")
        
        # Find the results JSON file in the checkpoint directory
        json_files = list(checkpoint_dir.glob("results_*.json"))
        if not json_files:
            print(f"‚ùå No results JSON file found in {checkpoint_dir}")
            return None
            
        # Read the most recent results file
        results_file = max(json_files, key=lambda p: p.stat().st_mtime)
        print(f"üìñ Reading results from {results_file}")
        
        with open(results_file, 'r') as f:
            eval_results = json.load(f)
            
        return eval_results
        
    except Exception as e:
        print(f"‚ùå Evaluation failed for step {step}: {e}")
        return None

def parse_and_log_results(eval_results, step):
    """Parse LightEval results and log to WandB"""
    if not eval_results or "results" not in eval_results:
        print(f"‚ùå No valid results to log for step {step}")
        return
        
    results_to_log = {}
    
    # Parse individual task results
    for task_name, metrics in eval_results["results"].items():
        if task_name == "all":  # Skip the aggregated results
            continue
            
        # Clean up task name for logging
        clean_task_name = task_name.replace("leaderboard|", "").split("|")[0]
        
        for metric_name, value in metrics.items():
            if not metric_name.endswith("_stderr"):  # Skip stderr metrics
                log_key = f"eval/{clean_task_name}_{metric_name}"
                results_to_log[log_key] = value
                
    # Also log some metadata
    if "config_general" in eval_results:
        config = eval_results["config_general"]
        if "total_evaluation_time_secondes" in config:
            results_to_log["eval/evaluation_time_seconds"] = float(config["total_evaluation_time_secondes"])
        if "model_size" in config:
            results_to_log["eval/model_size"] = config["model_size"]
            
    # Log to WandB
    if results_to_log:
        wandb.log(results_to_log, step=step)
        print(f"üìä Logged {len(results_to_log)} metrics to WandB for step {step}")
        for key, value in results_to_log.items():
            print(f"  {key}: {value}")
    else:
        print(f"‚ùå No metrics to log for step {step}")

class EvalCallback(TrainerCallback):
    """Callback to trigger async LightEval CLI on checkpoint saves"""
    def __init__(self, gpu_id=6, max_samples=10, max_workers=2):
        self.gpu_id = gpu_id
        self.max_samples = max_samples
        self.executor = ThreadPoolExecutor(max_workers=max_workers)
        self.futures = []
    
    def on_save(self, args, state, control, **kwargs):
        """Trigger evaluation when checkpoint is saved"""
        checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
        
        # Check if checkpoint directory exists
        if not os.path.exists(checkpoint_path):
            print(f"‚ö†Ô∏è  Checkpoint path {checkpoint_path} does not exist, skipping evaluation")
            return
            
        print(f"üíæ Checkpoint saved at step {state.global_step}, triggering evaluation...")
        
        # Submit evaluation job
        def eval_and_log():
            results = run_lighteval_cli_async(checkpoint_path, state.global_step, self.gpu_id, self.max_samples)
            if results:
                parse_and_log_results(results, state.global_step)
        
        future = self.executor.submit(eval_and_log)
        self.futures.append(future)
        
        # Clean up completed futures
        self.futures = [f for f in self.futures if not f.done()]
        
        print(f"üéØ Evaluation job submitted for step {state.global_step}")
    
    def on_train_end(self, args, state, control, **kwargs):
        """Wait for all evaluations to complete"""
        if self.futures:
            print("‚è≥ Waiting for remaining evaluations to complete...")
            for future in self.futures:
                try:
                    future.result(timeout=300)  # 5 minute timeout per evaluation
                except Exception as e:
                    print(f"‚ùå Evaluation future failed: {e}")
        
        self.executor.shutdown(wait=True)
        print("‚úÖ All evaluations completed")

SOURCE_MODEL_PATH = "/raid/s3/opengptx/mfrey/3.73T-Tokens/checkpoints/aug15_tokfix"
class CustomSFTTrainer(SFTTrainer):
    def __init__(self, source_model_path, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.source_model_path = source_model_path
    
    def save_model(self, output_dir=None, _internal_call=False):
        super().save_model(output_dir, _internal_call)
        if output_dir is None:
            output_dir = self.args.output_dir
            
        save_model_with_custom_code(self.model, output_dir, self.source_model_path)

def save_model_with_custom_code(model, save_path, source_model_path):
    """Save model and copy custom code files"""
    import shutil
    from pathlib import Path
    if hasattr(model, 'save_pretrained'):
        model.save_pretrained(save_path)
    source_path = Path(source_model_path)
    save_path = Path(save_path)
    
    custom_files = ["modeling_gpt2.py", "configuration_gpt2.py"]
    
    for file_name in custom_files:
        source_file = source_path / file_name
        dest_file = save_path / file_name
        
        if source_file.exists():
            shutil.copy(source_file, dest_file)
            print(f"‚úÖ Copied {file_name} to {save_path}")
        else:
            print(f"‚ö†Ô∏è  Warning: {file_name} not found in {source_path}")


In [None]:
# Configure SFTTrainer with W&B logging
sft_config = SFTConfig(
    output_dir="/raid/s3/opengptx/mfrey/3.73T-Tokens/checkpoints/aug15_tokfix/instruct",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=1e-5,
    weight_decay=0.01,
    warmup_ratio=0.1,
    logging_steps=10,
    save_steps=1000,
    eval_steps=1000,
    eval_strategy="steps",
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    report_to="wandb",
    gradient_checkpointing=True,
    # Mixed precision training
    #fp16=torch.cuda.is_available(),
    fp16=False,
    bf16=True,
    push_to_hub=False,
)

In [None]:
trainer = CustomSFTTrainer(
    source_model_path=SOURCE_MODEL_PATH,
    model=model,
    args=sft_config,
    train_dataset=ds["train"],
    eval_dataset=ds["test"],
    processing_class=tokenizer,
    callbacks=[EvalCallback(gpu_id=7, max_samples=500)],
)

In [None]:

# Train the model
print("üöÄ Starting training...")
trainer.train()



In [None]:
# Close W&B run
wandb.finish()

print("‚ú® Training and evaluation complete!")