# Sanskrit Tutor - QLoRA Fine-tuning

This notebook allows you to fine-tune a language model on your Sanskrit text corpus using QLoRA (Quantized Low-Rank Adaptation).

## ⚠️ IMPORTANT: User Assets Required

This notebook expects you to have uploaded your Sanskrit corpus and model files to Google Drive. The required structure is:

```
drive/MyDrive/sanskrit-tutor-user-assets/
├── passages.jsonl
├── qa_pairs.jsonl  
├── config.yaml
└── models/ (optional - for base models)
```

**This notebook will NOT download data or models for you.** You must provide them.

## 1. Setup and Installation

In [None]:
# Install required packages for fine-tuning
!pip install -q \
    transformers>=4.30.0 \
    peft>=0.4.0 \
    bitsandbytes>=0.39.0 \
    accelerate>=0.20.0 \
    datasets>=2.12.0 \
    torch>=2.0.0 \
    pyyaml \
    jsonlines

print("✅ Packages installed successfully!")

In [None]:
# Mount Google Drive
from google.colab import drive
import os
from pathlib import Path

drive.mount('/content/drive')

# Define user assets path
USER_ASSETS_PATH = Path('/content/drive/MyDrive/sanskrit-tutor-user-assets')
print(f"User assets expected at: {USER_ASSETS_PATH}")

# Check if the directory exists
if USER_ASSETS_PATH.exists():
    print("✅ User assets directory found!")
    print("Contents:")
    for item in USER_ASSETS_PATH.iterdir():
        print(f"  {item.name}")
else:
    print("❌ User assets directory not found!")
    print("Please create the directory and upload your files according to the README.")

## 2. Validate User Assets

In [None]:
import json
import yaml
from pathlib import Path

def validate_user_assets(assets_path: Path):
    """
    Validate that required user assets are present and properly formatted.
    """
    errors = []
    
    # Check required files
    required_files = {
        'passages.jsonl': 'Sanskrit passages for fine-tuning',
        'qa_pairs.jsonl': 'Question-answer pairs for training',
        'config.yaml': 'Configuration file'
    }
    
    for filename, description in required_files.items():
        filepath = assets_path / filename
        if not filepath.exists():
            errors.append(f"Missing {filename}: {description}")
        else:
            print(f"✅ Found {filename}")
    
    if errors:
        print("\n❌ Validation failed:")
        for error in errors:
            print(f"  - {error}")
        return False
    
    # Validate file contents
    try:
        # Check passages.jsonl
        passages_file = assets_path / 'passages.jsonl'
        with open(passages_file, 'r', encoding='utf-8') as f:
            passage_count = 0
            for line_num, line in enumerate(f, 1):
                if line.strip():
                    passage_count += 1
                    if line_num <= 3:  # Check first 3 lines for required fields
                        try:
                            obj = json.loads(line)
                            required_fields = ['id', 'text_devanagari', 'text_iast', 'work']
                            for field in required_fields:
                                if field not in obj:
                                    errors.append(f"passages.jsonl line {line_num}: missing field '{field}'")
                        except json.JSONDecodeError:
                            errors.append(f"passages.jsonl line {line_num}: invalid JSON")
        
        print(f"📚 Found {passage_count} passages")
        
        # Check qa_pairs.jsonl
        qa_file = assets_path / 'qa_pairs.jsonl'
        with open(qa_file, 'r', encoding='utf-8') as f:
            qa_count = 0
            for line_num, line in enumerate(f, 1):
                if line.strip():
                    qa_count += 1
                    if line_num <= 3:  # Check first 3 lines
                        try:
                            obj = json.loads(line)
                            required_fields = ['id', 'question', 'answer']
                            for field in required_fields:
                                if field not in obj:
                                    errors.append(f"qa_pairs.jsonl line {line_num}: missing field '{field}'")
                        except json.JSONDecodeError:
                            errors.append(f"qa_pairs.jsonl line {line_num}: invalid JSON")
        
        print(f"❓ Found {qa_count} QA pairs")
        
        # Check config.yaml
        config_file = assets_path / 'config.yaml'
        with open(config_file, 'r', encoding='utf-8') as f:
            config = yaml.safe_load(f)
            print(f"⚙️ Configuration loaded")
    
    except Exception as e:
        errors.append(f"Error reading files: {str(e)}")
    
    if errors:
        print("\n❌ Validation errors:")
        for error in errors:
            print(f"  - {error}")
        return False
    
    print("\n✅ All user assets validated successfully!")
    return True

