<a href="https://colab.research.google.com/github/ayagup/stablediffusion/blob/main/hf_llm_lora_train_gpu_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers datasets peft accelerate bitsandbytes torch

In [None]:
!mkdir -p /root/.config/kaggle
!echo "{\"username\":\"maygup123\",\"key\":\"e8ff771508f59b00b55d840f011f1916\"}" > /root/.config/kaggle/kaggle.json

In [None]:
import kaggle

kaggle.api.dataset_download_files('mohammadnouralawad/spider-text-sql', path='./data', unzip=True)

In [None]:
import os

print(os.getcwd())

In [None]:
!head /kaggle/working/data/spider_text_sql.csv

In [None]:
# import pandas as pd

# df = pd.read_csv('/kaggle/working/data/spider_text_sql.csv')
# df['type'] = 'Write a SELECT query'
# df = df[['type', 'text_query', 'sql_command']]
# list_of_tuples = [tuple(row) for row in df.itertuples(index=False)]
# sql_examples.extend(list_of_tuples)
# df.head()

In [None]:
import torch
import os
import gc
import json
import random
from datetime import datetime
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    TrainerCallback
)
from datasets import Dataset
from peft import (
    LoraConfig,
    get_peft_model,
    TaskType,
    PeftModel,
    prepare_model_for_kbit_training
)
import numpy as np



In [None]:
def check_gpu():
    """Check and display GPU information"""
    if torch.cuda.is_available():
        gpu_name = torch.cuda.get_device_name(0)
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"🚀 GPU Available: {gpu_name}")
        print(f"💾 GPU Memory: {gpu_memory:.1f}GB")
        return torch.device("cuda")
    else:
        print("⚠️  No GPU available, using CPU")
        return torch.device("cpu")



