In [11]:
# ==========================================
# 0. INSTALL DEPENDENCIES (Run this once)
# ==========================================
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps packaging ninja einops
!pip install --no-deps xformers trl peft accelerate bitsandbytes
!pip install datasets sentence-transformers faiss-cpu


import os
import json
import pandas as pd
from unsloth import FastLanguageModel
from google.colab import drive
import torch
import re
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import numpy as np

# Use tqdm logic to determine best progress bar
try:
    from IPython import get_ipython
    if get_ipython():
        from tqdm.notebook import tqdm
    else:
        from tqdm import tqdm
except:
    from tqdm import tqdm


Collecting unsloth@ git+https://github.com/unslothai/unsloth.git (from unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git)
  Cloning https://github.com/unslothai/unsloth.git to /tmp/pip-install-89io03ub/unsloth_c8994cde17f74b488ed280d93d240577
  Running command git clone --filter=blob:none --quiet https://github.com/unslothai/unsloth.git /tmp/pip-install-89io03ub/unsloth_c8994cde17f74b488ed280d93d240577
  Resolved https://github.com/unslothai/unsloth.git to commit b96a04c17bc6bcb5522eafb17adc2b104be38f99
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [12]:
# ==========================================
# 1. SETUP & CONFIGURATION
# ==========================================
# Mount Drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# Hugging Face Cache to Drive (Persistent)
DRIVE_CACHE = "/content/drive/MyDrive/LLM project/Cache/HF"
os.makedirs(DRIVE_CACHE, exist_ok=True)

os.environ["HF_HOME"] = f"{DRIVE_CACHE}/hf_home"
os.environ["HF_HUB_CACHE"] = f"{DRIVE_CACHE}/hf_hub"
os.environ["HF_DATASETS_CACHE"] = f"{DRIVE_CACHE}/datasets"
os.environ["TRANSFORMERS_CACHE"] = f"{DRIVE_CACHE}/transformers"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = f"{DRIVE_CACHE}/sentence_transformers"

from huggingface_hub import login
# Replace 'YOUR_HF_TOKEN' with your actual token or use Colab secrets
# login(token="YOUR_HF_TOKEN")

# Paths
model_id = "unsloth/Qwen2.5-7B-Instruct-bnb-4bit"
input_path = '/content/drive/MyDrive/LLM project/DATA/task-a-en.tsv'
output_file = "/content/drive/MyDrive/LLM project/DATA/outputs_qwen_dual_rag.jsonl"

# Config for Retrieval
N_WIKI_DOCS = 20000  # Stable size
N_JOKE_DOCS = 5000   # Stable size