# Run validation
validation_success = validate_user_assets(USER_ASSETS_PATH)

if not validation_success:
    print("\n🛑 Please fix the validation errors before proceeding.")
    print("\nFor help with the required file formats, see:")
    print("https://github.com/yourusername/sanskrit-tutor/blob/main/README.md")

## 3. Prepare Training Data

In [None]:
import json
import pandas as pd
from datasets import Dataset
import re

def load_training_data(assets_path: Path):
    """
    Load and prepare training data from user assets.
    """
    # Load passages
    passages = {}
    with open(assets_path / 'passages.jsonl', 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                passage = json.loads(line)
                passages[passage['id']] = passage
    
    print(f"Loaded {len(passages)} passages")
    
    # Load QA pairs and create training examples
    training_examples = []
    
    with open(assets_path / 'qa_pairs.jsonl', 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                qa = json.loads(line)
                
                # Create context from related passages
                context_parts = []
                if 'related_passage_ids' in qa:
                    for passage_id in qa['related_passage_ids']:
                        if passage_id in passages:
                            p = passages[passage_id]
                            context_part = f"[{p['id']}] {p['work']} {p['chapter']}.{p['verse']}\n"
                            context_part += f"Devanagari: {p['text_devanagari']}\n"
                            context_part += f"IAST: {p['text_iast']}"
                            if p.get('notes'):
                                context_part += f"\nNotes: {p['notes']}"
                            context_parts.append(context_part)
                
                context = "\n\n".join(context_parts) if context_parts else "No specific passages provided."
                
                # Format as instruction-following example
                instruction = f"""You are a Sanskrit tutor. Answer the following question using the provided context and cite your sources using the exact passage IDs in square brackets.

Context:
{context}

Question: {qa['question']}"""
                
                response = qa['answer']
                
                training_examples.append({
                    'instruction': instruction,
                    'response': response,
                    'qa_id': qa['id'],
                    'difficulty': qa.get('difficulty', 'unknown')
                })
    
    print(f"Created {len(training_examples)} training examples")
    
    # Convert to dataset
    dataset = Dataset.from_list(training_examples)
    
    # Print some statistics
    difficulties = [ex['difficulty'] for ex in training_examples]
    difficulty_counts = pd.Series(difficulties).value_counts()
    print(f"\nDifficulty distribution:")
    print(difficulty_counts)
    
    return dataset, passages

# Only proceed if validation passed
if validation_success:
    training_dataset, passage_dict = load_training_data(USER_ASSETS_PATH)
    print(f"\n📊 Training dataset ready: {len(training_dataset)} examples")
    
    # Show a sample
    print("\n📝 Sample training example:")
    sample = training_dataset[0]
    print(f"Instruction: {sample['instruction'][:200]}...")
    print(f"Response: {sample['response'][:200]}...")
else:
    print("❌ Cannot proceed without valid user assets")

## 4. Setup Base Model and QLoRA Configuration

In [None]:
import torch
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    BitsAndBytesConfig,
    TrainingArguments
)
from peft import LoraConfig, get_peft_model, TaskType

# Configuration
MODEL_NAME = "microsoft/DialoGPT-small"  # Default small model for testing
# For better results, consider: "mistralai/Mistral-7B-Instruct-v0.1" or "meta-llama/Llama-2-7b-chat-hf"
# Note: Larger models require more GPU memory

def setup_model_and_tokenizer(model_name: str):
    """
    Setup the base model with 4-bit quantization and LoRA configuration.
    """
    print(f"Setting up model: {model_name}")
    
    # Quantization configuration for QLoRA
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
    )
    
    # Load model with quantization
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
    )
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # LoRA configuration
    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=16,  # Rank of adaptation
        lora_alpha=32,  # LoRA scaling parameter
        lora_dropout=0.1,  # LoRA dropout
        target_modules=["c_attn", "c_proj"],  # Target modules for DialoGPT
        # For Mistral/Llama, use: ["q_proj", "k_proj", "v_proj", "o_proj"]
    )
    
    # Apply LoRA to model
    model = get_peft_model(model, lora_config)
    
    print(f"✅ Model and tokenizer loaded with QLoRA configuration")
    print(f"📊 Trainable parameters: {model.print_trainable_parameters()}")
    
    return model, tokenizer