In [None]:
class WorkingLoRASQLTrainer:
    def __init__(self):
        print("🔧 Initializing Working LoRA SQL Training Pipeline")
        self.device = check_gpu()
        self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.output_dir = f"./lora_sql_training_{self.timestamp}"

        os.makedirs(self.output_dir, exist_ok=True)
        print(f"📁 Output directory: {self.output_dir}")

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    def download_and_setup_model(self, model_name="microsoft/DialoGPT-medium"):
        """Download and setup model for LoRA training"""
        print(f"📥 Setting up model: {model_name}")

        try:
            # Load tokenizer
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
                tokenizer.pad_token_id = tokenizer.eos_token_id

            print(f"📝 Tokenizer loaded, vocab size: {len(tokenizer)}")

            # Load model with proper settings
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32,
                device_map=None,
                trust_remote_code=True,
                use_cache=False  # Disable cache for training
            )

            # Move to device
            model = model.to(self.device)

            # Prepare for LoRA training
            if self.device.type == 'cuda':
                model = prepare_model_for_kbit_training(model)

            print(f"🤖 Model prepared for LoRA training")
            print(f"🧠 Model parameters: {sum(p.numel() for p in model.parameters()):,}")

            return model, tokenizer

        except Exception as e:
            print(f"❌ Error setting up model: {e}")
            raise

    def create_sql_dataset(self, num_samples=200):
        """Create focused SQL dataset"""
        import pandas as pd

        print(f"📊 Creating SQL dataset with {num_samples} samples...")

        sql_examples = [
            # Basic SELECT queries
            ("Write a SELECT query", "Get all records from users table", "SELECT * FROM users;"),
            ("Write a SELECT query", "Get name and email from users", "SELECT name, email FROM users;"),
            ("Write a SELECT query", "Get all products", "SELECT * FROM products;"),

            # WHERE clauses
            ("Write a WHERE query", "Find users older than 25", "SELECT * FROM users WHERE age > 25;"),
            ("Write a WHERE query", "Find active users", "SELECT * FROM users WHERE status = 'active';"),
            ("Write a WHERE query", "Find products under $50", "SELECT * FROM products WHERE price < 50;"),

            # COUNT queries
            ("Write a COUNT query", "Count all users", "SELECT COUNT(*) FROM users;"),
            ("Write a COUNT query", "Count active orders", "SELECT COUNT(*) FROM orders WHERE status = 'active';"),

            # ORDER BY queries
            ("Write an ORDER BY query", "Sort users by name", "SELECT * FROM users ORDER BY name;"),
            ("Write an ORDER BY query", "Sort products by price descending", "SELECT * FROM products ORDER BY price DESC;"),

            # GROUP BY queries
            ("Write a GROUP BY query", "Count users by department", "SELECT department, COUNT(*) FROM users GROUP BY department;"),
            ("Write a GROUP BY query", "Sum sales by region", "SELECT region, SUM(amount) FROM sales GROUP BY region;"),

            # JOIN queries
            ("Write a JOIN query", "Join users and orders", "SELECT u.name, o.total FROM users u JOIN orders o ON u.id = o.user_id;"),
            ("Write a JOIN query", "Join products and categories", "SELECT p.name, c.category_name FROM products p JOIN categories c ON p.category_id = c.id;"),

            # INSERT queries
            ("Write an INSERT query", "Insert new user", "INSERT INTO users (name, email) VALUES ('John Doe', 'john@email.com');"),
            ("Write an INSERT query", "Insert new product", "INSERT INTO products (name, price) VALUES ('Widget', 19.99);"),

            # UPDATE queries
            ("Write an UPDATE query", "Update user email", "UPDATE users SET email = 'new@email.com' WHERE id = 1;"),
            ("Write an UPDATE query", "Update product price", "UPDATE products SET price = 29.99 WHERE name = 'Widget';"),

            # DELETE queries
            ("Write a DELETE query", "Delete inactive users", "DELETE FROM users WHERE status = 'inactive';"),
            ("Write a DELETE query", "Delete old orders", "DELETE FROM orders WHERE date < '2023-01-01';"),
        ]

        df = pd.read_csv('/kaggle/working/data/spider_text_sql.csv')
        df['type'] = 'Write a SELECT query'
        df = df[['type', 'text_query', 'sql_command']]
        list_of_tuples = [tuple(row) for row in df.itertuples(index=False)]
        sql_examples.extend(list_of_tuples)

        # Generate dataset by cycling through examples
        dataset = []
        for i in range(num_samples):
            example = sql_examples[i % len(sql_examples)]
            instruction, problem, solution = example

            # Simple format that works well
            text = f"Instruction: {instruction}\nInput: {problem}\nOutput: {solution}"
            dataset.append({"text": text})

        print(f"✅ Created {len(dataset)} SQL examples")
        return Dataset.from_list(dataset)

    def setup_lora_config(self):
        """Setup LoRA configuration"""
        print("🔧 Setting up LoRA configuration...")

        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            inference_mode=False,
            r=4,                     # Very small rank for stability
            lora_alpha=8,            # Alpha = 2 * r
            lora_dropout=0.05,
            target_modules=["c_attn"],  # Only target attention for simplicity
            bias="none",
            modules_to_save=None,
        )

        print(f"✅ LoRA Config: r={lora_config.r}, alpha={lora_config.lora_alpha}")
        return lora_config

    def apply_lora_to_model(self, model, lora_config):
        """Apply LoRA to model"""
        print("🔄 Applying LoRA to model...")

        try:
            peft_model = get_peft_model(model, lora_config)
            peft_model.train()

            # Count parameters
            trainable_params = sum(p.numel() for p in peft_model.parameters() if p.requires_grad)
            total_params = sum(p.numel() for p in peft_model.parameters())

            print(f"🧠 Total parameters: {total_params:,}")
            print(f"🎯 Trainable parameters: {trainable_params:,}")
            print(f"📊 Trainable percentage: {(trainable_params/total_params)*100:.2f}%")

            if trainable_params == 0:
                raise RuntimeError("No trainable parameters found!")

            return peft_model

        except Exception as e:
            print(f"❌ Error applying LoRA: {e}")
            raise

    def tokenize_dataset(self, dataset, tokenizer):
        """Tokenize dataset"""
        print("🔤 Tokenizing dataset...")

        def tokenize_function(examples):
            result = tokenizer(
                examples["text"],
                truncation=True,
                padding=True,
                max_length=200,  # Short for memory efficiency
                return_tensors="np"
            )

            tokenized = {key: values.tolist() for key, values in result.items()}
            tokenized["labels"] = tokenized["input_ids"].copy()
            return tokenized

        tokenized_dataset = dataset.map(
            tokenize_function,
            batched=True,
            batch_size=50,
            remove_columns=dataset.column_names,
            desc="Tokenizing"
        )

        # Filter short sequences
        tokenized_dataset = tokenized_dataset.filter(lambda x: len(x["input_ids"]) > 10)

        print(f"✅ Tokenized {len(tokenized_dataset)} examples")
        return tokenized_dataset

    def create_training_arguments(self):
        """Create training arguments"""
        return TrainingArguments(
            output_dir=self.output_dir,
            num_train_epochs=20,              # Single epoch for quick test
            per_device_train_batch_size=1,   # Very small batch
            gradient_accumulation_steps=4,
            learning_rate=5e-4,              # Higher LR for LoRA
            lr_scheduler_type="linear",
            warmup_steps=10,
            weight_decay=0.01,
            max_grad_norm=0.3,
            fp16=True if self.device.type == 'cuda' else False,
            gradient_checkpointing=False,
            dataloader_drop_last=True,
            dataloader_num_workers=0,
            logging_steps=5,
            save_steps=50,
            save_total_limit=1,
            eval_strategy="no",
            prediction_loss_only=True,
            seed=42,
            report_to=[],
            remove_unused_columns=True,
        )

    def test_model_generation(self, model, tokenizer, stage=""):
        """Test model generation"""
        print(f"🧪 Testing model generation {stage}...")

        test_prompt = "Instruction: Write a SELECT query\nInput: Get all users\nOutput:"
        inputs = tokenizer(test_prompt, return_tensors="pt").to(self.device)

        model.eval()
        with torch.no_grad():
            generated = model.generate(
                input_ids=inputs.input_ids,
                max_new_tokens=20,
                temperature=0.8,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id
            )

            new_tokens = generated[0][inputs.input_ids.shape[1]:]
            generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
            print(f"🎯 Generated {stage}: {generated_text.strip()}")

        model.train()

    def run_lora_training(self):
        """Run complete LoRA training"""
        try:
            print("🚀 Starting Working LoRA SQL Training")
            print("=" * 60)

            # Setup components
            model, tokenizer = self.download_and_setup_model()
            dataset = self.create_sql_dataset(num_samples=100)  # Small dataset

            # Apply LoRA
            lora_config = self.setup_lora_config()
            lora_model = self.apply_lora_to_model(model, lora_config)

            # Test before training
            self.test_model_generation(lora_model, tokenizer, "before training")

            # Tokenize
            tokenized_dataset = self.tokenize_dataset(dataset, tokenizer)

            # Setup training
            training_args = self.create_training_arguments()
            data_collator = DataCollatorForLanguageModeling(
                tokenizer=tokenizer,
                mlm=False,
                return_tensors="pt"
            )

            # Create trainer WITHOUT callbacks first
            trainer = Trainer(
                model=lora_model,
                args=training_args,
                train_dataset=tokenized_dataset,
                data_collator=data_collator,
                processing_class=tokenizer,
            )

            # Add a proper callback class
            class SimpleProgressCallback(TrainerCallback):
                def on_log(self, args, state, control, logs=None, **kwargs):
                    if logs and 'train_loss' in logs:
                        step = state.global_step
                        loss = logs['train_loss']
                        print(f"📊 Step {step}: Loss = {loss:.4f}")

                def on_train_begin(self, args, state, control, **kwargs):
                    print(f"🏁 Training started!")

                def on_train_end(self, args, state, control, **kwargs):
                    print(f"🏁 Training completed!")

            trainer.add_callback(SimpleProgressCallback())

            # Start training
            print(f"\n🏃 Starting LoRA training...")
            training_result = trainer.train()

            print(f"✅ Training completed!")
            print(f"📊 Final loss: {training_result.training_loss:.4f}")

            # Test after training
            self.test_model_generation(lora_model, tokenizer, "after training")

            # Save LoRA adapters
            adapter_path = f"{self.output_dir}/lora_adapters"
            print(f"\n💾 Saving LoRA adapters to {adapter_path}")
            lora_model.save_pretrained(adapter_path)
            tokenizer.save_pretrained(adapter_path)

            # Save training info
            training_info = {
                "timestamp": self.timestamp,
                "base_model": "microsoft/DialoGPT-medium",
                "dataset_size": len(dataset),
                "lora_rank": lora_config.r,
                "final_loss": float(training_result.training_loss),
                "device": str(self.device)
            }

            with open(f"{self.output_dir}/training_info.json", "w") as f:
                json.dump(training_info, f, indent=2)

            print(f"📁 All outputs saved to: {self.output_dir}")
            print("🎉 LoRA training completed successfully!")

            return lora_model, tokenizer, adapter_path

        except Exception as e:
            print(f"❌ LoRA training failed: {e}")
            import traceback
            traceback.print_exc()

            # Cleanup
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            raise



