<a href="https://colab.research.google.com/github/ayagup/stablediffusion/blob/main/hf_lora_gpu_inferencing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch transformers peft accelerate bitsandbytes

In [None]:
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    GenerationConfig,
    BitsAndBytesConfig
)
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig
import warnings
warnings.filterwarnings('ignore')
import time
import gc

# Check GPU availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    device = torch.device("cuda")
else:
    print("Using CPU")
    device = torch.device("cpu")


In [None]:

class LoRAInferenceEngine:
    def __init__(self):
        self.tokenizer = None
        self.model = None
        self.generation_config = None
        self.base_model_name = None
        self.lora_model_name = None

    def download_and_load_models(self,
                                base_model_name="microsoft/DialoGPT-medium",
                                lora_model_name=None,
                                create_synthetic_lora=True,
                                use_4bit=False):
        """
        Download and load base model and LoRA adapter
        """
        self.base_model_name = base_model_name
        self.lora_model_name = lora_model_name

        print(f"\n{'='*70}")
        print(f"📥 LOADING MODELS")
        print(f"Base model: {base_model_name}")
        print(f"LoRA model: {lora_model_name if lora_model_name else 'Synthetic LoRA'}")
        print(f"{'='*70}")

        # Load tokenizer
        print("Loading tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)

        # Set up tokenizer properly
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        print("✅ Tokenizer loaded successfully")

        # Configure quantization for memory efficiency (optional)
        model_kwargs = {
            "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
            "device_map": "auto" if torch.cuda.is_available() else None,
            "trust_remote_code": True,
        }

        if use_4bit and torch.cuda.is_available():
            print("Using 4-bit quantization for memory efficiency...")
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_use_double_quant=True,
            )
            model_kwargs["quantization_config"] = quantization_config

        # Load base model
        print("Loading base model...")
        self.model = AutoModelForCausalLM.from_pretrained(
            base_model_name,
            **model_kwargs
        )

        print(f"✅ Base model loaded successfully")
        print(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")

        # Load or create LoRA adapter
        if lora_model_name and not create_synthetic_lora:
            print(f"Loading LoRA adapter: {lora_model_name}")
            self.load_real_lora_adapter(lora_model_name)
        else:
            print("Creating synthetic LoRA adapter...")
            self.create_synthetic_lora()

        # Move to device if not using device_map
        if not torch.cuda.is_available() or not use_4bit:
            print(f"Moving model to {device}...")
            self.model = self.model.to(device)

        # Setup generation configuration
        self.generation_config = GenerationConfig(
            max_new_tokens=150,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            top_k=50,
            repetition_penalty=1.1,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
        )

        print("✅ Model setup completed!")
        self.print_gpu_memory()

    def load_real_lora_adapter(self, lora_model_name):
        """
        Load a real LoRA adapter from Hugging Face
        """
        try:
            # Load LoRA config
            peft_config = PeftConfig.from_pretrained(lora_model_name)
            print(f"LoRA config: {peft_config}")

            # Load LoRA model
            self.model = PeftModel.from_pretrained(
                self.model,
                lora_model_name,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
            )

            print(f"✅ LoRA adapter loaded from {lora_model_name}")

            # Print LoRA info
            if hasattr(self.model, 'peft_config'):
                print(f"LoRA rank: {self.model.peft_config.get('default', {}).get('r', 'N/A')}")
                print(f"LoRA alpha: {self.model.peft_config.get('default', {}).get('lora_alpha', 'N/A')}")

        except Exception as e:
            print(f"❌ Failed to load LoRA adapter: {e}")
            print("Creating synthetic LoRA instead...")
            self.create_synthetic_lora()

    def create_synthetic_lora(self):
        """
        Create a synthetic LoRA adapter for demonstration
        """
        # Define LoRA configuration
        lora_config = LoraConfig(
            r=16,  # rank
            lora_alpha=32,
            target_modules=["c_attn", "c_proj", "c_fc"],  # Common transformer modules
            lora_dropout=0.1,
            bias="none",
            task_type="CAUSAL_LM",
        )

        # Add LoRA adapter to the model
        self.model = get_peft_model(self.model, lora_config)

        print(f"✅ Synthetic LoRA adapter created")
        print(f"LoRA rank: {lora_config.r}")
        print(f"LoRA alpha: {lora_config.lora_alpha}")
        print(f"Target modules: {lora_config.target_modules}")

        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in self.model.parameters())
        print(f"Trainable parameters: {trainable_params:,}")
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable %: {100 * trainable_params / total_params:.2f}%")

        # Simulate some training by slightly modifying LoRA weights
        self.simulate_lora_training()

    def simulate_lora_training(self):
        """
        Simulate LoRA training by adding small random modifications
        """
        print("Simulating LoRA fine-tuning...")

        with torch.no_grad():
            modified_params = 0
            for name, param in self.model.named_parameters():
                if 'lora_' in name and param.requires_grad:
                    # Add small random noise to simulate training
                    noise = torch.randn_like(param) * 0.01
                    param.add_(noise)
                    modified_params += 1

        print(f"✅ Simulated training on {modified_params} LoRA parameters")

    def merge_lora_weights(self):
        """
        Merge LoRA weights into base model for faster inference
        """
        if hasattr(self.model, 'merge_and_unload'):
            print("Merging LoRA weights into base model...")
            self.model = self.model.merge_and_unload()
            print("✅ LoRA weights merged successfully")
            self.print_gpu_memory()
        else:
            print("Model doesn't support LoRA merging, continuing with adapter")

    def generate_response(self, prompt, max_new_tokens=100, temperature=0.7):
        """
        Generate response for a given prompt
        """
        # Prepare input
        formatted_prompt = f"Human: {prompt}\nAssistant:"

        # Tokenize
        inputs = self.tokenizer(
            formatted_prompt,
            return_tensors="pt",
            truncation=True,
            max_length=512,
            padding=False
        )

        # Move to device
        input_ids = inputs['input_ids'].to(self.model.device)
        attention_mask = inputs.get('attention_mask')
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.model.device)

        # Generate
        start_time = time.time()
        with torch.no_grad():
            outputs = self.model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=temperature,
                top_p=0.9,
                top_k=50,
                repetition_penalty=1.1,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
            )
        generation_time = time.time() - start_time

        # Decode response
        full_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Extract generated part
        if "Assistant:" in full_text:
            response = full_text.split("Assistant:")[-1].strip()
        else:
            response = full_text[len(formatted_prompt):].strip()

        return response, generation_time

    def batch_inference(self, prompts, max_new_tokens=100, temperature=0.7):
        """
        Perform inference on multiple prompts
        """
        print(f"\n{'='*80}")
        print(f"🔥 PERFORMING INFERENCE ON {len(prompts)} PROMPTS")
        print(f"Base Model: {self.base_model_name}")
        print(f"LoRA Model: {self.lora_model_name or 'Synthetic'}")
        print(f"Device: {self.model.device}")
        print(f"Max new tokens: {max_new_tokens}")
        print(f"Temperature: {temperature}")
        print(f"{'='*80}")

        results = []
        total_time = 0

        for i, prompt in enumerate(prompts, 1):
            print(f"\n📝 PROMPT {i}/{len(prompts)}:")
            print(f"{'─'*60}")
            print(f"💭 Input: {prompt}")
            print(f"{'─'*60}")

            try:
                response, gen_time = self.generate_response(
                    prompt, max_new_tokens, temperature
                )
                total_time += gen_time

                print(f"🤖 Output: {response}")
                print(f"⏱️  Generation time: {gen_time:.2f}s")
                print(f"📏 Response length: {len(response.split())} words")

                results.append({
                    'prompt': prompt,
                    'response': response,
                    'generation_time': gen_time,
                    'success': True
                })

            except Exception as e:
                print(f"❌ Error: {e}")
                results.append({
                    'prompt': prompt,
                    'response': f"Error: {e}",
                    'generation_time': 0,
                    'success': False
                })

            print(f"{'─'*60}")

            # Clean up GPU memory periodically
            if torch.cuda.is_available() and i % 3 == 0:
                torch.cuda.empty_cache()

        # Print summary
        successful_results = [r for r in results if r['success']]
        print(f"\n{'='*80}")
        print(f"📊 INFERENCE SUMMARY")
        print(f"{'='*80}")
        print(f"Total prompts: {len(prompts)}")
        print(f"Successful: {len(successful_results)}")
        print(f"Failed: {len(prompts) - len(successful_results)}")
        if successful_results:
            print(f"Total time: {total_time:.2f}s")
            print(f"Average time per prompt: {total_time/len(successful_results):.2f}s")

        self.print_gpu_memory()
        print(f"{'='*80}")

        return results

    def print_gpu_memory(self):
        """Print GPU memory usage"""
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated() / 1024**3
            reserved = torch.cuda.memory_reserved() / 1024**3
            print(f"🔧 GPU Memory - Allocated: {allocated:.1f}GB, Reserved: {reserved:.1f}GB")