In [13]:
# ==========================================
# 2. LOAD DUAL RAG SYSTEM (Retriever) - FIXED CACHING
# ==========================================
class DualRetriever:
    def __init__(self, n_wiki=20000, n_jokes=5000, seed=42):
        print("Initializing Dual Retriever...")
        self.seed = seed

        # Cache folder (Drive)
        cache_dir = "/content/drive/MyDrive/LLM project/Cache/Retriever"
        os.makedirs(cache_dir, exist_ok=True)

        # Embedding model (CPU)
        self.encoder_name = "mixedbread-ai/mxbai-embed-large-v1"
        print("Loading embedding model on CPU (Safe Mode)...")
        self.encoder = SentenceTransformer(self.encoder_name, device="cpu")

        # -------------------------
        # A) Wikipedia (embedded)
        # -------------------------
        print("Loading Wikipedia (embedded) slice...")
        wiki_full = load_dataset("not-lain/wikipedia", revision="embedded", split="train")
        actual_n_wiki = min(int(n_wiki), len(wiki_full))
        print(f"Using {actual_n_wiki} wiki docs (from {len(wiki_full)} total)")

        wiki_ds = load_dataset(
            "not-lain/wikipedia",
            revision="embedded",
            split=f"train[:{actual_n_wiki}]"
        ).shuffle(seed=seed)

        self.wiki_texts = [str(x) for x in wiki_ds["text"]]
        self.wiki_embs = np.asarray(wiki_ds["embeddings"], dtype=np.float32)
        self.wiki_embs /= (np.linalg.norm(self.wiki_embs, axis=1, keepdims=True) + 1e-12)

        # -------------------------
        # B) Jokes (style) - Cached
        # -------------------------
        actual_n_jokes = int(n_jokes)
        emb_tag = self.encoder_name.replace("/", "_")
        jokes_emb_path = os.path.join(cache_dir, f"jokes_embs_{actual_n_jokes}_{emb_tag}.npy")
        jokes_txt_path = os.path.join(cache_dir, f"jokes_texts_{actual_n_jokes}.npy")

        if os.path.exists(jokes_emb_path) and os.path.exists(jokes_txt_path):
            print(f"‚úÖ Found cached jokes index. Loading from Drive...")
            self.joke_embs = np.load(jokes_emb_path)
            self.joke_texts = np.load(jokes_txt_path, allow_pickle=True).tolist()
        else:
            print("‚ö†Ô∏è Cache not found. Building jokes index ONCE...")
            jokes_ds = load_dataset("weixingxing/short-jokes-dataset", split="train")
            actual_n_jokes = min(actual_n_jokes, len(jokes_ds))
            print(f"Using {actual_n_jokes} jokes (from {len(jokes_ds)} total)")

            jokes_ds = jokes_ds.shuffle(seed=seed).select(range(actual_n_jokes))

            joke_texts = []
            for ex in jokes_ds:
                s = ex.get("Joke", "")
                s = "" if s is None else str(s).strip()
                if s and s.lower() != "nan":
                    joke_texts.append(s)
            self.joke_texts = joke_texts

            print("Embedding jokes (CPU)...")
            self.joke_embs = self.encoder.encode(
                self.joke_texts,
                convert_to_numpy=True,
                show_progress_bar=True,
                batch_size=64,
                normalize_embeddings=True,
            ).astype(np.float32)

            print(f"üíæ Saving jokes index to Drive: {cache_dir}")
            np.save(jokes_emb_path, self.joke_embs)
            np.save(jokes_txt_path, np.array(self.joke_texts, dtype=object))

        print("Dual RAG Ready!")

    def retrieve(self, query, k_wiki=2, k_jokes=1, max_chars=1200):
        try:
            query = str(query).strip()
            if not query:
                return "--- FACTUAL CONTEXT ---\n\n\n--- COMEDIC INSPIRATION (Similar Jokes) ---\n"

            q_vec = self.encoder.encode([query], convert_to_numpy=True, normalize_embeddings=True).astype(np.float32)[0]

            # Wiki
            sims_wiki = self.wiki_embs @ q_vec
            top_w = np.argsort(-sims_wiki)[:k_wiki]
            wiki_res = [self.wiki_texts[i] for i in top_w]

            # Jokes
            sims_j = self.joke_embs @ q_vec
            top_j = np.argsort(-sims_j)[:k_jokes]
            joke_res = [self.joke_texts[i] for i in top_j]

            combined = "--- FACTUAL CONTEXT ---\n" + "\n".join(wiki_res)
            combined += "\n\n--- COMEDIC INSPIRATION (Similar Jokes) ---\n" + "\n".join(joke_res)

            return combined[:max_chars]

        except Exception as e:
            print(f"Retrieval failed: {e}")
            return ""

# Initialize Dual Retriever
retriever = DualRetriever(n_wiki=N_WIKI_DOCS, n_jokes=N_JOKE_DOCS)


Initializing Dual Retriever...
Loading embedding model on CPU (Safe Mode)...
Loading Wikipedia (embedded) slice...
Using 3000 wiki docs (from 3000 total)
‚úÖ Found cached jokes index. Loading from Drive...
Dual RAG Ready!


