In [None]:
!pip install bitsandbytes -q

In [None]:
from huggingface_hub import login
login("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")

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import json
from datetime import datetime

# ============================================
# LOAD YOUR FINE-TUNED MODEL
# ============================================

def load_finetuned_model(lora_path, base_model_id="google/gemma-2-2b-it"):
    """Load the fine-tuned LoRA model"""
    
    print("Loading base model with quantization...")
    
    # Same quantization config as training
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True
    )
    
    # Load base model
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_id,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
    )
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(base_model_id)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    
    print("Loading LoRA adapter...")
    # Load LoRA weights
    model = PeftModel.from_pretrained(base_model, lora_path)
    model.eval()
    
    print("âœ“ Model loaded successfully!")
    return model, tokenizer


# ============================================
# GENERATION FUNCTIONS WITH DIFFERENT STRATEGIES
# ============================================

def generate_story(model, tokenizer, prompt, strategy="greedy", temperature=1.0, top_k=50, top_p=0.9, num_beams=1, max_length=512):
    """
    Generate story with different decoding strategies
    
    Strategies:
    - greedy: Always picks highest probability token (deterministic)
    - sampling: Random sampling with temperature control
    - top_k: Sample from top K most likely tokens
    - top_p: Nucleus sampling - sample from smallest set of tokens with cumulative prob >= p
    - beam: Beam search (explores multiple paths)
    """
    
    # Format input in Gemma's chat template
    formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
    
    # Tokenize
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
    
    # Generation parameters based on strategy
    gen_kwargs = {
        "max_length": max_length,
        "pad_token_id": tokenizer.eos_token_id,
        "eos_token_id": tokenizer.eos_token_id,
    }
    
    if strategy == "greedy":
        gen_kwargs.update({
            "do_sample": False,
            "num_beams": 1,
        })
        
    elif strategy == "sampling":
        gen_kwargs.update({
            "do_sample": True,
            "temperature": temperature,
        })
        
    elif strategy == "top_k":
        gen_kwargs.update({
            "do_sample": True,
            "temperature": temperature,
            "top_k": top_k,
        })
        
    elif strategy == "top_p":
        gen_kwargs.update({
            "do_sample": True,
            "temperature": temperature,
            "top_p": top_p,
        })
        
    elif strategy == "beam":
        gen_kwargs.update({
            "do_sample": False,
            "num_beams": num_beams,
            "early_stopping": True,
        })
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(**inputs, **gen_kwargs)
    
    # Decode
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract only the model's response
    if "<start_of_turn>model" in generated_text:
        response = generated_text.split("<start_of_turn>model")[-1].strip()
    else:
        response = generated_text
    
    return response


# ============================================
# MAIN EXECUTION WITH ALL CONFIGURATIONS
# ============================================

