In [16]:
import os
import faiss
import numpy as np
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
from tqdm.auto import tqdm
from huggingface_hub import HfApi
from huggingface_hub import hf_hub_download
import torch

In [12]:
# ─── Config ─────────────────────────────────────────────────────────────
HF_TOKEN        = os.getenv("HF_TOKEN")
MCQA_DS         = "GingerBled/MNLP_M2_mcqa_dataset"   # your existing MCQA-only data
CHUNKS_REPO     = "GingerBled/RAG_corpus_docs"
INDEX_PATH      = "index/index.faiss"
IDMAP_PATH      = "index/id_map.npy"
ENCODER_REPO    = "GingerBled/MNLP_M2_document_encoder"
OUT_DS          = "GingerBled/MNLP_M2_mcqa_with_context"
TOP_K           = 5
DEVICE          = "cuda"

In [3]:
# ─── Load models & data ───────────────────────────────────────────────────
print("[1/6] Loading MCQA examples…")
mcqa = load_dataset(MCQA_DS, split="train")

[1/6] Loading MCQA examples…


README.md:   0%|          | 0.00/420 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/1.25M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10687 [00:00<?, ? examples/s]

In [9]:
print("[2/6] Loading chunk index & id_map…")
index_file = hf_hub_download(
    repo_id=CHUNKS_REPO,
    filename=INDEX_PATH,
    repo_type="dataset",
    token=HF_TOKEN
)
id_map_file = hf_hub_download(
    repo_id=CHUNKS_REPO,
    filename=IDMAP_PATH,
    repo_type="dataset",
    token=HF_TOKEN
)

[2/6] Loading chunk index & id_map…


index.faiss:   0%|          | 0.00/2.12G [00:00<?, ?B/s]

id_map.npy:   0%|          | 0.00/19.4M [00:00<?, ?B/s]

In [10]:
faiss_index = faiss.read_index(index_file)
id_map       = np.load(id_map_file, allow_pickle=True)

In [17]:
print("[3/6] Loading encoder…")

tokenizer_q = AutoTokenizer.from_pretrained(ENCODER_REPO)
model_q     = AutoModel.from_pretrained(ENCODER_REPO, torch_dtype=torch.float16).to(DEVICE)
model_q.eval()

[3/6] Loading encoder…


In [23]:
def encode_batch(texts, max_length=512):
    # texts: List[str]
    inputs = tokenizer_q(
        texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    ).to(DEVICE)
    with torch.no_grad():
        outputs = model_q(**inputs, return_dict=True)
    last = outputs.last_hidden_state          # [B, T, D]
    mask = inputs.attention_mask.unsqueeze(-1)  # [B, T, 1]
    summed = (last * mask).sum(1)             # [B, D]
    counts = mask.sum(1)                      # [B, 1]
    return (summed / counts).cpu().numpy()    # [B, D]

In [18]:
print("[4/6] Loading chunks parquet into memory …")
chunks_ds = load_dataset(CHUNKS_REPO, split="train")
texts     = chunks_ds["text"]

[3/6] Loading chunks parquet into memory …


Generating train split:   0%|          | 0/484987 [00:00<?, ? examples/s]

In [27]:
print("[5/6] Building {} RAG examples …".format(len(mcqa)))
out_rows = []
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
topic = "knowledge and skills in advanced master-level STEM courses"
prompt = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n"
for ex in tqdm(mcqa, total=len(mcqa)):
    query = ex["question"] 
    choices = ex['options']
    opts_block = "\n".join(
            f"({LETTERS[i]}) {c}" for i, c in enumerate(choices)
        )

    query_text = (
            f"{prompt}\n"
            f"{ex['question'].strip()}\n"
            f"{opts_block}\n"
            "### Answer:"
        )
    
    # embed + retrieve
    q_vec = encode_batch([query_text])                               # (1,D)
    D, I   = faiss_index.search(q_vec.astype("float32"), TOP_K) # (1,TOP_K)
    # gather passages
    passages = [ texts[int(idx)] for idx in I[0] ]
    context  = "\n\n".join(passages)
    input_text = f"{context}\n\n{query_text}"
    label      = ex["answer"]
    out_rows.append({"input_text": input_text, "target_text": label})


[5/6] Building 10687 RAG examples …


  0%|          | 0/10687 [00:00<?, ?it/s]

In [28]:
print("[5/6] Pushing {} examples → {} …".format(len(out_rows), OUT_DS))
ds = Dataset.from_list(out_rows)
print(len(ds))
print(ds.features)

[5/6] Pushing 10687 examples → GingerBled/MNLP_M2_mcqa_with_context …
10687
{'input_text': Value(dtype='string', id=None), 'target_text': Value(dtype='string', id=None)}


In [38]:
api = HfApi()

api.create_repo(OUT_DS, repo_type="dataset", private=False, exist_ok=True)
ds.push_to_hub(OUT_DS, token=HF_TOKEN)
print("✅ Dataset ready:", OUT_DS)

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/11 [00:00<?, ?ba/s]

✅ Dataset ready: GingerBled/MNLP_M2_mcqa_with_context
