<a href="https://colab.research.google.com/github/ayagup/stablediffusion/blob/main/hf_lora_tpu_inferencing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    GenerationConfig,
)
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig
import warnings
warnings.filterwarnings('ignore')
import time

# Check TPU availability
print(f"PyTorch version: {torch.__version__}")
print(f"XLA version: {torch_xla.__version__}")

device = xm.xla_device()
print(f"Using device: {device}")

try:
    world_size = xr.world_size()
    print(f"Number of TPU cores: {world_size}")
except:
    print("World size not available, but TPU is working")


In [None]:

class LoRAInferenceEngine:
    def __init__(self):
        self.tokenizer = None
        self.model = None
        self.generation_config = None

    def setup_tokenizer_padding(self, tokenizer):
        """
        Properly setup tokenizer padding to avoid attention mask issues
        """
        # Check if pad token exists
        if tokenizer.pad_token is None:
            # Try to use different tokens for padding
            if tokenizer.unk_token is not None:
                tokenizer.pad_token = tokenizer.unk_token
                print(f"✓ Using unk_token as pad_token: {tokenizer.pad_token}")
            elif tokenizer.bos_token is not None:
                tokenizer.pad_token = tokenizer.bos_token
                print(f"✓ Using bos_token as pad_token: {tokenizer.pad_token}")
            else:
                # Add a new pad token
                tokenizer.add_special_tokens({'pad_token': '[PAD]'})
                print(f"✓ Added new pad_token: {tokenizer.pad_token}")
        else:
            print(f"✓ Using existing pad_token: {tokenizer.pad_token}")

        # Set padding side to left for generation tasks
        tokenizer.padding_side = "left"

        return tokenizer

    def download_and_load_models(self,
                                base_model_name="microsoft/DialoGPT-medium",
                                lora_model_name=None,
                                create_synthetic_lora=True):
        """
        Download and load base model and LoRA adapter with proper tokenizer setup
        """
        print(f"Loading base model: {base_model_name}")

        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)

        # Setup proper padding
        self.tokenizer = self.setup_tokenizer_padding(self.tokenizer)

        # Load base model
        self.model = AutoModelForCausalLM.from_pretrained(
            base_model_name,
            torch_dtype=torch.float32,  # TPU works better with float32
            device_map=None,  # We'll move to TPU manually
            trust_remote_code=True
        )

        # Resize model embeddings if we added new tokens
        self.model.resize_token_embeddings(len(self.tokenizer))

        print(f"✓ Base model loaded successfully")
        print(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
        print(f"Vocabulary size: {len(self.tokenizer)}")

        # Create or load LoRA adapter
        if create_synthetic_lora:
            print("Creating synthetic LoRA adapter...")
            self.create_synthetic_lora()
        elif lora_model_name:
            print(f"Loading LoRA adapter: {lora_model_name}")
            self.load_lora_adapter(lora_model_name)

        # Move model to TPU
        print("Moving model to TPU...")
        self.model = self.model.to(device)
        print("✓ Model moved to TPU successfully")

        # Setup generation config with proper token IDs
        self.generation_config = GenerationConfig(
            max_new_tokens=100,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            top_k=50,
            repetition_penalty=1.1,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            bos_token_id=self.tokenizer.bos_token_id if self.tokenizer.bos_token_id else self.tokenizer.eos_token_id,
        )

        print("✓ Model setup completed!")
        print(f"Pad token ID: {self.tokenizer.pad_token_id}")
        print(f"EOS token ID: {self.tokenizer.eos_token_id}")

    def create_synthetic_lora(self):
        """
        Create a synthetic LoRA adapter for demonstration
        """
        # Define LoRA configuration
        lora_config = LoraConfig(
            r=16,  # rank
            lora_alpha=32,
            target_modules=["c_attn", "c_proj"],  # DialoGPT specific modules
            lora_dropout=0.1,
            bias="none",
            task_type="CAUSAL_LM",
        )

        # Add LoRA adapter to the model
        self.model = get_peft_model(self.model, lora_config)

        print(f"✓ Synthetic LoRA adapter created")
        print(f"Trainable parameters: {sum(p.numel() for p in self.model.parameters() if p.requires_grad):,}")

        # Simulate some training by slightly modifying LoRA weights
        self.simulate_lora_training()

    def simulate_lora_training(self):
        """
        Simulate LoRA training by adding small random modifications
        """
        print("Simulating LoRA training...")

        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if 'lora_' in name and param.requires_grad:
                    # Add small random noise to simulate training
                    noise = torch.randn_like(param) * 0.01
                    param.add_(noise)

        print("✓ LoRA simulation completed")

    def load_lora_adapter(self, lora_model_name):
        """
        Load a real LoRA adapter from Hugging Face
        """
        try:
            # Load LoRA config
            peft_config = PeftConfig.from_pretrained(lora_model_name)

            # Load LoRA model
            self.model = PeftModel.from_pretrained(
                self.model,
                lora_model_name,
                torch_dtype=torch.float32
            )

            print(f"✓ LoRA adapter loaded from {lora_model_name}")

        except Exception as e:
            print(f"Failed to load LoRA adapter: {e}")
            print("Creating synthetic LoRA instead...")
            self.create_synthetic_lora()

    def generate_response(self, prompt, add_context=True):
        """
        Generate response for a given prompt with proper attention masks
        """
        # Add conversational context if requested
        if add_context:
            formatted_prompt = f"Human: {prompt} Assistant:"
        else:
            formatted_prompt = prompt

        # Tokenize input with proper attention mask
        encoded = self.tokenizer(
            formatted_prompt,
            return_tensors="pt",
            truncation=True,
            max_length=512,
            padding=True,  # Enable padding
            return_attention_mask=True  # Explicitly return attention mask
        )

        # Move to device
        input_ids = encoded['input_ids'].to(device)
        attention_mask = encoded['attention_mask'].to(device)

        # Generate response
        start_time = time.time()
        with torch.no_grad():
            outputs = self.model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,  # Pass attention mask
                generation_config=self.generation_config,
                pad_token_id=self.tokenizer.pad_token_id,
            )
        generation_time = time.time() - start_time

        # Decode response
        full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Extract only the new response
        if add_context and "Assistant:" in full_response:
            response = full_response.split("Assistant:")[-1].strip()
        else:
            response = full_response[len(formatted_prompt):].strip()

        return response, generation_time

    def batch_inference(self, prompts, add_context=True):
        """
        Perform inference on multiple prompts with proper batching
        """
        print(f"\n{'='*80}")
        print(f"🔥 PERFORMING BATCH INFERENCE ON {len(prompts)} PROMPTS")
        print(f"{'='*80}")

        results = []
        total_time = 0

        for i, prompt in enumerate(prompts, 1):
            print(f"\n📝 PROMPT {i}/{len(prompts)}:")
            print(f"{'─'*60}")
            print(f"💭 Input: {prompt}")
            print(f"{'─'*60}")

            try:
                # Generate response
                response, gen_time = self.generate_response(prompt, add_context)
                total_time += gen_time

                print(f"🤖 Output: {response}")
                print(f"⏱️  Generation time: {gen_time:.2f}s")

                results.append({
                    'prompt': prompt,
                    'response': response,
                    'generation_time': gen_time,
                    'status': 'success'
                })

            except Exception as e:
                print(f"❌ Error generating response: {e}")
                results.append({
                    'prompt': prompt,
                    'response': f"Error: {str(e)}",
                    'generation_time': 0,
                    'status': 'error'
                })

            print(f"{'─'*60}")

        # Print summary
        successful_results = [r for r in results if r['status'] == 'success']
        print(f"\n{'='*80}")
        print(f"📊 INFERENCE SUMMARY")
        print(f"{'='*80}")
        print(f"Total prompts processed: {len(prompts)}")
        print(f"Successful generations: {len(successful_results)}")
        print(f"Failed generations: {len(prompts) - len(successful_results)}")
        if successful_results:
            print(f"Total time: {total_time:.2f}s")
            print(f"Average time per prompt: {total_time/len(successful_results):.2f}s")
        print(f"{'='*80}")

        return results

    def test_tokenizer_setup(self):
        """
        Test tokenizer configuration to ensure proper attention mask handling
        """
        print(f"\n{'='*60}")
        print(f"🔧 TESTING TOKENIZER CONFIGURATION")
        print(f"{'='*60}")

        test_text = "Hello, how are you?"

        # Test encoding
        encoded = self.tokenizer(
            test_text,
            return_tensors="pt",
            padding=True,
            return_attention_mask=True
        )

        print(f"Test text: {test_text}")
        print(f"Input IDs shape: {encoded['input_ids'].shape}")
        print(f"Attention mask shape: {encoded['attention_mask'].shape}")
        print(f"Input IDs: {encoded['input_ids']}")
        print(f"Attention mask: {encoded['attention_mask']}")
        print(f"Pad token: '{self.tokenizer.pad_token}' (ID: {self.tokenizer.pad_token_id})")
        print(f"EOS token: '{self.tokenizer.eos_token}' (ID: {self.tokenizer.eos_token_id})")

        # Check if pad token is different from EOS token
        if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
            print("⚠️  WARNING: Pad token is same as EOS token!")
        else:
            print("✅ Pad token is different from EOS token")

        print(f"{'='*60}")


