In [None]:
!pip install bitsandbytes -q

In [None]:
from huggingface_hub import login
login("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")

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import os
from datetime import datetime

# ============================================
# IMPORTANT: Library Dependencies
# ============================================

# pip install --upgrade transformers bitsandbytes accelerate

# ============================================

# ============================================
# STORY PROMPT
# ============================================

STORY_PROMPT = """You are a master storyteller. Your task is to rewrite the classic fairy tale 'The Shoemaker and the Elves' into a **dystopian science-fiction story**.

**Your rewrite must adhere to the following rules but should also go beyond them with unexpected twists not stated below:**
* **Setting:** A grimy, neon-lit metropolis in the year 2242, where citizens are judged by their cybernetic enhancements.
* **The 'Shoemaker':** He is a reclusive old craftsman who illegally repairs and builds custom bio-mechanical limbs for outcasts.
* **The 'Elves':** They are a swarm of self-replicating nanobots that are believed to be a myth.
* **The 'Shoes':** The nanobots are not making shoes; they are performing impossibly intricate upgrades on the limbs the craftsman leaves on his workbench overnight.
* **Tone:** The story should be grim, mysterious, and slightly hopeful.
* **Ending:** The story must end with the craftsman discovering the nanobots and leaving a small charging plate with a micro-drop of refined energy for them as a gift, without ever seeing them directly."""

# ============================================
# MODEL LOADING FUNCTION
# ============================================

def load_base_model(base_model_id="google/gemma-2-2b-it"):
    """Load the original base model without fine-tuning"""
    
    print("Loading original base model with quantization...")
    
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True
    )
    
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_id,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
    )
    
    tokenizer = AutoTokenizer.from_pretrained(base_model_id)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    
    base_model.eval()
    print("✓ Original model loaded successfully!")
    return base_model, tokenizer

# ============================================
# GENERATION FUNCTION
# ============================================

def generate_story(model, tokenizer, prompt, strategy="greedy", temperature=1.0, 
                   top_k=50, top_p=0.9, num_beams=1, max_length=1024):
    """Generate story with different decoding strategies"""
    
    # Format input in Gemma's chat template
    formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
    
    # Tokenize
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
    
    # Generation parameters based on strategy
    gen_kwargs = {
        "max_length": max_length,
        "pad_token_id": tokenizer.eos_token_id,
        "eos_token_id": tokenizer.eos_token_id,
    }
    
    if strategy == "greedy":
        gen_kwargs.update({
            "do_sample": False,
            "num_beams": 1,
        })
        
    elif strategy == "sampling":
        gen_kwargs.update({
            "do_sample": True,
            "temperature": temperature,
        })
        
    elif strategy == "top_k":
        gen_kwargs.update({
            "do_sample": True,
            "temperature": temperature,
            "top_k": top_k,
        })
        
    elif strategy == "top_p":
        gen_kwargs.update({
            "do_sample": True,
            "temperature": temperature,
            "top_p": top_p,
        })
        
    elif strategy == "beam":
        gen_kwargs.update({
            "do_sample": False,
            "num_beams": num_beams,
            "early_stopping": True,
        })
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(**inputs, **gen_kwargs)
    
    # Decode
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract only the model's response
    if "<start_of_turn>model" in generated_text:
        response = generated_text.split("<start_of_turn>model")[-1].strip()
    else:
        response = generated_text
    
    return response

# ============================================
# FILE SAVING FUNCTION
# ============================================

def save_story_to_file(filepath, story_text, prompt, strategy, temperature, model_type, top_k=None, top_p=None):
    """Save generated story to a text file with metadata"""
    
    os.makedirs(os.path.dirname(filepath), exist_ok=True)
    
    with open(filepath, 'w', encoding='utf-8') as f:
        # f.write("="*80 + "\n")
        # f.write(f"GENERATED STORY - {model_type.upper()}\n")
        # f.write("="*80 + "\n\n")
        # f.write(f"Model Type: {model_type}\n")
        # f.write(f"Decoding Strategy: {strategy}\n")
        # if temperature is not None:
        #     f.write(f"Temperature: {temperature}\n")
        # if top_k is not None:
        #     f.write(f"Top-K: {top_k}\n")
        # if top_p is not None:
        #     f.write(f"Top-P: {top_p}\n")
        # f.write(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        # f.write("\n" + "="*80 + "\n")
        # f.write("PROMPT:\n")
        # f.write("="*80 + "\n")
        # f.write(prompt + "\n")
        # f.write("\n" + "="*80 + "\n")
        # f.write("GENERATED STORY:\n")
        # f.write("="*80 + "\n\n")
        f.write(story_text)
        # f.write("\n\n" + "="*80 + "\n")

