In [1]:
# ==========================================
# 0. INSTALL DEPENDENCIES (Run this once)
# ==========================================
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps packaging ninja einops
!pip install --no-deps xformers trl peft accelerate bitsandbytes
!pip install datasets sentence-transformers faiss-cpu tqdm

import os
import json
import pandas as pd
from unsloth import FastLanguageModel
from google.colab import drive
import torch
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import numpy as np
from datasets import load_dataset
from sentence_transformers import SentenceTransformer

# Use tqdm logic to determine best progress bar
try:
    from IPython import get_ipython
    if get_ipython():
        from tqdm.notebook import tqdm
    else:
        from tqdm import tqdm
except:
    from tqdm import tqdm


Collecting unsloth@ git+https://github.com/unslothai/unsloth.git (from unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git)
  Cloning https://github.com/unslothai/unsloth.git to /tmp/pip-install-d0gk9f7w/unsloth_cfd28cdab7944412b2e5c104feb9a66b
  Running command git clone --filter=blob:none --quiet https://github.com/unslothai/unsloth.git /tmp/pip-install-d0gk9f7w/unsloth_cfd28cdab7944412b2e5c104feb9a66b
  Resolved https://github.com/unslothai/unsloth.git to commit b96a04c17bc6bcb5522eafb17adc2b104be38f99
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting unsloth_zoo>=2026.1.2 (from unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git)
  Downloading unsloth_zoo-2026.1.2-py3-none-any.whl.metadata (32 kB)