if validation_success:
    # Check GPU availability
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    if device.type == "cuda":
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    
    # Load model and tokenizer
    model, tokenizer = setup_model_and_tokenizer(MODEL_NAME)
else:
    print("❌ Cannot proceed without valid user assets")

## 5. Training Configuration and Data Processing

In [None]:
from transformers import DataCollatorForLanguageModeling

def format_training_example(example, tokenizer, max_length=512):
    """
    Format a training example for instruction following.
    """
    # Create the full training text
    full_text = f"{example['instruction']}\n\n{example['response']}{tokenizer.eos_token}"
    
    # Tokenize
    tokenized = tokenizer(
        full_text,
        truncation=True,
        max_length=max_length,
        padding="max_length",
        return_tensors="pt"
    )
    
    # For causal language modeling, labels are the same as input_ids
    tokenized["labels"] = tokenized["input_ids"].clone()
    
    return tokenized

def tokenize_dataset(dataset, tokenizer, max_length=512):
    """
    Tokenize the entire dataset.
    """
    def tokenize_function(examples):
        # Process each example
        results = {
            "input_ids": [],
            "attention_mask": [],
            "labels": []
        }
        
        for i in range(len(examples['instruction'])):
            example = {
                'instruction': examples['instruction'][i],
                'response': examples['response'][i]
            }
            
            tokenized = format_training_example(example, tokenizer, max_length)
            
            results["input_ids"].append(tokenized["input_ids"][0])
            results["attention_mask"].append(tokenized["attention_mask"][0])
            results["labels"].append(tokenized["labels"][0])
        
        return results
    
    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=dataset.column_names,
        desc="Tokenizing dataset"
    )
    
    return tokenized_dataset

if validation_success and 'model' in locals():
    # Tokenize the dataset
    tokenized_dataset = tokenize_dataset(training_dataset, tokenizer, max_length=512)
    
    # Split into train/validation
    train_test_split = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
    train_dataset = train_test_split['train']
    eval_dataset = train_test_split['test']
    
    print(f"📊 Training set: {len(train_dataset)} examples")
    print(f"📊 Validation set: {len(eval_dataset)} examples")
    
    # Data collator
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,  # Not masked language modeling
    )
    
    print("✅ Training data prepared")
else:
    print("❌ Cannot proceed without valid model setup")

## 6. Fine-tuning with QLoRA

In [None]:
from transformers import Trainer, TrainingArguments
import os

def setup_training_arguments(output_dir="./sanskrit-tutor-qlora"):
    """
    Setup training arguments for QLoRA fine-tuning.
    """
    return TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=2,  # Small batch size for Colab
        per_device_eval_batch_size=2,
        gradient_accumulation_steps=4,  # Effective batch size = 2 * 4 = 8
        num_train_epochs=3,  # Start with few epochs
        learning_rate=2e-4,  # Higher learning rate for LoRA
        warmup_steps=100,
        logging_steps=10,
        evaluation_strategy="steps",
        eval_steps=50,
        save_steps=100,
        save_strategy="steps",
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        report_to=None,  # Disable wandb/tensorboard
        dataloader_pin_memory=False,
        remove_unused_columns=False,
        fp16=True,  # Use mixed precision for speed
        ddp_find_unused_parameters=False,
    )

if validation_success and 'train_dataset' in locals():
    # Setup training arguments
    training_args = setup_training_arguments()
    
    # Create trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
    )
    
    print("✅ Trainer initialized")
    print(f"📊 Total training steps: {len(train_dataset) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps) * training_args.num_train_epochs}")
    
    # Start training
    print("\n🚀 Starting fine-tuning...")
    print("This may take 30-60 minutes depending on your data size and GPU.")
    
    try:
        trainer.train()
        print("✅ Training completed successfully!")
        
        # Save the final model
        trainer.save_model()
        print(f"💾 Model saved to: {training_args.output_dir}")
        
    except Exception as e:
        print(f"❌ Training failed: {str(e)}")
        print("This might be due to memory constraints. Try reducing batch size or sequence length.")
        
else:
    print("❌ Cannot proceed without properly prepared training data")

## 7. Save Model to Drive and Test

In [None]:
import shutil
from datetime import datetime