# ============================================
# MAIN GENERATION FUNCTION
# ============================================

def generate_all_stories(base_model, tokenizer, output_dir="story_output"):
    """Generate stories organized by decoding strategy with temperature variations"""
    
    print("\n" + "="*80)
    print("GENERATING STORIES WITH THE BASE MODEL")
    print("="*80 + "\n")
    
    # Define all 74 configurations
    configurations = [
        ("greedy", {}),
        ("sampling", {"temperature": 0.1}),
        ("sampling", {"temperature": 0.2}),
        ("sampling", {"temperature": 0.3}),
        ("sampling", {"temperature": 0.4}),
        ("sampling", {"temperature": 0.5}),
        ("sampling", {"temperature": 0.6}),
        ("sampling", {"temperature": 0.7}),
        ("sampling", {"temperature": 0.8}),
        ("sampling", {"temperature": 0.9}),
        ("sampling", {"temperature": 1.0}),
        ("top_k", {"temperature": 0.1, "top_k": 10}),
        ("top_k", {"temperature": 0.2, "top_k": 10}),
        ("top_k", {"temperature": 0.3, "top_k": 10}),
        ("top_k", {"temperature": 0.4, "top_k": 10}),
        ("top_k", {"temperature": 0.5, "top_k": 10}),
        ("top_k", {"temperature": 0.6, "top_k": 10}),
        ("top_k", {"temperature": 0.7, "top_k": 10}),
        ("top_k", {"temperature": 0.8, "top_k": 10}),
        ("top_k", {"temperature": 0.9, "top_k": 10}),
        ("top_k", {"temperature": 1.0, "top_k": 10}),        
        ("top_k", {"temperature": 0.1, "top_k": 30}),
        ("top_k", {"temperature": 0.2, "top_k": 30}),
        ("top_k", {"temperature": 0.3, "top_k": 30}),
        ("top_k", {"temperature": 0.4, "top_k": 30}),
        ("top_k", {"temperature": 0.5, "top_k": 30}),
        ("top_k", {"temperature": 0.6, "top_k": 30}),
        ("top_k", {"temperature": 0.7, "top_k": 30}),
        ("top_k", {"temperature": 0.8, "top_k": 30}),
        ("top_k", {"temperature": 0.9, "top_k": 30}),
        ("top_k", {"temperature": 1.0, "top_k": 30}),
        ("top_k", {"temperature": 0.1, "top_k": 50}),
        ("top_k", {"temperature": 0.2, "top_k": 50}),
        ("top_k", {"temperature": 0.3, "top_k": 50}),
        ("top_k", {"temperature": 0.4, "top_k": 50}),
        ("top_k", {"temperature": 0.5, "top_k": 50}),
        ("top_k", {"temperature": 0.6, "top_k": 50}),
        ("top_k", {"temperature": 0.7, "top_k": 50}),
        ("top_k", {"temperature": 0.8, "top_k": 50}),
        ("top_k", {"temperature": 0.9, "top_k": 50}),
        ("top_k", {"temperature": 1.0, "top_k": 50}),        
        ("top_p", {"temperature": 0.1, "top_p": 0.5}),
        ("top_p", {"temperature": 0.2, "top_p": 0.5}),
        ("top_p", {"temperature": 0.3, "top_p": 0.5}),
        ("top_p", {"temperature": 0.4, "top_p": 0.5}),
        ("top_p", {"temperature": 0.5, "top_p": 0.5}),
        ("top_p", {"temperature": 0.6, "top_p": 0.5}),
        ("top_p", {"temperature": 0.7, "top_p": 0.5}),
        ("top_p", {"temperature": 0.8, "top_p": 0.5}),
        ("top_p", {"temperature": 0.9, "top_p": 0.5}),
        ("top_p", {"temperature": 1.0, "top_p": 0.5}),
        ("top_p", {"temperature": 0.1, "top_p": 0.9}),
        ("top_p", {"temperature": 0.2, "top_p": 0.9}),
        ("top_p", {"temperature": 0.3, "top_p": 0.9}),
        ("top_p", {"temperature": 0.4, "top_p": 0.9}),
        ("top_p", {"temperature": 0.5, "top_p": 0.9}),
        ("top_p", {"temperature": 0.6, "top_p": 0.9}),
        ("top_p", {"temperature": 0.7, "top_p": 0.9}),
        ("top_p", {"temperature": 0.8, "top_p": 0.9}),
        ("top_p", {"temperature": 0.9, "top_p": 0.9}),
        ("top_p", {"temperature": 1.0, "top_p": 0.9}),
        ("top_p", {"temperature": 0.1, "top_p": 0.98}),
        ("top_p", {"temperature": 0.2, "top_p": 0.98}),
        ("top_p", {"temperature": 0.3, "top_p": 0.98}),
        ("top_p", {"temperature": 0.4, "top_p": 0.98}),
        ("top_p", {"temperature": 0.5, "top_p": 0.98}),
        ("top_p", {"temperature": 0.6, "top_p": 0.98}),
        ("top_p", {"temperature": 0.7, "top_p": 0.98}),
        ("top_p", {"temperature": 0.8, "top_p": 0.98}),
        ("top_p", {"temperature": 0.9, "top_p": 0.98}),
        ("top_p", {"temperature": 1.0, "top_p": 0.98}),
        ("beam", {"num_beams": 3}),
        ("beam", {"num_beams": 5}),
        ("beam", {"num_beams": 10})
    ]
    
    total_files = len(configurations)
    
    # Generate stories for each configuration
    for idx, (strategy, kwargs) in enumerate(configurations, 1):
        # Create descriptive label for filename and display
        if strategy == "greedy":
            label = "greedy_default"
            display_label = "Greedy"
        elif strategy == "sampling":
            temp = kwargs['temperature']
            label = f"sampling_{temp:.1f}"
            display_label = f"Sampling (T={temp})"
        elif strategy == "top_k":
            temp = kwargs['temperature']
            k = kwargs['top_k']
            label = f"top_k_{temp:.1f}_k{k}"
            display_label = f"Top-K (T={temp}, K={k})"
        elif strategy == "top_p":
            temp = kwargs['temperature']
            p = kwargs['top_p']
            label = f"top_p_{temp:.1f}_p{p}"
            display_label = f"Top-P (T={temp}, P={p})"
        elif strategy == "beam":
            beams = kwargs['num_beams']
            label = f"beam_{beams}"
            display_label = f"Beam Search (beams={beams})"
        
        print(f"\n{'='*80}")
        print(f"[{idx}/{total_files}] {display_label}")
        print(f"{'='*80}")
        
        # Create strategy folder
        strategy_dir = os.path.join(output_dir, strategy)
        
        # Generate with BASE model
        print(f"Generating with BASE model...")
        
        base_story = generate_story(base_model, tokenizer, STORY_PROMPT, strategy=strategy, **kwargs)
        
        # Create filename
        filename = f"base_{label}.txt"
        filepath = os.path.join(strategy_dir, filename)
        
        # Save file
        save_story_to_file(
            filepath, 
            base_story, 
            STORY_PROMPT, 
            strategy, 
            kwargs.get('temperature'),
            "BASE MODEL",
            top_k=kwargs.get('top_k'),
            top_p=kwargs.get('top_p')
        )
        print(f"✓ Saved: {filename}")
    
    return total_files

