In [6]:
import warnings
warnings.filterwarnings("ignore")

import torch
import wandb
from math import ceil
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# Import our custom modules
from config import Config
from metrics import create_compute_metrics
from data_utils import set_seeds, load_and_prepare_dataset, preprocess_dataset
from utils import clear_memory, create_directories, safe_training_check, save_model_safe
from train import setup_model, create_training_args

def main():
    """Main training function"""
    
    # Load config
    config = Config()
    
    print("=" * 60)
    print("üöÄ Nepali Grammar Error Correction Training")
    print("=" * 60)
    print(f"Model: {config.model_id}")
    print(f"LoRA: {config.use_lora}")
    print(f"Samples: {config.num_samples or 'Full dataset'}")
    print("=" * 60)
    
    # Setup
    set_seeds(config.seed)
    clear_memory()
    create_directories(config.output_dir)
    
    # Initialize wandb
    wandb.finish()
    wandb.init(
        project=config.wandb_project,
        config=vars(config)
    )
    run_id = wandb.run.id
    
    # Load data
    dataset = load_and_prepare_dataset(config)
    
    # Setup model
    model, tokenizer = setup_model(config)
    
    # Preprocess
    dataset_encoded = preprocess_dataset(dataset, tokenizer, config)
    
    # Create training args
    training_args = create_training_args(config, dataset_encoded, run_id)
    
    # Data collator
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        padding=True
    )
    
    # Create metrics
    compute_metrics = create_compute_metrics(tokenizer)
    
    # Create trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset_encoded["train"],
        eval_dataset=dataset_encoded["valid"],
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        callbacks=[
            EarlyStoppingCallback(early_stopping_patience=config.early_stopping_patience)
        ]
    )
    
        # Safety check
    if not safe_training_check(trainer):
        print("\n‚ùå Safety checks failed! Fix issues before training.")
        return
    
    # Train!
    print("\n" + "=" * 60)
    print("üèãÔ∏è  Starting training...")
    print("=" * 60)
    
    try:
        trainer.train()
        print("\n‚úÖ Training complete!")
    except Exception as e:
        print(f"\n‚ùå Training failed: {e}")
        wandb.finish()
        return
    
    # Save model
    best_model_path = f"{config.output_dir}/best_model"
    save_model_safe(model, tokenizer, best_model_path, use_lora=config.use_lora)
    
    print(f"\nüéâ All done! Model saved to {best_model_path}")
    wandb.finish()

if __name__ == "__main__":
    main()


üöÄ Nepali Grammar Error Correction Training
Model: google/mt5-small
LoRA: True
Samples: 15
‚úÖ Seeds set to 42


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


Memory cleared
Directories created in ../outputs


0,1
eval/bertscore_f1,‚ñà‚ñÅ‚ñà
eval/bleu,‚ñÅ‚ñà‚ñÅ
eval/chrf,‚ñÅ‚ñà‚ñÅ
eval/correction_accuracy,‚ñÅ‚ñÅ‚ñÅ
eval/gleu,‚ñÅ‚ñÅ‚ñÅ
eval/loss,‚ñà‚ñÅ‚ñÇ
eval/model_preparation_time,‚ñÅ‚ñÅ‚ñÅ
eval/runtime,‚ñà‚ñÅ‚ñÅ
eval/samples_per_second,‚ñÅ‚ñà‚ñà
eval/steps_per_second,‚ñÅ‚ñà‚ñà

0,1
eval/bertscore_f1,0.44887
eval/bleu,0.85086
eval/chrf,0.08999
eval/correction_accuracy,0
eval/gleu,0.0098
eval/loss,15.7358
eval/model_preparation_time,0.002
eval/runtime,15.6512
eval/samples_per_second,0.192
eval/steps_per_second,0.064


--- Logging error ---
Traceback (most recent call last):
  File "C:\Users\Lenovo\AppData\Local\Programs\Python\Python313\Lib\logging\__init__.py", line 1153, in emit
    stream.write(msg + self.terminator)
    ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Lenovo\AppData\Local\Programs\Python\Python313\Lib\encodings\cp1252.py", line 19, in encode
    return codecs.charmap_encode(input,self.errors,encoding_table)[0]
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
UnicodeEncodeError: 'charmap' codec can't encode characters in position 187-191: character maps to <undefined>
Call stack:
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "c:\Users\Lenovo\Desktop\Nepali_GEC\nepali_gec\myenv\Lib\site-packages\ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "c:\Users\Lenovo\Desktop\Nepali_GEC\nepali_gec\myenv\Lib\site-packages\traitlets\config\application.py", line 1075, in l


üìö Loading dataset: sumitaryal/nepali_grammatical_error_correction
  Using 15 samples
  Train: 12 samples
  Valid: 3 samples

 Lodaing model: google/mt5-small
 Using LoRA + 8-bit quantization
trainable params: 1,769,472 || all params: 301,946,240 || trainable%: 0.5860

‚öôÔ∏è  Preprocessing dataset...
  ‚úÖ Preprocessing complete

üìä Training plan:
  Steps per epoch: 1
  Total steps: 5
  Warmup steps: 0

 Running pre-training safety checks...
Model device: cuda:0
Train dataset size: 12
Eval dataset size: 3
 Data loading works
 Performing evaluation check...


üîç Sample - Pred: '<extra_id_0>...' | Ref: '‡§´‡§∞‡•ç‡§ï‡§®‡•á ‡§π‡•ã ‡§â‡§§‡•à?...' | Match: False
 Evaluation successful
Initial metrics: {'eval_loss': 15.911977767944336, 'eval_model_preparation_time': 0.0023, 'eval_bleu': 0.8508564639341064, 'eval_chrf': 0.08999280057595392, 'eval_correction_accuracy': 0.0, 'eval_bertscore_f1': 0.4488728940486908, 'eval_gleu': 0.00980392156862745, 'eval_runtime': 17.2188, 'eval_samples_per_second': 0.174, 'eval_steps_per_second': 0.058}

üèãÔ∏è  Starting training...


Epoch,Training Loss,Validation Loss,Model Preparation Time,Bleu,Chrf,Correction Accuracy,Bertscore F1,Gleu
1,No log,15.697292,0.0023,1.011845,0.179986,0.0,0.442217,0.009804
2,No log,15.735796,0.0023,0.850856,0.089993,0.0,0.448873,0.009804


üîç Sample - Pred: '<extra_id_0>...' | Ref: '‡§´‡§∞‡•ç‡§ï‡§®‡•á ‡§π‡•ã ‡§â‡§§‡•à?...' | Match: False
üîç Sample - Pred: '<extra_id_0>...' | Ref: '‡§´‡§∞‡•ç‡§ï‡§®‡•á ‡§π‡•ã ‡§â‡§§‡•à?...' | Match: False


KeyboardInterrupt: 