In [None]:
def demonstrate_trained_model(adapter_path):
    """Demonstrate the trained model"""
    print("\n🎯 Demonstrating Trained LoRA Model")
    print("=" * 50)

    try:
        # Load model
        tokenizer = AutoTokenizer.from_pretrained(adapter_path)
        base_model = AutoModelForCausalLM.from_pretrained(
            "microsoft/DialoGPT-medium",
            torch_dtype=torch.float16,
            device_map="auto"
        )
        model = PeftModel.from_pretrained(base_model, adapter_path)

        # Test prompts
        test_prompts = [
            "Instruction: Write a SELECT query\nInput: Get all customers\nOutput:",
            "Instruction: Write a COUNT query\nInput: Count total orders\nOutput:",
            "Instruction: Write a JOIN query\nInput: Join users and orders\nOutput:",
        ]

        model.eval()
        for i, prompt in enumerate(test_prompts, 1):
            print(f"\n🧪 Test {i}: {prompt.split('Input: ')[1].split('Output:')[0].strip()}")

            inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

            with torch.no_grad():
                generated = model.generate(
                    **inputs,
                    max_new_tokens=30,
                    temperature=0.7,
                    do_sample=True,
                    pad_token_id=tokenizer.pad_token_id
                )

                new_tokens = generated[0][inputs.input_ids.shape[1]:]
                generated_sql = tokenizer.decode(new_tokens, skip_special_tokens=True)
                print(f"🎉 Generated: {generated_sql.strip()}")

        print("\n✅ Model demonstration completed!")

    except Exception as e:
        print(f"❌ Demo failed: {e}")



In [None]:
# ===================================================================
# Main Execution
# ===================================================================

if __name__ == "__main__":
    # Run LoRA training
    trainer = WorkingLoRASQLTrainer()
    lora_model, tokenizer, adapter_path = trainer.run_lora_training()

    print("\n" + "="*60)
    print("🎉 LoRA Training Complete!")
    print(f"📁 Adapters saved to: {adapter_path}")

    # Demonstrate the trained model
    demonstrate_trained_model(adapter_path)