# ============================================
# MAIN EXECUTION
# ============================================

if __name__ == "__main__":
    
    # Configuration
    BASE_MODEL_ID = "google/gemma-2-2b-it"
    OUTPUT_DIR = "story_generation_output"
    
    print("\n" + "="*80)
    print("STORY GENERATION SCRIPT")
    print("="*80)
    print(f"\nBase Model: {BASE_MODEL_ID}")
    print(f"Output Directory: {OUTPUT_DIR}")
    print("\n" + "="*80)
    
    # Load model
    print("\nStep 1: Loading model...")
    print("-" * 80)
    base_model, tokenizer = load_base_model(BASE_MODEL_ID)
    
    # Generate all stories
    print("\nStep 2: Generating stories...")
    print("-" * 80)
    total_files = generate_all_stories(base_model, tokenizer, OUTPUT_DIR)
    
    # Summary
    print("\n" + "="*80)
    print("GENERATION COMPLETE!")
    print("="*80)
    print(f"\nTotal files generated: {total_files}")
    print(f"Output location: {OUTPUT_DIR}/")
    print("\nFolder structure:")
    print(f"  {OUTPUT_DIR}/")
    print(f"  ├── greedy/")
    print(f"  │   └── base_greedy_default.txt")
    print(f"  ├── sampling/")
    print(f"  │   ├── base_sampling_0.1.txt")
    print(f"  │   ├── base_sampling_0.2.txt")
    print(f"  │   └── ... (11 files)")
    print(f"  ├── top_k/")
    print(f"  │   └── ... (30 files)")
    print(f"  ├── top_p/")
    print(f"  │   └── ... (30 files)")
    print(f"  └── beam/")
    print(f"      └── ... (3 files)")
    print("\nEach file contains:")
    print("  • Model type")
    print("  • Generation parameters")
    print("  • The full prompt")
    print("  • The generated story text")
    print("\n" + "="*80)