In [1]:
# ==========================================
# 0. INSTALL DEPENDENCIES (Run this once)
# ==========================================
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps packaging ninja einops
!pip install --no-deps xformers trl peft accelerate bitsandbytes
!pip install datasets sentence-transformers faiss-cpu tqdm

import os
import json
import pandas as pd
from unsloth import FastLanguageModel
from google.colab import drive
import torch
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import numpy as np
from huggingface_hub import login
from tqdm import tqdm


Collecting unsloth@ git+https://github.com/unslothai/unsloth.git (from unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git)
  Cloning https://github.com/unslothai/unsloth.git to /tmp/pip-install-no9j3008/unsloth_80e801b7f6204e06be6701e39ce6585a
  Running command git clone --filter=blob:none --quiet https://github.com/unslothai/unsloth.git /tmp/pip-install-no9j3008/unsloth_80e801b7f6204e06be6701e39ce6585a
  Resolved https://github.com/unslothai/unsloth.git to commit b96a04c17bc6bcb5522eafb17adc2b104be38f99
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting unsloth_zoo>=2026.1.2 (from unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git)
  Downloading unsloth_zoo-2026.1.2-py3-none-any.whl.metadata (32 kB)
Collecting tyro (from unsloth@ git+https://github.com/unslothai/unsloth.git-

In [2]:
# ==========================================
# 1. CONFIGURATION
# ==========================================
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# Hugging Face Cache to Drive (Persistent)
DRIVE_CACHE = "/content/drive/MyDrive/LLM project/Cache/HF"
os.makedirs(DRIVE_CACHE, exist_ok=True)

os.environ["HF_HOME"] = f"{DRIVE_CACHE}/hf_home"
os.environ["HF_HUB_CACHE"] = f"{DRIVE_CACHE}/hf_hub"
os.environ["HF_DATASETS_CACHE"] = f"{DRIVE_CACHE}/datasets"
os.environ["TRANSFORMERS_CACHE"] = f"{DRIVE_CACHE}/transformers"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = f"{DRIVE_CACHE}/sentence_transformers"

from huggingface_hub import login
# login(token="YOUR_HF_TOKEN")

model_id = "unsloth/llama-3-8b-Instruct-bnb-4bit"
input_path = '/content/task-a-en.tsv'
output_file = "/content/outputs_llama_dual_rag.jsonl"

# Config for Retrieval
N_WIKI_DOCS = 20000
N_JOKE_DOCS = 5000

Mounted at /content/drive


In [3]:
# ==========================================
# 2. LOAD DUAL RAG SYSTEM (Retriever) - FIXED CACHING
# - wiki loads only train[:N] (no full train)
# - jokes embeddings cached per (n_jokes, encoder)
# - works after restart (HF caches go to Drive if you set env vars)
# ==========================================
class DualRetriever:
    def __init__(self, n_wiki=20000, n_jokes=5000, seed=42):
        print("Initializing Dual Retriever...")

        self.seed = seed

        # Cache folder (Drive)
        cache_dir = "/content/drive/MyDrive/LLM project/Cache/Retriever"
        os.makedirs(cache_dir, exist_ok=True)

        # Embedding model (CPU)
        self.encoder_name = "mixedbread-ai/mxbai-embed-large-v1"
        print("Loading embedding model on CPU (Safe Mode)...")
        self.encoder = SentenceTransformer(self.encoder_name, device="cpu")

        # -------------------------
        # A) Wikipedia (embedded) - load ONLY the slice you need
        # -------------------------
        print("Loading Wikipedia (embedded) slice...")
        wiki_full = load_dataset("not-lain/wikipedia", revision="embedded", split="train")
        actual_n_wiki = min(int(n_wiki), len(wiki_full))
        print(f"Using {actual_n_wiki} wiki docs (from {len(wiki_full)} total)")

        # load only first N, then shuffle that small subset
        wiki_ds = load_dataset(
            "not-lain/wikipedia",
            revision="embedded",
            split=f"train[:{actual_n_wiki}]"
        ).shuffle(seed=seed)

        self.wiki_texts = [str(x) for x in wiki_ds["text"]]
        self.wiki_embs = np.asarray(wiki_ds["embeddings"], dtype=np.float32)

        # normalize once
        self.wiki_embs /= (np.linalg.norm(self.wiki_embs, axis=1, keepdims=True) + 1e-12)

        # -------------------------
        # B) Jokes (style) - cache per (n_jokes, encoder)
        # -------------------------
        actual_n_jokes = int(n_jokes)
        emb_tag = self.encoder_name.replace("/", "_")
        jokes_emb_path = os.path.join(cache_dir, f"jokes_embs_{actual_n_jokes}_{emb_tag}.npy")
        jokes_txt_path = os.path.join(cache_dir, f"jokes_texts_{actual_n_jokes}.npy")

        if os.path.exists(jokes_emb_path) and os.path.exists(jokes_txt_path):
            print(f"✅ Found cached jokes index. Loading from Drive:\n- {jokes_emb_path}\n- {jokes_txt_path}")
            self.joke_embs = np.load(jokes_emb_path)
            self.joke_texts = np.load(jokes_txt_path, allow_pickle=True).tolist()
        else:
            print("⚠️ Cache not found. Building jokes index ONCE (this takes time)...")
            jokes_ds = load_dataset("weixingxing/short-jokes-dataset", split="train")

            actual_n_jokes = min(actual_n_jokes, len(jokes_ds))
            print(f"Using {actual_n_jokes} jokes (from {len(jokes_ds)} total)")

            jokes_ds = jokes_ds.shuffle(seed=seed).select(range(actual_n_jokes))

            # clean text
            joke_texts = []
            for ex in jokes_ds:
                s = ex.get("Joke", "")
                s = "" if s is None else str(s).strip()
                if s and s.lower() != "nan":
                    joke_texts.append(s)

            self.joke_texts = joke_texts

            print("Embedding jokes (CPU)...")
            self.joke_embs = self.encoder.encode(
                self.joke_texts,
                convert_to_numpy=True,
                show_progress_bar=True,
                batch_size=64,
                normalize_embeddings=True,
            ).astype(np.float32)

            print(f"💾 Saving jokes index to Drive: {cache_dir}")
            np.save(jokes_emb_path, self.joke_embs)
            np.save(jokes_txt_path, np.array(self.joke_texts, dtype=object))

        print("Dual RAG Ready!")

    def retrieve(self, query, k_wiki=1, k_jokes=5):
        try:
            query = str(query).strip()
            if not query:
                return "Fact:\n\nStyle Inspiration:\n-"

            q_vec = self.encoder.encode([query], convert_to_numpy=True, normalize_embeddings=True).astype(np.float32)[0]

            # Wiki
            sims_wiki = self.wiki_embs @ q_vec
            # Use argpartition for speed
            top_w_idx = np.argpartition(-sims_wiki, kth=min(k_wiki, len(sims_wiki)-1))[:k_wiki]
            # Sort only top k
            top_w_idx = top_w_idx[np.argsort(-sims_wiki[top_w_idx])]
            wiki_str = self.wiki_texts[top_w_idx[0]][:350] if len(top_w_idx) else "" # Limit chars

            # Jokes (retrieve multiple for style pack)
            sims_j = self.joke_embs @ q_vec
            top_j_idx = np.argpartition(-sims_j, kth=min(k_jokes, len(sims_j)-1))[:k_jokes]
            top_j_idx = top_j_idx[np.argsort(-sims_j[top_j_idx])]

            joke_samples = [self.joke_texts[i][:150] for i in top_j_idx] # Limit chars per joke
            joke_str = "\n- ".join(joke_samples)

            return f"Fact: {wiki_str}\n\nStyle Inspiration:\n- {joke_str}"

        except Exception as e:
            print(f"Retrieval failed: {e}")
            return ""

# Initialize Retriever
retriever = DualRetriever(n_wiki=N_WIKI_DOCS, n_jokes=N_JOKE_DOCS)

Initializing Dual Retriever...
Loading embedding model on CPU (Safe Mode)...
Loading Wikipedia (embedded) slice...


README.md:   0%|          | 0.00/417 [00:00<?, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/49.8M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3000 [00:00<?, ? examples/s]

Using 3000 wiki docs (from 3000 total)
✅ Found cached jokes index. Loading from Drive:
- /content/drive/MyDrive/LLM project/Cache/Retriever/jokes_embs_5000_mixedbread-ai_mxbai-embed-large-v1.npy
- /content/drive/MyDrive/LLM project/Cache/Retriever/jokes_texts_5000.npy
Dual RAG Ready!


In [4]:
# ==========================================
# 3. LOAD MODEL
# ==========================================
# Drive Cache Path
drive_model_path = "/content/drive/MyDrive/LLM project/Models/llama-3-8b-Instruct-bnb-4bit"

if os.path.exists(drive_model_path):
    print(f"✅ Found local Llama model at {drive_model_path}. Loading from Drive...")
    model_id_to_load = drive_model_path
    save_after_load = False
else:
    print(f"⚠️ Local model not found. Downloading {model_id} from Hugging Face...")
    model_id_to_load = model_id
    save_after_load = True

print(f"Loading Model from {model_id_to_load}...")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_id_to_load,
    max_seq_length = 4096,
    load_in_4bit = True,
    dtype = None,
)

if save_after_load:
    print(f"💾 Saving model to Drive for future use: {drive_model_path}")
    model.save_pretrained(drive_model_path)
    tokenizer.save_pretrained(drive_model_path)

# Force model to CUDA explicitly to avoid device mismatch
# model = model.to("cuda") # CAUSES ERROR with 4-bit models
FastLanguageModel.for_inference(model)


⚠️ Local model not found. Downloading unsloth/llama-3-8b-Instruct-bnb-4bit from Hugging Face...
Loading Model from unsloth/llama-3-8b-Instruct-bnb-4bit...
==((====))==  Unsloth 2026.1.2: Fast Llama patching. Transformers: 4.57.3.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.33.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/5.70G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/220 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/345 [00:00<?, ?B/s]

💾 Saving model to Drive for future use: /content/drive/MyDrive/LLM project/Models/llama-3-8b-Instruct-bnb-4bit


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096, padding_idx=128255)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm):