In [None]:

def create_sample_prompts():
    """
    Create diverse sample prompts for testing
    """
    prompts = [
        # Conversational prompts
        "Hello! How are you today?",
        "What's your favorite book and why?",
        "Can you tell me a joke?",

        # Creative writing
        "Write a short story about a time traveler",
        "Describe a beautiful sunset over the ocean",
        "What would happen if animals could talk?",

        # Question answering
        "Explain artificial intelligence in simple terms",
        "What are the benefits of renewable energy?",
        "How does the internet work?",

        # Problem solving
        "Give me 3 tips for staying productive",
        "How can I learn a new language effectively?",
        "What's the best way to reduce stress?",

        # Technical topics
        "Explain machine learning algorithms",
        "What is cloud computing?",
        "How do neural networks work?",

        # Fun and philosophical
        "If you could have any superpower, what would it be?",
        "What's the meaning of life?",
        "Describe your ideal vacation destination"
    ]

    return prompts


In [None]:

def test_different_configurations(engine):
    """
    Test different generation configurations
    """
    print(f"\n{'='*80}")
    print(f"🧪 TESTING DIFFERENT CONFIGURATIONS")
    print(f"{'='*80}")

    test_prompt = "The future of technology is"

    configurations = [
        {"name": "Conservative", "temperature": 0.3, "max_tokens": 50},
        {"name": "Balanced", "temperature": 0.7, "max_tokens": 80},
        {"name": "Creative", "temperature": 1.0, "max_tokens": 100},
    ]

    for config in configurations:
        print(f"\n🔧 {config['name']} Configuration:")
        print(f"   Temperature: {config['temperature']}")
        print(f"   Max tokens: {config['max_tokens']}")
        print(f"{'─'*50}")

        response, gen_time = engine.generate_response(
            test_prompt,
            max_new_tokens=config['max_tokens'],
            temperature=config['temperature']
        )

        print(f"💭 Prompt: {test_prompt}")
        print(f"🤖 Response: {response}")
        print(f"⏱️  Time: {gen_time:.2f}s")


