In [1]:
# ==========================================
# 0. INSTALL DEPENDENCIES
# ==========================================
!pip install "unsloth[colab-new]"
!pip install --no-deps "unsloth_zoo"
!pip install --no-deps packaging ninja einops
!pip install --no-deps xformers trl peft accelerate bitsandbytes
!pip install datasets sentence-transformers faiss-cpu tqdm

import os
import json
import pandas as pd
from unsloth import FastLanguageModel
from google.colab import drive
import torch
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import numpy as np

# Use tqdm logic to determine best progress bar
try:
    from IPython import get_ipython
    if get_ipython():
        from tqdm.notebook import tqdm
    else:
        from tqdm import tqdm
except:
    from tqdm import tqdm

Collecting faiss-cpu
  Downloading faiss_cpu-1.13.2-cp310-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (7.6 kB)
Downloading faiss_cpu-1.13.2-cp310-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (23.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.8/23.8 MB[0m [31m107.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.13.2
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [2]:
# ==========================================
# 1. SETUP & CONFIGURATION
# ==========================================
# Mount Drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# Paths
model_id = "unsloth/gemma-2-9b-it-bnb-4bit"
input_path = '/content/drive/MyDrive/LLM project/DATA/task-a-en.tsv'
output_file = "/content/drive/MyDrive/LLM project/DATA/outputs_gemma_rag_25k.jsonl"

# Config for Retrieval
N_WIKI_DOCS = 25000
RETRIEVAL_K = 2
MAX_CONTEXT_TOKENS = 128


Mounted at /content/drive


In [3]:
# ==========================================
# 2. LOAD RAG SYSTEM (Retriever) - CPU Only
# ==========================================
class HFRetriever:
    def __init__(self, n_docs=N_WIKI_DOCS):
        print(f"Loading embedded Wikipedia subset: {n_docs} docs ...")
        # Load pre-embedded dataset (fast)
        self.ds = load_dataset(
            "not-lain/wikipedia",
            revision="embedded",
            split=f"train[:{n_docs}]"
        )
        self.texts = [str(x) for x in self.ds["text"]]
        # Load embeddings into numpy (fast cosine sim)
        self.embs = np.array(self.ds["embeddings"], dtype=np.float32)

        # Normalize doc vectors once
        self.embs = self.embs / (np.linalg.norm(self.embs, axis=1, keepdims=True) + 1e-12)

        print("Loading query embedding model on CPU (Safe Mode)...")
        # Use the exact same model that was used to embed the dataset!
        # The dataset 'not-lain/wikipedia' was embedded with 'mixedbread-ai/mxbai-embed-large-v1'
        # This model outputs 1024-dim vectors.
        # We CANNOT use 'all-MiniLM-L6-v2' (384-dim) because dimensions mismatch.
        self.encoder = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1", device="cpu")
        print("RAG Index Ready!")

    def retrieve(self, query, tokenizer, k=RETRIEVAL_K, max_tokens=MAX_CONTEXT_TOKENS):
        try:
            # Encode query
            q = self.encoder.encode([query]).astype(np.float32)
            q = q / (np.linalg.norm(q, axis=1, keepdims=True) + 1e-12)

            # Fast Cosine Similarity
            sims = self.embs @ q[0]
            top_idx = np.argsort(-sims)[:k]

            # Combine results
            combined_text = "\n\n".join([self.texts[i] for i in top_idx])

            # Truncate by tokens (Production Logic)
            ctx_ids = tokenizer.encode(combined_text, add_special_tokens=False)
            ctx_ids = ctx_ids[:max_tokens]
            truncated_text = tokenizer.decode(ctx_ids, skip_special_tokens=True)

            return truncated_text
        except Exception as e:
            print(f"Retrieval failed: {e}")
            return ""

# Initialize Retriever (Fast Mode)
retriever = HFRetriever(n_docs=N_WIKI_DOCS)


Loading embedded Wikipedia subset: 25000 docs ...


README.md:   0%|          | 0.00/417 [00:00<?, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/49.8M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3000 [00:00<?, ? examples/s]

Loading query embedding model on CPU (Safe Mode)...


modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/266 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/677 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/670M [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/695 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/297 [00:00<?, ?B/s]

RAG Index Ready!


In [4]:
# ==========================================
# 3. LOAD GEMMA MODEL
# ==========================================
print(f"Loading Gemma from {model_id}...")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_id,
    max_seq_length = 2048,
    load_in_4bit = True,
    dtype = None,
)
FastLanguageModel.for_inference(model)

# Fix for Unsloth padding if missing
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


Loading Gemma from unsloth/gemma-2-9b-it-bnb-4bit...
==((====))==  Unsloth 2026.1.4: Fast Gemma2 patching. Transformers: 4.57.6.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.10.0+cu128. CUDA: 7.5. CUDA Toolkit: 12.8. Triton: 3.6.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.34. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/6.13G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

In [5]:
# ==========================================
# 4. GENERATION LOOP
# ==========================================
# Read Input
# Check if input path exists, if not use a dummy for safety if running locally
if not os.path.exists(input_path):
    print(f"Warning: Input path {input_path} not found. Please check paths.")
    data = []
else:
    df = pd.read_csv(input_path, sep='\t')
    data = df.to_dict('records')

    #data = df.head(5).to_dict('records')

    # DEBUG: Pick 3 headlines + 3 word-pairs (Uncomment to use)
    #df_head = df[df['headline'].notna() & (df['headline'] != "-")].head(3)
    #df_word = df[df['word1'].notna() & (df['word1'] != "-")].head(3)
    #data = pd.concat([df_head, df_word]).to_dict('records')

# Resume Check
processed_ids = set()
if os.path.exists(output_file):
    with open(output_file, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                saved_item = json.loads(line)
                processed_ids.add(saved_item['id'])
            except: pass
    print(f"Resuming... Found {len(processed_ids)} jokes.")

print("Starting RAG Generation with Gemma...")

# progress bar
try:
    data_iter = tqdm(data, desc="Generating Jokes")
except NameError:
    data_iter = data

for i, row in enumerate(data_iter):
    current_id = row.get('id')
    if current_id is None:
        continue
    if current_id in processed_ids: continue

    # Parse Input
    headline_val = row.get('headline')
    w1_val = row.get('word1')
    w2_val = row.get('word2')

    # Safely convert to string only if valid
    headline_str = str(headline_val).strip() if pd.notna(headline_val) else "-"
    w1_str = str(w1_val).strip() if pd.notna(w1_val) else "-"
    w2_str = str(w2_val).strip() if pd.notna(w2_val) else "-"

    # --- DETERMINE TYPE & QUERY ---
    h = str(headline_val).strip()
    if pd.notna(headline_val) and h not in ["", "-", "nan", "NaN"]:
        input_type = "headline"
        input_content = headline_str

        # Exact query wrapper for headlines
        retrieval_query = "Background facts and context about: " + headline_str
        # Pass tokenizer and max_tokens
        context = retriever.retrieve(retrieval_query, tokenizer=tokenizer, k=RETRIEVAL_K, max_tokens=MAX_CONTEXT_TOKENS)

        # HEADLINE PROMPT (ZERO-SHOT)
        prompt_text = f"""### Instruction
You are a witty, cynical stand-up comedian. Write ORIGINAL humor (do not reuse or paraphrase known jokes).

You are given BACKGROUND INFO retrieved from Wikipedia (RAG).

Rules:
- Output EXACTLY ONE joke (1–2 sentences).
- The joke must be STANDALONE: include the premise so it makes sense without reading the headline.
- Use AT LEAST ONE concrete detail or idea from the background (paraphrase it; do NOT quote).
- Be clever, cynical, or ironic; avoid repeating the same setup style.
- Do NOT explain the joke.
- Do NOT summarize the headline.
- Keep it punchy (max ~35 words).

### Background (RAG, paraphrase only)
{context}

### Task
Headline: "{headline_str}"

### Response
Joke:"""

    else:
        # Words Case
        real_w1 = w1_str if w1_str not in ["-", "", "nan", "NaN"] else "something"
        real_w2 = w2_str if w2_str not in ["-", "", "nan", "NaN"] else "random"
        input_type = "word-pair"
        input_content = f"{real_w1}, {real_w2}"

        # Exact query wrapper for words
        retrieval_query = "Meaning, usage, and related concepts for: " + real_w1 + " and " + real_w2
        # Pass tokenizer and max_tokens
        context = retriever.retrieve(retrieval_query, tokenizer=tokenizer, k=RETRIEVAL_K, max_tokens=MAX_CONTEXT_TOKENS)

        # WORD-INCLUSION PROMPT (ZERO-SHOT)
        prompt_text = f"""### Instruction
You are a witty, cynical stand-up comedian. Write ORIGINAL humor (do not reuse or paraphrase known jokes).

You are given BACKGROUND INFO retrieved from Wikipedia (RAG).

Rules:
- Output EXACTLY ONE joke (1–2 sentences).
- The joke must be STANDALONE.
- You MUST include BOTH words exactly: "{real_w1}" and "{real_w2}".
- Use AT LEAST ONE concrete idea or detail from the background (paraphrase only; do NOT quote).
- Be clever, cynical, ironic, or absurd; avoid repeating the same setup style.
- Do NOT explain the joke.
- Do NOT list facts or definitions.
- Keep it punchy (max ~35 words).

### Background (RAG, paraphrase only)
{context}

### Task
Words: "{real_w1}", "{real_w2}"

### Response
Joke:"""

    # --- GENERATE WITH RETRY (Word Inclusion / Chat Template) ---
    max_retries = 2
    final_joke = ""

    for attempt in range(max_retries):
        # Use Chat Template for Llama
        messages = [
            {"role": "user", "content": prompt_text}
        ]

        # Only add reminder on retry for word-pair
        if attempt > 0 and input_type == "word-pair":
             messages[0]["content"] += f"\n\nREMINDER: You MUST include the words '{real_w1}' and '{real_w2}' in the joke."

        inputs = tokenizer.apply_chat_template(
            messages,
            tokenize = True,
            add_generation_prompt = True,
            return_tensors = "pt"
        )
        inputs = inputs.to(model.device)

        # Increase max tokens for retries
        current_max_tokens = 64 if attempt == 0 else 80

        with torch.inference_mode():
            # Explicitly create attention mask to fix warning
            attention_mask = (inputs != tokenizer.pad_token_id).long()

            outputs = model.generate(
                input_ids=inputs,
                attention_mask=attention_mask, # Pass explicit mask
                do_sample = True,
                temperature = 0.6,
                top_p = 0.9,
                repetition_penalty = 1.15,
                max_new_tokens = current_max_tokens,
                pad_token_id = tokenizer.eos_token_id
            )

        # --- PARSE ---
        decoded_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # 1. Robust extraction (Simple & Safe)
        # Try stripping the prompt first (most reliable for chat templates)
        prompt_decoded = tokenizer.decode(inputs[0], skip_special_tokens=True)
        # If inputs included generation prompt, the output should just append to it
        if decoded_text.startswith(prompt_decoded):
            generated_part = decoded_text[len(prompt_decoded):].strip()
        else:
            generated_part = decoded_text

        # Clean up any "model"/"assistant" tag artifacts if they appear at start
        if generated_part.startswith("assistant") or generated_part.startswith("model"):
             # Llama uses "assistant" usually
             generated_part = generated_part.replace("assistant", "", 1).strip()
             generated_part = generated_part.replace("model", "", 1).strip()

        # Further cleanup
        generated_part = generated_part.split("###", 1)[0].strip()
        if "Joke:" in generated_part:
            temp_joke = generated_part.split("Joke:")[-1]
        elif "Response" in generated_part:
             temp_joke = generated_part.split("Response")[-1]
        else:
             temp_joke = generated_part

        # Final fallback if empty
        if not temp_joke.strip():
             lines = [line.strip() for line in decoded_text.split("\n") if line.strip()]
             temp_joke = lines[-1] if lines else ""

        # 2. Cleanup
        # Allow up to 2 lines for the joke
        temp_joke = "\n".join([l.strip() for l in temp_joke.split("\n") if l.strip()][:2])
        temp_joke = temp_joke.strip('"').strip("'")

        # 3. Validation (Word Pair)
        if input_type == "word-pair":
            # Case insensitive check
            if real_w1.lower() in temp_joke.lower() and real_w2.lower() in temp_joke.lower():
                final_joke = temp_joke
                break # Success
            else:
                if attempt == max_retries - 1:
                    final_joke = temp_joke # Accept imperfect result after retries
        else:
            final_joke = temp_joke
            break # No validation needed for headlines

    # Fallback if empty
    if not final_joke:
        final_joke = "Error: No joke generated."

    # Save
    result_entry = {
        "id": current_id,
        "type": input_type,
        "input_original": input_content,
        "retrieved_context": context,
        "generated_joke": final_joke,
    }

    with open(output_file, "a", encoding='utf-8') as f:
        f.write(json.dumps(result_entry, ensure_ascii=False) + "\n")


    processed_ids.add(current_id)

    # Periodic Cache Clear
    if (i + 1) % 50 == 0:
        torch.cuda.empty_cache()

print("Finished RAG Generation!")


Starting RAG Generation with Gemma...


Generating Jokes:   0%|          | 0/1200 [00:00<?, ?it/s]

Finished RAG Generation!