In [14]:
# ==========================================
# 3. LOAD MODEL
# ==========================================
print(f"Loading Model from {model_id}...")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_id,
    max_seq_length = 2048,
    load_in_4bit = True,
    dtype = None,
)
FastLanguageModel.for_inference(model)

# Fix for Qwen/Unsloth padding
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


def contains_chinese(text):
    return bool(re.search(r'[\u4e00-\u9fff]', text))

Loading Model from unsloth/Qwen2.5-7B-Instruct-bnb-4bit...
==((====))==  Unsloth 2026.1.2: Fast Qwen2 patching. Transformers: 4.57.3.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.33.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


In [16]:
# ==========================================
# 4. GENERATION LOOP
# ==========================================
df = pd.read_csv(input_path, sep='\t')
data = df.to_dict('records')

processed_ids = set()
if os.path.exists(output_file):
    with open(output_file, 'r', encoding='utf-8') as f:
        for line in f:
            try: processed_ids.add(json.loads(line)['id'])
            except: pass
    print(f"Resuming... Found {len(processed_ids)} jokes.")

print("Starting Dual-RAG Generation...")

for row in tqdm(data, desc="Generating Jokes"):
    current_id = row['id']
    if current_id in processed_ids: continue

    headline_val = str(row.get('headline', "-")).strip()
    w1_val = str(row.get('word1', "-")).strip()
    w2_val = str(row.get('word2', "-")).strip()

    # --- RAG RETRIEVAL (DIFFERENT LEVELS) ---

    # LEVEL 1: NO RAG (Just Qwen)
    context_no_rag = ""

    # LEVEL 2: WIKI ONLY (Fact)
    context_wiki = ""
    # LEVEL 3: DUAL (Fact + Style)
    context_dual = ""

    if headline_val != "-" and headline_val != "" and headline_val.lower() != "nan":
        # Retrieve full Dual context
        full_context_str = retriever.retrieve(headline_val)

        # Split Context (Qwen uses "--- FACTUAL CONTEXT ---" delimiter)
        if "--- COMEDIC INSPIRATION (Similar Jokes) ---" in full_context_str:
            parts = full_context_str.split("--- COMEDIC INSPIRATION (Similar Jokes) ---")
            fact_part = parts[0].replace("--- FACTUAL CONTEXT ---", "").strip()
            style_part = parts[1].strip()

            context_wiki = f"--- FACTUAL CONTEXT ---\n{fact_part}"
            context_dual = full_context_str
        else:
            context_wiki = full_context_str
            context_dual = full_context_str

        input_type = "headline"
        input_content = headline_val

        # Base Prompt Template Generator
        def make_prompt(ctx):
            if ctx:
                return f"""### Instruction
You are a witty, cynical stand-up comedian. Write ONLY in English.
Use the provided Context to write a joke.

{ctx}

Headline: "{headline_val}"

Task: Write EXACTLY ONE punchy joke (1-2 sentences) about this headline.
- You can use the "Factual Context" for substance.
- You can use the "Comedic Inspiration" for tone/style.
- Do NOT explain the joke.
- Output ONLY the joke.

### Response
Joke:"""
            else:
                 return f"""### Instruction
You are a witty, cynical stand-up comedian. Write ONLY in English.
Headline: "{headline_val}"

Task: Write EXACTLY ONE punchy joke (1-2 sentences) about this headline.
- Be clever, cynical, or ironic.
- Do NOT explain the joke.
- Output ONLY the joke.

### Response
Joke:"""

    else:
        # Words Case
        real_w1 = w1_val if w1_val != "-" else "something"
        real_w2 = w2_val if w2_val != "-" else "random"
        query = f"{real_w1} {real_w2}"

        full_context_str = retriever.retrieve(query)

        if "--- COMEDIC INSPIRATION (Similar Jokes) ---" in full_context_str:
            parts = full_context_str.split("--- COMEDIC INSPIRATION (Similar Jokes) ---")
            fact_part = parts[0].replace("--- FACTUAL CONTEXT ---", "").strip()
            style_part = parts[1].strip()

            context_wiki = f"--- FACTUAL CONTEXT ---\n{fact_part}"
            context_dual = full_context_str
        else:
            context_wiki = full_context_str
            context_dual = full_context_str

        input_type = "word-pair"
        input_content = f"{real_w1} {real_w2}"

        def make_prompt(ctx):
            if ctx:
                 return f"""### Instruction
You are a witty, cynical stand-up comedian. Write ONLY in English.
Use the provided Context to write a joke.

{ctx}

Task: Write EXACTLY ONE punchy joke (1-2 sentences) connecting: "{real_w1}" and "{real_w2}".
- You can use the "Factual Context" for substance.
- You can use the "Comedic Inspiration" for tone/style.
- Do NOT explain the joke.
- Output ONLY the joke.

### Response
Joke:"""
            else:
                 return f"""### Instruction
You are a witty, cynical stand-up comedian. Write ONLY in English.
Task: Write EXACTLY ONE punchy joke (1-2 sentences) connecting: "{real_w1}" and "{real_w2}".
- Be clever, cynical, or ironic.
- Do NOT explain the joke.
- Output ONLY the joke.

### Response
Joke:"""

    # --- GENERATE 3 TIMES ---

    # Helper to generate
    def generate_one(prompt_str):
        # Retry loop for Chinese safety check
        max_retries = 3
        final_text = "ERROR"

        for attempt in range(max_retries):
            # Tokenize manually to get attention mask (Fixes Qwen pad token warning)
            messages = [{"role": "user", "content": prompt_str}]
            text_input = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )

            model_inputs = tokenizer([text_input], return_tensors="pt").to("cuda")

            with torch.inference_mode():
                outputs = model.generate(
                    input_ids=model_inputs.input_ids,
                    attention_mask=model_inputs.attention_mask, # Explicitly pass mask
                    do_sample = True,
                    max_new_tokens = 64, # Optimized
                    temperature = 0.9,
                    top_p = 0.9,
                    repetition_penalty = 1.2,
                    pad_token_id = tokenizer.eos_token_id
                )

            raw_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            if "Joke:" in raw_text:
                temp_joke = raw_text.split("Joke:")[-1]
            else:
                temp_joke = raw_text.split("assistant")[-1] if "assistant" in raw_text else raw_text

            temp_joke = temp_joke.replace("assistant\n", "").replace("assistant", "").strip().split("\n\n")[0].strip(' "')

            if contains_chinese(temp_joke):
                continue # Retry
            else:
                final_text = temp_joke
                break

        return final_text

    # 1. NO RAG
    # prompt_1 = make_prompt(None)
    # joke_1 = generate_one(prompt_1)

    # 1. NO RAG
    # prompt_1 = make_prompt(None)
    # joke_1 = generate_one(prompt_1)

    # 2. WIKI ONLY
    # prompt_2 = make_prompt(context_wiki)
    # joke_2 = generate_one(prompt_2)

    # 3. DUAL RAG (ONLY THIS ONE AS REQUESTED)
    prompt_final = make_prompt(context_dual)
    joke_final = generate_one(prompt_final)

    # Periodic Cache Clear
    if len(processed_ids) % 50 == 0:
        torch.cuda.empty_cache()

    # Save Combined
    result_entry = {
        "id": current_id,
        "type": input_type,
        "input_original": input_content,
        "context_retrieved": context_dual,
        "generated_joke": joke_final
    }

    with open(output_file, "a", encoding='utf-8') as f:
        f.write(json.dumps(result_entry, ensure_ascii=False) + "\n")

print("Finished Dual-RAG Generation!")


Starting Dual-RAG Generation...


Generating Jokes:   0%|          | 0/1200 [00:00<?, ?it/s]

Finished Dual-RAG Generation!