In [None]:

def main():
    """
    Main execution function
    """
    try:
        print("🚀 Starting LoRA Inference Engine on GPU...")

        # Initialize inference engine
        engine = LoRAInferenceEngine()

        # You can try different model combinations:
        model_options = [
            {
                "base": "microsoft/DialoGPT-medium",
                "lora": None,  # Will create synthetic LoRA
                "use_4bit": False
            },
            # Example with a real LoRA (uncomment if you have one):
            # {
            #     "base": "microsoft/DialoGPT-medium",
            #     "lora": "some-username/dialogpt-lora-adapter",
            #     "use_4bit": True
            # }
        ]

        # Use the first option
        config = model_options[0]

        # Load models
        engine.download_and_load_models(
            base_model_name=config["base"],
            lora_model_name=config["lora"],
            create_synthetic_lora=True,
            use_4bit=config["use_4bit"]
        )

        # Optionally merge LoRA weights for faster inference
        # engine.merge_lora_weights()

        # Test different configurations
        test_different_configurations(engine)

        # Create sample prompts
        sample_prompts = create_sample_prompts()

        # Perform batch inference
        results = engine.batch_inference(sample_prompts[:10])  # Use first 10 prompts

        # Show best results
        successful_results = [r for r in results if r['success']]
        if successful_results:
            print(f"\n{'='*80}")
            print(f"🏆 TOP RESPONSES")
            print(f"{'='*80}")

            # Sort by generation time and quality
            sorted_results = sorted(successful_results,
                                  key=lambda x: (x['generation_time'], -len(x['response'])))

            for i, result in enumerate(sorted_results[:5], 1):
                print(f"\n{i}. ⚡ Response ({result['generation_time']:.2f}s):")
                print(f"   💭 Prompt: {result['prompt']}")
                print(f"   🤖 Response: {result['response'][:200]}...")

        print(f"\n✅ LoRA Inference completed successfully!")
        return engine, results

    except Exception as e:
        print(f"❌ Error: {e}")
        import traceback
        traceback.print_exc()
        return None, None


In [None]:

def quick_gpu_test():
    """
    Quick test to verify GPU functionality
    """
    print("🔬 Quick GPU Test")

    if not torch.cuda.is_available():
        print("❌ CUDA not available. Make sure you're using GPU runtime.")
        return False

    try:
        # Test basic GPU operations
        x = torch.randn(1000, 1000).cuda()
        y = torch.randn(1000, 1000).cuda()
        z = torch.matmul(x, y)
        print(f"✅ GPU computation test passed!")
        print(f"Result shape: {z.shape}")

        # Clean up
        del x, y, z
        torch.cuda.empty_cache()

        return True

    except Exception as e:
        print(f"❌ GPU test failed: {e}")
        return False


In [None]:

if __name__ == "__main__":
    print("🎯 LoRA Inference Engine on GPU")
    print("Make sure you're using GPU runtime in Colab!")
    print("Runtime -> Change runtime type -> GPU")

    # Quick GPU test first
    if quick_gpu_test():
        # Run main program
        engine, results = main()
    else:
        print("Please switch to GPU runtime and try again.")