if __name__ == "__main__":
    
    # Load your fine-tuned model
    LORA_PATH = "/kaggle/input/gemma-2-gold/transformers/default/1"  # Path to your saved LoRA adapter
    
    model, tokenizer = load_finetuned_model(LORA_PATH)
    
    # Story prompt with woodcutter and pond
    prompt = "Create a scene at a mystical pond with mood: mysterious. Characters: Woodcutter, Pond Spirit"
    
    # Define all 74 configurations
    configurations = [
        ("greedy", {}),
        ("sampling", {"temperature": 0.1}),
        ("sampling", {"temperature": 0.2}),
        ("sampling", {"temperature": 0.3}),
        ("sampling", {"temperature": 0.4}),
        ("sampling", {"temperature": 0.5}),
        ("sampling", {"temperature": 0.6}),
        ("sampling", {"temperature": 0.7}),
        ("sampling", {"temperature": 0.8}),
        ("sampling", {"temperature": 0.9}),
        ("sampling", {"temperature": 1.0}),
        ("top_k", {"temperature": 0.1, "top_k": 10}),
        ("top_k", {"temperature": 0.2, "top_k": 10}),
        ("top_k", {"temperature": 0.3, "top_k": 10}),
        ("top_k", {"temperature": 0.4, "top_k": 10}),
        ("top_k", {"temperature": 0.5, "top_k": 10}),
        ("top_k", {"temperature": 0.6, "top_k": 10}),
        ("top_k", {"temperature": 0.7, "top_k": 10}),
        ("top_k", {"temperature": 0.8, "top_k": 10}),
        ("top_k", {"temperature": 0.9, "top_k": 10}),
        ("top_k", {"temperature": 1.0, "top_k": 10}),        
        ("top_k", {"temperature": 0.1, "top_k": 30}),
        ("top_k", {"temperature": 0.2, "top_k": 30}),
        ("top_k", {"temperature": 0.3, "top_k": 30}),
        ("top_k", {"temperature": 0.4, "top_k": 30}),
        ("top_k", {"temperature": 0.5, "top_k": 30}),
        ("top_k", {"temperature": 0.6, "top_k": 30}),
        ("top_k", {"temperature": 0.7, "top_k": 30}),
        ("top_k", {"temperature": 0.8, "top_k": 30}),
        ("top_k", {"temperature": 0.9, "top_k": 30}),
        ("top_k", {"temperature": 1.0, "top_k": 30}),
        ("top_k", {"temperature": 0.1, "top_k": 50}),
        ("top_k", {"temperature": 0.2, "top_k": 50}),
        ("top_k", {"temperature": 0.3, "top_k": 50}),
        ("top_k", {"temperature": 0.4, "top_k": 50}),
        ("top_k", {"temperature": 0.5, "top_k": 50}),
        ("top_k", {"temperature": 0.6, "top_k": 50}),
        ("top_k", {"temperature": 0.7, "top_k": 50}),
        ("top_k", {"temperature": 0.8, "top_k": 50}),
        ("top_k", {"temperature": 0.9, "top_k": 50}),
        ("top_k", {"temperature": 1.0, "top_k": 50}),        
        ("top_p", {"temperature": 0.1, "top_p": 0.5}),
        ("top_p", {"temperature": 0.2, "top_p": 0.5}),
        ("top_p", {"temperature": 0.3, "top_p": 0.5}),
        ("top_p", {"temperature": 0.4, "top_p": 0.5}),
        ("top_p", {"temperature": 0.5, "top_p": 0.5}),
        ("top_p", {"temperature": 0.6, "top_p": 0.5}),
        ("top_p", {"temperature": 0.7, "top_p": 0.5}),
        ("top_p", {"temperature": 0.8, "top_p": 0.5}),
        ("top_p", {"temperature": 0.9, "top_p": 0.5}),
        ("top_p", {"temperature": 1.0, "top_p": 0.5}),
        ("top_p", {"temperature": 0.1, "top_p": 0.9}),
        ("top_p", {"temperature": 0.2, "top_p": 0.9}),
        ("top_p", {"temperature": 0.3, "top_p": 0.9}),
        ("top_p", {"temperature": 0.4, "top_p": 0.9}),
        ("top_p", {"temperature": 0.5, "top_p": 0.9}),
        ("top_p", {"temperature": 0.6, "top_p": 0.9}),
        ("top_p", {"temperature": 0.7, "top_p": 0.9}),
        ("top_p", {"temperature": 0.8, "top_p": 0.9}),
        ("top_p", {"temperature": 0.9, "top_p": 0.9}),
        ("top_p", {"temperature": 1.0, "top_p": 0.9}),
        ("top_p", {"temperature": 0.1, "top_p": 0.98}),
        ("top_p", {"temperature": 0.2, "top_p": 0.98}),
        ("top_p", {"temperature": 0.3, "top_p": 0.98}),
        ("top_p", {"temperature": 0.4, "top_p": 0.98}),
        ("top_p", {"temperature": 0.5, "top_p": 0.98}),
        ("top_p", {"temperature": 0.6, "top_p": 0.98}),
        ("top_p", {"temperature": 0.7, "top_p": 0.98}),
        ("top_p", {"temperature": 0.8, "top_p": 0.98}),
        ("top_p", {"temperature": 0.9, "top_p": 0.98}),
        ("top_p", {"temperature": 1.0, "top_p": 0.98}),
        ("beam", {"num_beams": 3}),
        ("beam", {"num_beams": 5}),
        ("beam", {"num_beams": 10})
    ]
    
    print("\n" + "="*80)
    print(f"GENERATING STORIES WITH {len(configurations)} CONFIGURATIONS")
    print("="*80 + "\n")
    
    # Store all results
    results = []
    
    # Generate stories for all configurations
    for idx, (strategy, kwargs) in enumerate(configurations, 1):
        # Create descriptive label
        if strategy == "greedy":
            label = "Greedy"
        elif strategy == "sampling":
            label = f"Sampling (T={kwargs['temperature']})"
        elif strategy == "top_k":
            label = f"Top-K (T={kwargs['temperature']}, K={kwargs['top_k']})"
        elif strategy == "top_p":
            label = f"Top-P (T={kwargs['temperature']}, P={kwargs['top_p']})"
        elif strategy == "beam":
            label = f"Beam Search (beams={kwargs['num_beams']})"
        
        print(f"\n{'='*80}")
        print(f"Configuration {idx}/{len(configurations)}: {label}")
        print('='*80)
        
        try:
            story = generate_story(model, tokenizer, prompt, strategy=strategy, **kwargs)
            print(story)
            
            # Store result
            results.append({
                "config_id": idx,
                "strategy": strategy,
                "parameters": kwargs,
                "label": label,
                "story": story
            })
            
        except Exception as e:
            print(f"ERROR: {e}")
            results.append({
                "config_id": idx,
                "strategy": strategy,
                "parameters": kwargs,
                "label": label,
                "story": f"ERROR: {e}"
            })
    
    # Save results to JSON file
    output_file = f"story_generations_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump({
            "prompt": prompt,
            "total_configurations": len(configurations),
            "results": results
        }, f, indent=2, ensure_ascii=False)
    
    print("\n\n" + "="*80)
    print("GENERATION COMPLETE!")
    print(f"Results saved to: {output_file}")
    print(f"Total configurations tested: {len(configurations)}")
    print("="*80)