def save_model_to_drive(model, tokenizer, source_dir="./sanskrit-tutor-qlora"):
    """
    Save the fine-tuned model to Google Drive.
    """
    # Create timestamped directory in Drive
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    drive_model_dir = USER_ASSETS_PATH / f"fine_tuned_model_{timestamp}"
    drive_model_dir.mkdir(exist_ok=True)
    
    print(f"💾 Saving model to: {drive_model_dir}")
    
    # Save model and tokenizer
    model.save_pretrained(drive_model_dir)
    tokenizer.save_pretrained(drive_model_dir)
    
    # Copy training logs if they exist
    if os.path.exists(source_dir):
        for item in os.listdir(source_dir):
            if item.endswith('.json') or item.endswith('.txt'):
                shutil.copy2(os.path.join(source_dir, item), drive_model_dir / item)
    
    # Create a README for the saved model
    readme_content = f"""# Sanskrit Tutor Fine-tuned Model

**Training Date:** {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
**Base Model:** {MODEL_NAME}
**Training Examples:** {len(training_dataset) if 'training_dataset' in locals() else 'Unknown'}
**Method:** QLoRA (4-bit quantization + LoRA)

## Usage

This model can be loaded using the `transformers` and `peft` libraries:

```python
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

# Load base model and tokenizer
base_model = AutoModelForCausalLM.from_pretrained("{MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained("{MODEL_NAME}")

# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, "path/to/this/directory")
```

## Files

- `adapter_config.json`: LoRA adapter configuration
- `adapter_model.safetensors`: LoRA adapter weights
- `tokenizer.json`, `tokenizer_config.json`: Tokenizer files
- Training logs and other metadata
"""
    
    with open(drive_model_dir / "README.md", "w", encoding="utf-8") as f:
        f.write(readme_content)
    
    print(f"✅ Model saved successfully to Drive!")
    print(f"📁 Location: {drive_model_dir}")
    
    return drive_model_dir

# Test the fine-tuned model
def test_model(model, tokenizer, test_question="What does 'dharma' mean?"):
    """
    Test the fine-tuned model with a sample question.
    """
    print(f"\n🧪 Testing model with question: '{test_question}'")
    
    # Create a test prompt in the same format as training
    test_prompt = f"""You are a Sanskrit tutor. Answer the following question and cite your sources using exact passage IDs in square brackets.

Question: {test_question}

Answer:"""
    
    # Tokenize and generate
    inputs = tokenizer(test_prompt, return_tensors="pt", truncation=True, max_length=400)
    
    # Move inputs to the same device as model
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=150,
            do_sample=True,
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id
        )
    
    # Decode the response
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract just the generated part
    answer_start = response.find("Answer:") + len("Answer:")
    generated_answer = response[answer_start:].strip()
    
    print("\n📝 Generated Response:")
    print("="*50)
    print(generated_answer)
    print("="*50)
    
    return generated_answer

# Save and test if training was successful
if 'trainer' in locals() and hasattr(trainer, 'model'):
    # Save to Drive
    saved_model_path = save_model_to_drive(trainer.model, tokenizer)
    
    # Test the model
    test_model(trainer.model, tokenizer)
    
    print(f"\n🎉 Fine-tuning completed successfully!")
    print(f"📁 Model saved to: {saved_model_path}")
    print(f"\n🔄 To use this model in the Sanskrit Tutor:")
    print(f"1. Copy the adapter files to your local user_assets/adapters/ directory")
    print(f"2. Update your config.yaml to point to the adapter")
    print(f"3. The system will automatically detect and load the adapter")
    
else:
    print("❌ No trained model available to save")

## 8. Next Steps

### Using Your Fine-tuned Model

1. **Download the adapter files** from your Google Drive to your local machine
2. **Place them** in the `user_assets/adapters/` directory of your Sanskrit Tutor installation
3. **Update your configuration** to use the fine-tuned model
4. **Test the improved model** in the Sanskrit Tutor interface

### Improving Results

- **More data**: Add more passages and QA pairs to your training data
- **Better base model**: Try larger models like Mistral-7B or Llama-2-7B
- **Hyperparameter tuning**: Adjust learning rate, LoRA rank, and training epochs
- **Data quality**: Ensure your QA pairs have accurate citations and high-quality answers

### Monitoring and Evaluation

- Check the training logs for loss curves
- Test the model on held-out questions
- Compare citation accuracy before and after fine-tuning
- Monitor for overfitting (high train accuracy, low validation accuracy)

---

**Remember**: This notebook requires you to provide your own data and models. The fine-tuned model will only be as good as the training data you provide. For best results, ensure your passages.jsonl and qa_pairs.jsonl files are high-quality and representative of the questions you want the system to answer.