In [5]:
# ==========================================
# 4. GENERATION LOOP
# ==========================================
# Read Input
df = pd.read_csv(input_path, sep='\t')
data = df.to_dict('records')

# Resume Check
processed_ids = set()
if os.path.exists(output_file):
    with open(output_file, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                saved_item = json.loads(line)
                processed_ids.add(saved_item['id'])
            except: pass
    print(f"Resuming... Found {len(processed_ids)} jokes.")

print("Starting Dual-RAG Generation with Llama...")

# Use tqdm logic to determine best progress bar
try:
    from IPython import get_ipython
    if get_ipython():
        from tqdm.notebook import tqdm
    else:
        from tqdm import tqdm
except:
    from tqdm import tqdm

for row in tqdm(data, desc="Generating Jokes"):
    current_id = row['id']
    if current_id in processed_ids: continue

    # Parse Input
    headline_val = str(row.get('headline', "-")).strip()
    w1_val = str(row.get('word1', "-")).strip()
    w2_val = str(row.get('word2', "-")).strip()

    # --- RAG RETRIEVAL (DIFFERENT LEVELS) ---

    # LEVEL 1: NO RAG (Just Llama)
    context_no_rag = ""

    # LEVEL 2: WIKI ONLY (Fact)
    context_wiki = ""
    # LEVEL 3: DUAL (Fact + Style)
    context_dual = ""

    if headline_val != "-" and headline_val != "" and headline_val.lower() != "nan":
        # Retrieve full Dual context
        full_context_str = retriever.retrieve(headline_val)

        # Split Context if possible
        if "Style Inspiration:" in full_context_str:
            parts = full_context_str.split("Style Inspiration:")
            fact_part = parts[0].replace("Fact:", "").strip()
            style_part = parts[1].strip()

            context_wiki = f"Fact: {fact_part}"
            context_dual = full_context_str
        else:
            context_wiki = full_context_str
            context_dual = full_context_str

        input_type = "headline"
        input_content = headline_val

        # Base Prompt Template Generator
        def make_prompt(ctx):
            if ctx:
                return f"""### Instruction
You are a witty, cynical stand-up comedian.
Use the following background information ONLY if it helps inspire a joke. Otherwise, ignore it.

Background Info:
{ctx}

Headline: "{headline_val}"

Task: Write EXACTLY ONE punchy joke (1-2 sentences) about this headline.
- Be clever, cynical, or ironic.
- Do NOT explain the joke.
- Output ONLY the joke.

### Response
Joke:"""
            else:
                 return f"""### Instruction
You are a witty, cynical stand-up comedian.
Headline: "{headline_val}"

Task: Write EXACTLY ONE punchy joke (1-2 sentences) about this headline.
- Be clever, cynical, or ironic.
- Do NOT explain the joke.
- Output ONLY the joke.

### Response
Joke:"""

    else:
        # Words Case
        real_w1 = w1_val if w1_val != "-" else "something"
        real_w2 = w2_val if w2_val != "-" else "random"
        query = f"{real_w1} {real_w2}"

        full_context_str = retriever.retrieve(query)

        if "Style Inspiration:" in full_context_str:
            parts = full_context_str.split("Style Inspiration:")
            fact_part = parts[0].replace("Fact:", "").strip()
            style_part = parts[1].strip()

            context_wiki = f"Fact: {fact_part}"
            context_dual = full_context_str
        else:
            context_wiki = full_context_str
            context_dual = full_context_str

        input_type = "word-pair"
        input_content = f"{real_w1} {real_w2}"

        def make_prompt(ctx):
            if ctx:
                 return f"""### Instruction
You are a witty, cynical stand-up comedian.
Use the following background information ONLY if it helps inspire a joke. Otherwise, ignore it.

Background Info:
{ctx}

Task: Write EXACTLY ONE punchy joke (1-2 sentences) connecting: "{real_w1}" and "{real_w2}".
- Be clever, cynical, or ironic.
- Do NOT explain the joke.
- Output ONLY the joke.

### Response
Joke:"""
            else:
                 return f"""### Instruction
You are a witty, cynical stand-up comedian.
Task: Write EXACTLY ONE punchy joke (1-2 sentences) connecting: "{real_w1}" and "{real_w2}".
- Be clever, cynical, or ironic.
- Do NOT explain the joke.
- Output ONLY the joke.

### Response
Joke:"""

    # --- GENERATE (OPTIMIZED) ---

    # Helper to generate
    def generate_one(prompt_str):
        inputs = tokenizer([prompt_str], return_tensors="pt").to("cuda")
        with torch.inference_mode(): # Memory optimization
            outputs = model.generate(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                do_sample = True,
                max_new_tokens = 64, # Reduced from 128
                temperature = 0.9,
                top_p = 0.9,
                repetition_penalty = 1.2,
                pad_token_id = tokenizer.eos_token_id
            )
        dec_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        if "### Response" in dec_text:
            resp = dec_text.split("### Response")[-1]
        else:
            resp = dec_text
        if "Joke:" in resp:
            final = resp.split("Joke:")[-1].strip()
        else:
            final = resp.strip()
        final = final.split("\n\n")[0].strip().strip('"').strip("'")
        if final.lower().startswith("assistant"):
            final = final[9:].strip()
        return final


    prompt_final = make_prompt(context_dual)
    joke_final = generate_one(prompt_final)

    # Periodic Cache Clear
    if len(processed_ids) % 50 == 0:
        torch.cuda.empty_cache()

    # Save Combined
    result_entry = {
        "id": current_id,
        "type": input_type,
        "input_original": input_content,
        "context_retrieved": context_dual,
        "generated_joke": joke_final
    }

    with open(output_file, "a", encoding='utf-8') as f:
        f.write(json.dumps(result_entry, ensure_ascii=False) + "\n")

print("Finished Dual-RAG Generation!")

Starting Dual-RAG Generation with Llama...


Generating Jokes:   0%|          | 0/1200 [00:00<?, ?it/s]

Finished Dual-RAG Generation!