# ============================================
# ANALYSIS FUNCTION (OPTIONAL)
# ============================================

def analyze_results(results_file):
    """Analyze the generated stories to find patterns"""
    
    with open(results_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    results = data['results']
    
    print("\n" + "="*80)
    print("ANALYSIS OF GENERATED STORIES")
    print("="*80)
    
    # Group by strategy
    strategy_groups = {}
    for r in results:
        strat = r['strategy']
        if strat not in strategy_groups:
            strategy_groups[strat] = []
        strategy_groups[strat].append(r)
    
    print(f"\nTotal stories generated: {len(results)}")
    print("\nBreakdown by strategy:")
    for strat, stories in strategy_groups.items():
        print(f"  {strat}: {len(stories)} stories")
    
    # Find unique stories
    unique_stories = {}
    for r in results:
        story = r['story']
        if story not in unique_stories:
            unique_stories[story] = []
        unique_stories[story].append(r['label'])
    
    print(f"\nUnique stories generated: {len(unique_stories)}")
    print(f"Duplicate stories: {len(results) - len(unique_stories)}")
    
    # Show most common stories
    story_counts = [(story, len(configs)) for story, configs in unique_stories.items()]
    story_counts.sort(key=lambda x: x[1], reverse=True)
    
    print("\n--- Most common outputs (duplicates) ---")
    for story, count in story_counts[:5]:
        if count > 1:
            print(f"\nAppeared {count} times:")
            print(f"Configs: {', '.join(unique_stories[story])}")
            print(f"Story preview: {story[:150]}...")


# ============================================
# INTERACTIVE MODE (OPTIONAL)
# ============================================

def interactive_generation(model, tokenizer):
    """Interactive story generation with all configurations available"""
    
    print("\n" + "="*80)
    print("INTERACTIVE STORY GENERATOR")
    print("="*80)
    
    # Build menu from configurations
    configurations = [
        ("greedy", {}),
        ("sampling", {"temperature": 0.1}),
        ("sampling", {"temperature": 0.5}),
        ("sampling", {"temperature": 0.7}),
        ("sampling", {"temperature": 1.0}),
        ("top_k", {"temperature": 0.8, "top_k": 30}),
        ("top_p", {"temperature": 0.8, "top_p": 0.9}),
        ("beam", {"num_beams": 5}),
    ]
    
    while True:
        print("\n" + "-"*80)
        prompt = input("\nEnter your prompt (or 'quit' to exit): ")
        if prompt.lower() == 'quit':
            break
        
        print("\nSelect decoding strategy:")
        for i, (strat, kwargs) in enumerate(configurations, 1):
            if strat == "greedy":
                print(f"{i}. Greedy")
            elif strat == "sampling":
                print(f"{i}. Sampling (T={kwargs['temperature']})")
            elif strat == "top_k":
                print(f"{i}. Top-K (T={kwargs['temperature']}, K={kwargs['top_k']})")
            elif strat == "top_p":
                print(f"{i}. Top-P (T={kwargs['temperature']}, P={kwargs['top_p']})")
            elif strat == "beam":
                print(f"{i}. Beam Search (beams={kwargs['num_beams']})")
        
        choice = input(f"\nChoice (1-{len(configurations)}): ")
        
        try:
            idx = int(choice) - 1
            if 0 <= idx < len(configurations):
                strategy, kwargs = configurations[idx]
                print(f"\nGenerating with {strategy}...")
                story = generate_story(model, tokenizer, prompt, strategy=strategy, **kwargs)
                print("\n" + "="*80)
                print("GENERATED STORY:")
                print("="*80)
                print(story)
            else:
                print("Invalid choice!")
        except ValueError:
            print("Invalid input!")

# Uncomment to run interactive mode
# interactive_generation(model, tokenizer)

# Uncomment to analyze results from a previous run
# analyze_results("story_generations_20241106_123456.json")