In [None]:

def create_sample_prompts():
    """
    Create diverse sample prompts for testing
    """
    prompts = [
        # Short prompts
        "Hello!",
        "How are you?",
        "What's AI?",

        # Medium prompts
        "Tell me about machine learning",
        "What's your favorite programming language?",
        "Explain quantum computing simply",

        # Longer prompts
        "Can you write a short story about a robot who learns to paint?",
        "What are the most important technological advances of the 21st century?",
        "How do you think artificial intelligence will change education in the future?",

        # Diverse topics
        "What makes a good friend?",
        "Describe the perfect vacation",
        "What would you do with a million dollars?",
        "How can we protect the environment?",
        "What's the best way to learn a new skill?",
        "Tell me about space exploration"
    ]

    return prompts


In [None]:

def main():
    """
    Main execution function with proper error handling
    """
    try:
        print("🚀 Starting LoRA Inference Engine on TPU...")

        # Initialize inference engine
        inference_engine = LoRAInferenceEngine()

        # Download and load models
        inference_engine.download_and_load_models(
            base_model_name="microsoft/DialoGPT-medium",
            create_synthetic_lora=True
        )

        # Test tokenizer configuration
        inference_engine.test_tokenizer_setup()

        # Create sample prompts
        sample_prompts = create_sample_prompts()

        # Perform batch inference on subset of prompts
        results = inference_engine.batch_inference(sample_prompts[:8])  # Use first 8 prompts

        # Show successful results
        successful_results = [r for r in results if r['status'] == 'success']
        if successful_results:
            print(f"\n{'='*80}")
            print(f"🏆 SUCCESSFUL GENERATIONS")
            print(f"{'='*80}")

            # Sort by generation time
            sorted_results = sorted(successful_results, key=lambda x: x['generation_time'])
            for i, result in enumerate(sorted_results[:5], 1):
                print(f"\n{i}. ⚡ Response ({result['generation_time']:.2f}s):")
                print(f"   💭 Prompt: {result['prompt']}")
                print(f"   🤖 Response: {result['response'][:100]}...")  # Truncate long responses

        print(f"\n✅ LoRA Inference completed successfully!")
        return inference_engine, results

    except Exception as e:
        print(f"❌ Error: {e}")
        import traceback
        traceback.print_exc()
        return None, None


In [None]:

def quick_test():
    """
    Quick test with minimal setup and better error handling
    """
    print("🔬 Quick LoRA Inference Test")

    try:
        inference_engine = LoRAInferenceEngine()
        inference_engine.download_and_load_models(
            base_model_name="microsoft/DialoGPT-small",  # Smaller model for quick test
            create_synthetic_lora=True
        )

        # Test tokenizer
        inference_engine.test_tokenizer_setup()

        # Test with a few simple prompts
        test_prompts = [
            "Hi there!",
            "How's it going?",
            "Tell me something interesting"
        ]

        results = inference_engine.batch_inference(test_prompts)
        print("✅ Quick test completed!")
        return results

    except Exception as e:
        print(f"❌ Quick test failed: {e}")
        import traceback
        traceback.print_exc()
        return None


In [None]:

if __name__ == "__main__":
    print("🎯 LoRA Inference Engine on TPU (Fixed Version)")
    print("Choose an option:")
    print("1. Full inference demo")
    print("2. Quick test")

    # Run quick test by default for safety
    choice = "2"

    if choice == "1":
        engine, results = main()
    elif choice == "2":
        results = quick_test()
    else:
        print("Running quick test by default...")
        results = quick_test()