Collecting tyro (from unsloth@ git+https://github.com/unslothai/unsloth.git-

In [2]:
# ==========================================
# 1. SETUP & CONFIGURATION
# ==========================================
# Mount Drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# Hugging Face Cache to Drive (Persistent)
DRIVE_CACHE = "/content/drive/MyDrive/LLM project/Cache/HF"
os.makedirs(DRIVE_CACHE, exist_ok=True)

os.environ["HF_HOME"] = f"{DRIVE_CACHE}/hf_home"
os.environ["HF_HUB_CACHE"] = f"{DRIVE_CACHE}/hf_hub"
os.environ["HF_DATASETS_CACHE"] = f"{DRIVE_CACHE}/datasets"
os.environ["TRANSFORMERS_CACHE"] = f"{DRIVE_CACHE}/transformers"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = f"{DRIVE_CACHE}/sentence_transformers"

from huggingface_hub import login
# Replace 'YOUR_HF_TOKEN' with your actual token or use Colab secrets
# login(token="YOUR_HF_TOKEN")

# Paths
# Use the official ID for downloading directly in Colab
model_id = "unsloth/gemma-2-9b-it-bnb-4bit"
# Common Colab path: /content/drive/MyDrive/LLM project/DATA/task-a-en.tsv
input_path = '/content/drive/MyDrive/LLM project/DATA/task-a-en.tsv'
output_file = "/content/drive/MyDrive/LLM project/DATA/outputs_gemma_rag.jsonl"

import numpy as np


Mounted at /content/drive


In [3]:
# ==========================================
# 2. LOAD RAG SYSTEM (Retriever)
# ==========================================
class HFRetriever:
    def __init__(self, n_docs=25000):
        print(f"Loading embedded Wikipedia subset: {n_docs} docs ...")
        # Load pre-embedded dataset (fast)
        self.ds = load_dataset(
            "not-lain/wikipedia",
            revision="embedded",
            split=f"train[:{n_docs}]"
        )
        self.texts = [str(x) for x in self.ds["text"]]
        # Load embeddings into numpy (fast cosine sim)
        self.embs = np.array(self.ds["embeddings"], dtype=np.float32)

        # Normalize doc vectors once
        self.embs = self.embs / (np.linalg.norm(self.embs, axis=1, keepdims=True) + 1e-12)

        print("Loading query embedding model on CPU (Safe Mode)...")
        # Use the exact same model that was used to embed the dataset!
        # The dataset 'not-lain/wikipedia' was embedded with 'mixedbread-ai/mxbai-embed-large-v1'
        # This model outputs 1024-dim vectors.
        # We CANNOT use 'all-MiniLM-L6-v2' (384-dim) because dimensions mismatch.
        self.encoder = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1", device="cpu")
        print("RAG Index Ready!")

    def retrieve(self, query, k=3, max_chars=1000):
        try:
            # Encode query
            q = self.encoder.encode([query]).astype(np.float32)
            q = q / (np.linalg.norm(q, axis=1, keepdims=True) + 1e-12)

            # Fast Cosine Similarity
            sims = self.embs @ q[0]
            top_idx = np.argsort(-sims)[:k]

            # Combine results
            combined_text = "\n\n".join([self.texts[i] for i in top_idx])
            return combined_text[:max_chars]
        except Exception as e:
            print(f"Retrieval failed: {e}")
            return ""

# Initialize Retriever (Fast Mode)
retriever = HFRetriever(n_docs=25000)


Loading embedded Wikipedia subset: 25000 docs ...


README.md:   0%|          | 0.00/417 [00:00<?, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/49.8M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3000 [00:00<?, ? examples/s]

Loading query embedding model on CPU (Safe Mode)...


modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/266 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/677 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/670M [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/695 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/297 [00:00<?, ?B/s]

RAG Index Ready!


In [4]:
# ==========================================
# 3. LOAD GEMMA MODEL
# ==========================================
print(f"Loading Gemma from {model_id}...")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_id,
    max_seq_length = 2048,
    load_in_4bit = True,
    dtype = None,
)
# Force model to CUDA explicitly to avoid device mismatch
# model = model.to("cuda") # CAUSES ERROR with 4-bit models
FastLanguageModel.for_inference(model)

# Fix for Unsloth padding if missing
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


Loading Gemma from unsloth/gemma-2-9b-it-bnb-4bit...
==((====))==  Unsloth 2026.1.2: Fast Gemma2 patching. Transformers: 4.57.3.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.33.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/6.13G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

In [5]:
# ==========================================
# 4. GENERATION LOOP
# ==========================================
# Read Input
df = pd.read_csv(input_path, sep='\t')
data = df.to_dict('records')

# Resume Check
processed_ids = set()
if os.path.exists(output_file):
    with open(output_file, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                saved_item = json.loads(line)
                processed_ids.add(saved_item['id'])
            except: pass
    print(f"Resuming... Found {len(processed_ids)} jokes.")

print("Starting RAG Generation with Gemma...")

for row in tqdm(data, desc="Generating Jokes"):
    current_id = row['id']
    if current_id in processed_ids: continue

    # Parse Input
    headline_val = str(row.get('headline', "-")).strip()
    w1_val = str(row.get('word1', "-")).strip()
    w2_val = str(row.get('word2', "-")).strip()

    # --- RAG RETRIEVAL ---
    if headline_val != "-" and headline_val != "" and headline_val.lower() != "nan":
        # Search for the headline content
        context = retriever.retrieve(headline_val)

        # Prompt with Context
        prompt_text = f"""### Instruction
You are a witty, cynical stand-up comedian.
Use the following background information ONLY if it helps inspire a joke. Otherwise, ignore it.

Background Info:
{context}

Headline: "{headline_val}"

Task: Write EXACTLY ONE punchy joke (1-2 sentences) about this headline.
- Be clever, cynical, or ironic.
- Do NOT explain the joke.
- Output ONLY the joke.

### Response
Joke:"""
        input_type = "headline"
        input_content = headline_val

    else:
        # Search for the words
        real_w1 = w1_val if w1_val != "-" else "something"
        real_w2 = w2_val if w2_val != "-" else "random"
        query = f"{real_w1} {real_w2}"
        context = retriever.retrieve(query)

        prompt_text = f"""### Instruction
You are a witty, cynical stand-up comedian.
Use the following background information ONLY if it helps inspire a joke. Otherwise, ignore it.

Background Info:
{context}

Task: Write EXACTLY ONE punchy joke (1-2 sentences) connecting: "{real_w1}" and "{real_w2}".
- Be clever, cynical, or ironic.
- Do NOT explain the joke.
- Output ONLY the joke.

### Response
Joke:"""
        input_type = "words"
        input_content = f"{real_w1}, {real_w2}"

    # --- GENERATE ---
    # Robust: Use direct tokenization, not apply_chat_template which might fail
    inputs = tokenizer([prompt_text], return_tensors="pt").to("cuda")

    with torch.inference_mode(): # Memory optimization
        outputs = model.generate(
            input_ids=inputs.input_ids,
            attention_mask=inputs.attention_mask,
            do_sample = True,
            max_new_tokens = 64, # Optimized
            temperature = 0.9,
            top_p = 0.9,
            repetition_penalty = 1.2,
            pad_token_id = tokenizer.eos_token_id
        )

    # --- PARSE ---
    decoded_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # 1. Remove the prompt (input) from the output so we don't parse it by mistake
    if "### Response" in decoded_text:
        response_part = decoded_text.split("### Response")[-1]
    else:
        response_part = decoded_text

    # 2. Find "Joke:"
    if "Joke:" in response_part:
        final_joke = response_part.split("Joke:")[-1].strip()
    else:
        # Fallback: Just take the whole response part
        final_joke = response_part.strip()

    # 3. Cleanup: Take first paragraph only (stop at double newline)
    final_joke = final_joke.split("\n\n")[0].strip().strip('"').strip("'")

    # 4. Remove artifacts like "model" or "assistant" (common in Gemma too)
    if final_joke.lower().startswith("model"):
        final_joke = final_joke[5:].strip().lstrip(":").strip()
    if final_joke.lower().startswith("assistant"):
        final_joke = final_joke[9:].strip().lstrip(":").strip()

    # If joke is still empty or just artifact, fallback to raw text
    if not final_joke or final_joke.lower() in ["model", "assistant"]:
        final_joke = decoded_text.replace(prompt_text, "").strip().split("\n\n")[0]

    # Clean Llama artifacts
    final_joke = final_joke.replace("assistant\n", "").strip()

    # Preview
    # print(f"ID {current_id} [RAG]: {final_joke[:50]}...")

    # Save
    result_entry = {
        "id": current_id,
        "type": input_type,
        "input_original": input_content,
        "retrieved_context": context,
        "prompt_used": prompt_text,
        "generated_joke": final_joke
    }

    with open(output_file, "a", encoding='utf-8') as f:
        f.write(json.dumps(result_entry, ensure_ascii=False) + "\n")

    # Periodic Cache Clear
    if len(processed_ids) % 50 == 0:
        torch.cuda.empty_cache()

print("Finished RAG Generation!")


Starting RAG Generation with Gemma...


Generating Jokes:   0%|          | 0/1200 [00:00<?, ?it/s]

Finished RAG Generation!
