In [1]:
# ─── Cell: Define Retrieval & Attribute Functions ────────────────────────────
!pip install --no-deps --quiet nltk rank_bm25 sentence-transformers tqdm

# ─── Imports & Setup ─────────────────────────────────────────────────────────
import os, heapq
import numpy as np
import nltk
import torch

from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.corpus    import stopwords
from nltk.util      import everygrams
from transformers   import pipeline
from rank_bm25      import BM25Okapi
from sentence_transformers import SentenceTransformer, util
from tqdm.auto     import tqdm

2025-05-26 09:40:11.886513: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748252412.078439      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748252412.134251      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
# Ensure NLTK assets
nltk.download("punkt")
nltk.download("stopwords")
stops = set(stopwords.words("english"))

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /usr/share/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [3]:
# Device for encoder
device = "cuda" if torch.cuda.is_available() else "cpu"

# Path to the dump
DUMP_PATH = "/kaggle/input/wikidump/wikipedia-dump.txt"

In [4]:
# 1) Multiview Query Generator
ner = pipeline(
    "ner",
    model="dslim/bert-large-NER",
    tokenizer="dslim/bert-large-NER",
    aggregation_strategy="simple",
    device=0 if device=="cuda" else -1
)

config.json:   0%|          | 0.00/1.45k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Some weights of the model checkpoint at dslim/bert-large-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tokenizer_config.json:   0%|          | 0.00/40.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Device set to use cuda:0


In [5]:
def generate_multiview_queries(question: str, max_ngram: int = 2) -> list[str]:

    #ner
    ents = {ent["word"] for ent in ner(question)}

    # n-grams
    toks = [
        t.lower() for t in word_tokenize(question)
        if t.isalpha() and t.lower() not in stops
    ]
    ngrams = {" ".join(g) for g in everygrams(toks, max_len=max_ngram)}
    
    return [q for q in (ents | ngrams) if len(q) > 1]

In [6]:
import os

for dirname, _, filenames in os.walk("/kaggle/input"):
    print(dirname, filenames)


/kaggle/input []
/kaggle/input/wikidump ['wikipedia-dump.txt']


Topic Retriever

In [7]:
# 2) Paragraph Streamer
def iter_wiki_paragraphs(path: str):
    """Yield (title, paragraph) for each line ≥50 chars."""
    with open(path, encoding="utf-8") as f:
        for line in f:
            p = line.strip()
            if len(p) >= 50:
                yield "Wikipedia", p

In [8]:
# 3) Multiview BM25 Topic Retriever
def _process_chunk_multiview(paras, titles, queries, top_d, heap):
    tokenized = [
        [w for w in word_tokenize(p.lower()) if w.isalpha() and w not in stops]
        for p in paras
    ]
    bm25 = BM25Okapi(tokenized)
    all_scores = np.zeros(len(paras), dtype=float)
    for q in queries:
        q_tok = [w for w in word_tokenize(q.lower()) if w.isalpha()]
        all_scores = np.maximum(all_scores, bm25.get_scores(q_tok))
    for sc, t, p in zip(all_scores, titles, paras):
        if len(heap) < top_d:
            heapq.heappush(heap, (sc, t, p))
        elif sc > heap[0][0]:
            heapq.heapreplace(heap, (sc, t, p))

In [9]:
def local_topic_retrieve_mv(
    question: str,
    top_d: int = 5,
    chunk_size: int = 50_000
) -> list[tuple[float,str,str]]:
    """
    Multiview BM25 over the full dump, in chunks of `chunk_size`.
    Returns top_d (score, title, paragraph).
    """
    queries = generate_multiview_queries(question)
    heap, buffer, titles = [], [], []
    for title, para in tqdm(iter_wiki_paragraphs(DUMP_PATH), desc="Topic retrieve"):
        buffer.append(para); titles.append(title)
        if len(buffer) >= chunk_size:
            _process_chunk_multiview(buffer, titles, queries, top_d, heap)
            buffer, titles = [], []
    if buffer:
        _process_chunk_multiview(buffer, titles, queries, top_d, heap)
    return sorted(heap, key=lambda x: -x[0])[:top_d]

In [10]:

# Quick sanity-check: count first 5 paragraphs
dump_path = "/kaggle/input/wikidump/wikipedia-dump.txt"
for i, (t, p) in enumerate(iter_wiki_paragraphs(dump_path)):
    print(f"[{i}] {t} → {p[:80]}…")
    if i >= 4: break

[0] Wikipedia → anarchism is political philosophy and movement that rejects all involuntary coer…
[1] Wikipedia → autism is developmental disorder characterized by difficulties with social inter…
[2] Wikipedia → diffusely reflected sunlight relative to various surface conditions albedo meani…
[3] Wikipedia → or is the first letter and the first vowel letter of the modern english alphabet…
[4] Wikipedia → alabama is state in the southeastern region of the united states it is bordered …


In [12]:
# ─── Cell 12: Test Local‐dump Topic Retriever ───────────────────────────────
question = "What is the ruling party of the country that hosted the 2016 Olympics?"
docs = local_topic_retrieve_mv(question, top_d=50, chunk_size=50_000)

print("Top 50 paragraphs from local dump:")
for sc, t, p in docs:
    print(f"{sc:.1f} → {t}: {p[:100]}…")

print(f"Retrieved {len(docs)} paragraphs.")


Topic retrieve: 0it [00:00, ?it/s]

Top 50 paragraphs from local dump:
17.7 → Wikipedia: for the winter olympics in st moritz switzerland total of eight sports venues were used the five ven…
17.5 → Wikipedia: right horse guards parade hosted the beach volleyball events at the summer olympics in london copaca…
17.3 → Wikipedia: kazakoshi park arena is an indoor arena located in karuizawa nagano japan constructed in with an ope…
17.2 → Wikipedia: the human rights party or kanakpak sethi manus is cambodian political party founded on july led by k…
17.1 → Wikipedia: the adler arena trade and exhibition center адлер арена is an seat speed skating oval in the olympic…
16.6 → Wikipedia: london hosted the olympic games in and the summer olympics made london the first city to have hosted…
16.6 → Wikipedia: snow harp is cross country skiing venue located in hakuba nagano japan for the winter olympics the v…
16.6 → Wikipedia: paris has hosted several olympiads and and will host in paris will be the second city in the modern …
16.5 

save the docs 

In [13]:
import os, json
os.makedirs("cache", exist_ok=True)
docs_path = "cache/docs.json"

# quick check
assert len(docs) > 0, "No paragraphs retrieved—check your retriever!"
print("Sample:", docs[0])

# save
with open(docs_path, "w", encoding="utf-8") as f:
    json.dump([{"score":s,"title":t,"para":p} for s,t,p in docs], f, ensure_ascii=False, indent=2)
print("Saved to", docs_path)

# verify
size = os.path.getsize(docs_path)
print("File size:", size, "bytes")
loaded = json.load(open(docs_path, "r", encoding="utf-8"))
print("Loaded entries:", len(loaded))

Sample: (17.741650576283966, 'Wikipedia', 'for the winter olympics in st moritz switzerland total of eight sports venues were used the five venues used for the winter olympics were reused for these games three new venues were added for alpine skiing which had been added to the winter olympics program twelve years earlier in garmisch partenkirchen germany allied occupied germany during the games as of the bob run continues to be used for bobsleigh and the cresta run for skeleton while alpine skiing remains popular in st moritz venues venue sports capacity ref around the hills of st moritz cross country skiing nordic combined cross country skiing not listed cresta run skeleton not listed kulm ice hockey not listed olympiaschanze st moritz ski jumping nordic combined ski jumping not listed olympic stadium opening closing ceremonies figure skating ice hockey final speed skating not listed piz nair alpine skiing not listed st moritz celerina olympic bobrun bobsleigh not listed suvretta ice 

In [25]:
p = "./cache/docs.json"

fil = json.load(open(p, "r", encoding="utf-8"))
fil

[{'score': 17.741650576283966,
  'title': 'Wikipedia',
  'para': 'for the winter olympics in st moritz switzerland total of eight sports venues were used the five venues used for the winter olympics were reused for these games three new venues were added for alpine skiing which had been added to the winter olympics program twelve years earlier in garmisch partenkirchen germany allied occupied germany during the games as of the bob run continues to be used for bobsleigh and the cresta run for skeleton while alpine skiing remains popular in st moritz venues venue sports capacity ref around the hills of st moritz cross country skiing nordic combined cross country skiing not listed cresta run skeleton not listed kulm ice hockey not listed olympiaschanze st moritz ski jumping nordic combined ski jumping not listed olympic stadium opening closing ceremonies figure skating ice hockey final speed skating not listed piz nair alpine skiing not listed st moritz celerina olympic bobrun bobsleigh n

In [28]:
from IPython.display import FileLink

# Display a download link in the notebook output
display(FileLink("cache/docs.json"))


In [27]:
!zip -r output.zip cache/
from IPython.display import FileLink
display(FileLink("output.zip"))


  adding: cache/ (stored 0%)
  adding: cache/docs.json (deflated 71%)
  adding: cache/sents.json (deflated 70%)


In [29]:
import json

data = {"foo": 42, "bar": [1, 2, 3]}

# Write to the working directory
out_path = "/kaggle/working/results.json"
with open(out_path, "w") as f:
    json.dump(data, f, indent=2)

print(f"Saved results to {out_path}")


Saved results to /kaggle/working/results.json


import os
import json

def find_json_file(filename, search_dir="."):
    """
    Recursively search for `filename` under `search_dir`.
    Returns the first full path found, or None if not present.
    """
    for root, _, files in os.walk(search_dir):
        if filename in files:
            return os.path.join(root, filename)
    return None

# Example usage:
target = "docs.json"
path = find_json_file(target, search_dir=".")
if path is None:
    raise FileNotFoundError(f"Could not find {target} under current folder")
print(f"Found at: {path}")

# Now load it:
with open(path, "r", encoding="utf-8") as f:
    data = json.load(f)

# `data` now holds your JSON contents; for example:
print(data[:2])  # print first two entries if it's a list of objects


In [None]:
# docs path : ./cache/docs.json

import os
import json

CACHE_DIR = "cache"
docs_path = os.path.join(CACHE_DIR, "docs.json")

# 1) Verify the file exists and isn’t empty
if not os.path.exists(docs_path):
    raise FileNotFoundError(f"{docs_path} not found. Run the Topic Retriever first.")
if os.path.getsize(docs_path) == 0:
    raise RuntimeError(f"{docs_path} is empty. Delete it and rerun the Topic Retriever.")

# 2) Try to load, with fallback
try:
    with open(docs_path, "r", encoding="utf-8") as f:
        docs_data = json.load(f)
    # Convert back to your in-memory format
    docs = [(d["score"], d["title"], d["para"]) for d in docs_data]
except json.JSONDecodeError:
    # If parsing fails, delete the bad cache and recompute
    os.remove(docs_path)
    print(f"⚠️ Cache corrupted, removed {docs_path}. Please rerun the Topic Retriever.")
    raise


In [14]:
docs

[(17.741650576283966,
  'Wikipedia',
  'for the winter olympics in st moritz switzerland total of eight sports venues were used the five venues used for the winter olympics were reused for these games three new venues were added for alpine skiing which had been added to the winter olympics program twelve years earlier in garmisch partenkirchen germany allied occupied germany during the games as of the bob run continues to be used for bobsleigh and the cresta run for skeleton while alpine skiing remains popular in st moritz venues venue sports capacity ref around the hills of st moritz cross country skiing nordic combined cross country skiing not listed cresta run skeleton not listed kulm ice hockey not listed olympiaschanze st moritz ski jumping nordic combined ski jumping not listed olympic stadium opening closing ceremonies figure skating ice hockey final speed skating not listed piz nair alpine skiing not listed st moritz celerina olympic bobrun bobsleigh not listed suvretta ice hoc

In [15]:
# 4) Attribute Retriever
enc = SentenceTransformer("paraphrase-mpnet-base-v2", device=device)

def attribute_retrieve(
    question: str,
    docs: list[tuple[float,str,str]],
    top_k: int = 10
) -> list[tuple[str,float]]:
    """
    Mask title words, split docs into sentences, and return
    the top_k (sentence, score) by dense similarity.
    """
    mask = {w.lower() for _, title, _ in docs for w in title.split() if w.isalpha()}
    q_mask = " ".join(w for w in question.split() if w.lower() not in mask)

    sents, masked = [], []
    for _, _, para in docs:
        for s in sent_tokenize(para):
            sents.append(s)
            masked.append(" ".join(w for w in s.split() if w.lower() not in mask))

    q_vec = enc.encode(q_mask, convert_to_tensor=True)
    s_vecs = enc.encode(masked, convert_to_tensor=True)
    cos   = util.cos_sim(q_vec, s_vecs)[0]
    topi  = torch.topk(cos, k=min(top_k, len(sents))).indices

    return [(sents[i], float(cos[i])) for i in topi]

modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/3.52k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/594 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.19k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [15]:
#question = "What is the ruling party of the country that hosted the 2016 Olympics?"

In [16]:
# ─── Cell X+2: Attribute Retrieval ──────────────────────────────────────────
top_sents = attribute_retrieve(question, docs, top_k=5)

print("Top 5 sentences:")
for sent, sc in top_sents:
    print(f"{sc:.2f} → {sent}")


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Top 5 sentences:
0.53 → there were six bids initially submitted for the summer olympics tokyo was ultimately elected as the host city at the th ioc session in buenos aires argentina on september bidding process the olympic bidding process begins with the submission of city application to the international olympic committee ioc by its national olympic committee noc and ends with the election of the host city by the members of the ioc during an ordinary session the process is governed by the olympic charter as stated in chapter rule since the process has consisted of two phases during the first phase which begins immediately after the bid submission deadline the applicant cities are required to answer questionnaire covering themes of importance to successful games organization this information allows the ioc to analyze the cities hosting capacities and the strengths and weaknesses of their plans following detailed study of the submitted questionnaires and ensuing reports the ioc executiv

cached retrieved results and sentences

In [17]:
# ─── Cell X+3: Cache the Retrieved Sentences ────────────────────────────────
import os, json

CACHE_DIR = "cache"
os.makedirs(CACHE_DIR, exist_ok=True)

sents_path = os.path.join(CACHE_DIR, "sents.json")

# Quick check: ensure we have some sentences
assert len(top_sents) > 0, "No sentences retrieved—check your attribute retriever!"
print("Sample sentence:", top_sents[0])

# Save out
with open(sents_path, "w", encoding="utf-8") as f:
    json.dump(
        [{"sent": s, "score": sc} for s, sc in top_sents],
        f, ensure_ascii=False, indent=2
    )
print("✅ Saved", len(top_sents), "sentences to", sents_path)

# Verify
size = os.path.getsize(sents_path)
print("File size:", size, "bytes")
loaded = json.load(open(sents_path, "r", encoding="utf-8"))
print("Loaded entries:", len(loaded))


Sample sentence: ('there were six bids initially submitted for the summer olympics tokyo was ultimately elected as the host city at the th ioc session in buenos aires argentina on september bidding process the olympic bidding process begins with the submission of city application to the international olympic committee ioc by its national olympic committee noc and ends with the election of the host city by the members of the ioc during an ordinary session the process is governed by the olympic charter as stated in chapter rule since the process has consisted of two phases during the first phase which begins immediately after the bid submission deadline the applicant cities are required to answer questionnaire covering themes of importance to successful games organization this information allows the ioc to analyze the cities hosting capacities and the strengths and weaknesses of their plans following detailed study of the submitted questionnaires and ensuing reports the ioc executive boa

In [18]:
# Load top_sents
if os.path.exists(sents_path):
    with open(sents_path, "r", encoding="utf-8") as f:
        top_sents = [(d["sent"], d["score"]) for d in json.load(f)]
    print(f"✅ Loaded {len(top_sents)} cached sentences")
else:
    raise FileNotFoundError(f"{sents_path} not found; run the Attribute Retriever first.")

✅ Loaded 5 cached sentences


In [None]:
import time

count = 0
while True:
    print(count)
    count += 1
    
    print("Waiting for 2 minutes...")
    time.sleep(120)
    print("Resuming execution now.")

Disintigrated reasoning

In [19]:
# ─── Cell X: Install & import reasoning deps ────────────────────────────────
!pip install --no-deps --quiet transformers

import torch.nn as nn
from transformers import RobertaTokenizer, RobertaModel


# 1) Load tokenizer & shared RoBERTa encoder
TOKENIZER = RobertaTokenizer.from_pretrained("roberta-base")
ENCODER   = RobertaModel.from_pretrained("roberta-base")
HIDDEN    = ENCODER.config.hidden_size
NUM_STRAT = 5   # e.g. comparison, logical, entail, binary, numerical


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [20]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [21]:
# ─── Cell Y: Disentangled Reasoner Definition ───────────────────────────────
class DisentangledReasoner(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder    = ENCODER
        # Strategy head: softmax over NUM_STRAT
        self.strat_head = nn.Sequential(
            nn.Linear(HIDDEN, HIDDEN//2),
            nn.ReLU(),
            nn.Linear(HIDDEN//2, NUM_STRAT)
        )
        # Answer head: sigmoid yes/no
        self.ans_head   = nn.Sequential(
            nn.Linear(HIDDEN*2, HIDDEN),
            nn.ReLU(),
            nn.Linear(HIDDEN, 1)
        )

    def forward(self, strat_inputs, ans_inputs):
        # strat_inputs: masked q + masked sents
        # ans_inputs: original q + original sents
        # Encode for strategy
        out_strat = self.encoder(**strat_inputs).last_hidden_state[:,0]  # [CLS]
        strat_logits = self.strat_head(out_strat)                       # (B,NUM_STRAT)
        # Encode for answer
        out_ans = self.encoder(**ans_inputs).last_hidden_state[:,0]      # [CLS]
        # Combine
        combined = torch.cat([out_strat, out_ans], dim=-1)              # (B,2*H)
        ans_logit = self.ans_head(combined).squeeze(-1)                 # (B,)
        return strat_logits, ans_logit

# Instantiate
model = DisentangledReasoner().to(device)


In [23]:
# ─── Cell Z: Prepare inputs & run a single‐example inference ───────────────
# Assume you still have:
#   q         = original question string
#   docs      = output of local_topic_retrieve_mv(...)
#   top_sents = output of attribute_retrieve(q, docs, top_k=5)

q= question

# (A) Build mask set from titles
mask = {w.lower() for _, title, _ in docs for w in title.split() if w.isalpha()}

# (B) Masked inputs for strategy predictor
q_masked = " ".join(w for w in q.split() if w.lower() not in mask)
sents_masked = [
    " ".join(w for w in s.split() if w.lower() not in mask)
    for s,_ in top_sents
]
strat_text = q_masked + " " + " ".join(sents_masked)
strat_inputs = TOKENIZER(
    strat_text,
    padding=True, truncation=True, max_length=512, return_tensors="pt"
).to(device)

# (C) Original inputs for answer predictor
orig_sents = [s for s,_ in top_sents]
ans_text = q + " " + " ".join(orig_sents)
ans_inputs = TOKENIZER(
    ans_text,
    padding=True, truncation=True, max_length=512, return_tensors="pt"
).to(device)

# (D) Inference
model.eval()
with torch.no_grad():
    strat_logits, ans_logit = model(strat_inputs, ans_inputs)
    strat_pred = torch.argmax(strat_logits, dim=-1).item()
    ans_prob  = torch.sigmoid(ans_logit).item()

# (E) Print
STRAT_LABELS = ["comparison","logical","entail","binary","numerical"]
print("Predicted strategy:", STRAT_LABELS[strat_pred])
print("Answer probability:", ans_prob)
print("Predicted answer:", "Yes" if ans_prob>0.5 else "No")


Predicted strategy: entail
Answer probability: 0.5046440362930298
Predicted answer: Yes


In [44]:
import time

count = 0
while True:
    print(count)
    count += 1
    
    print("Waiting for 2 minutes...")
    time.sleep(120)
    print("Resuming execution now.")
       


0
Waiting for 2 minutes...
Resuming execution now.
1
Waiting for 2 minutes...
Resuming execution now.
2
Waiting for 2 minutes...
Resuming execution now.
3
Waiting for 2 minutes...
Resuming execution now.
4
Waiting for 2 minutes...
Resuming execution now.
5
Waiting for 2 minutes...
Resuming execution now.
6
Waiting for 2 minutes...
Resuming execution now.
7
Waiting for 2 minutes...
Resuming execution now.
8
Waiting for 2 minutes...
Resuming execution now.
9
Waiting for 2 minutes...
Resuming execution now.
10
Waiting for 2 minutes...
Resuming execution now.
11
Waiting for 2 minutes...


KeyboardInterrupt: 

In [43]:
docs

[(17.741650576283966,
  'Wikipedia',
  'for the winter olympics in st moritz switzerland total of eight sports venues were used the five venues used for the winter olympics were reused for these games three new venues were added for alpine skiing which had been added to the winter olympics program twelve years earlier in garmisch partenkirchen germany allied occupied germany during the games as of the bob run continues to be used for bobsleigh and the cresta run for skeleton while alpine skiing remains popular in st moritz venues venue sports capacity ref around the hills of st moritz cross country skiing nordic combined cross country skiing not listed cresta run skeleton not listed kulm ice hockey not listed olympiaschanze st moritz ski jumping nordic combined ski jumping not listed olympic stadium opening closing ceremonies figure skating ice hockey final speed skating not listed piz nair alpine skiing not listed st moritz celerina olympic bobrun bobsleigh not listed suvretta ice hoc

StrategyQA

In [39]:
from datasets import load_dataset

# ─── Load StrategyQA (ChilleD mirror) ───────────────────────────────────────
# This will pull 2,780 train examples and 490 test examples
train_ds = load_dataset("ChilleD/StrategyQA", split="train")
test_ds  = load_dataset("ChilleD/StrategyQA", split="test")

print(f"Train size: {len(train_ds)}")
print(f"Test  size: {len(test_ds)}")
# Inspect a single example
print(train_ds[0])


README.md:   0%|          | 0.00/433 [00:00<?, ?B/s]

(…)-00000-of-00001-506370352f622815.parquet:   0%|          | 0.00/369k [00:00<?, ?B/s]

(…)-00000-of-00001-bae602f3ee37f4ca.parquet:   0%|          | 0.00/161k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1603 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/687 [00:00<?, ? examples/s]

Train size: 1603
Test  size: 687
{'qid': '4fd64bb6ce5b78ab20b6', 'term': 'Mixed martial arts', 'description': 'full contact combat sport', 'question': 'Is Mixed martial arts totally original from Roman Colosseum games?', 'answer': False, 'facts': 'Mixed Martial arts in the UFC takes place in an enclosed structure called The Octagon. The Roman Colosseum games were fought in enclosed arenas where combatants would fight until the last man was standing. Mixed martial arts contests are stopped when one of the combatants is incapacitated. The Roman Colosseum was performed in front of crowds that numbered in the tens of thousands. Over 56,000 people attended UFC 193.'}


In [47]:
# ─── Cell: load_or_compute_docs helper ──────────────────────────────────────
import os, json, hashlib, heapq, numpy as np
from nltk.tokenize    import word_tokenize
from rank_bm25        import BM25Okapi
from tqdm.auto        import tqdm


CACHE_DIR = "cache"
DUMP_PATH = "/kaggle/input/wikidump/wikipedia-dump.txt"


def get_cache_paths(question: str):
    h = hashlib.sha256(question.encode("utf-8")).hexdigest()[:8]
    return (
        os.path.join(CACHE_DIR, f"docs_{h}.json"),
        os.path.join(CACHE_DIR, f"sents_{h}.json")
    )

def load_or_compute_docs(
    question: str,
    top_d: int = 5,
    chunk_size: int = 50_000
) -> list[tuple[float,str,str]]:
    """
    Returns top_d (score, title, paragraph) for `question`.
    Caches to CACHE_DIR/docs_<hash>.json on first run, then reloads.
    """
    docs_path, _ = get_cache_paths(question)

    # 1) Try loading cached
    if os.path.exists(docs_path):
        with open(docs_path, "r", encoding="utf-8") as f:
            data = json.load(f)
        return [(d["score"], d["title"], d["para"]) for d in data]

    # 2) Otherwise compute
    queries = generate_multiview_queries(question)
    heap, buffer, titles = [], [], []

    for title, para in tqdm(iter_wiki_paragraphs(DUMP_PATH), desc="Topic retrieve"):
        buffer.append(para); titles.append(title)
        if len(buffer) >= chunk_size:
            _process_chunk_multiview(buffer, titles, queries, top_d, heap)
            buffer, titles = [], []

    # process leftovers
    if buffer:
        _process_chunk_multiview(buffer, titles, queries, top_d, heap)

    docs = sorted(heap, key=lambda x: -x[0])[:top_d]

    # 3) Cache to disk
    os.makedirs(CACHE_DIR, exist_ok=True)
    with open(docs_path, "w", encoding="utf-8") as f:
        json.dump(
            [{"score": s, "title": t, "para": p} for s, t, p in docs],
            f, ensure_ascii=False, indent=2
        )

    return docs


delete it


In [50]:
# ─── Cell 9 (cache-only): Batch Process Questions via Cached Docs/Sents ────
import os, json, hashlib


CACHE_DIR = "cache"
DUMP_PATH = "/kaggle/input/wikidump/wikipedia-dump.txt"


def get_cache_paths(question: str):
    h = hashlib.sha256(question.encode("utf-8")).hexdigest()[:8]
    return (
        os.path.join(CACHE_DIR, f"docs_{h}.json"),
        os.path.join(CACHE_DIR, f"sents_{h}.json")
    )

questions = [
    "What is the ruling party of the country that hosted the 2016 Olympics?",
    "Is the Eiffel Tower located in Paris?",
    "Was Barack Obama born in the United States?"
    # … add more …
]

for q in questions:
    print(f"\n--- {q} ---")
    docs_path, sents_path = get_cache_paths(q)

    # 1) Load docs from cache
    if not os.path.exists(docs_path):
        raise FileNotFoundError(f"{docs_path} not found. Run load_or_compute_docs() once first.")
    with open(docs_path, "r", encoding="utf-8") as f:
        docs = [(d["score"], d["title"], d["para"]) for d in json.load(f)]
    print(f"🔄 Loaded {len(docs)} paragraphs from {docs_path}")

    # 2) Load sents from cache
    if not os.path.exists(sents_path):
        raise FileNotFoundError(f"{sents_path} not found. Run load_or_compute_sents() once first.")
    with open(sents_path, "r", encoding="utf-8") as f:
        top_sents = [(d["sent"], d["score"]) for d in json.load(f)]
    print(f"🔄 Loaded {len(top_sents)} sentences from {sents_path}")

    # 3) (Optional) print a sample
    print("\nSample paragraph:", docs[0][2][:100], "…")
    print("Sample sentence:", top_sents[0][0])



--- What is the ruling party of the country that hosted the 2016 Olympics? ---


FileNotFoundError: cache/docs_d866e7b8.json not found. Run load_or_compute_docs() once first.

In [None]:
# ─── Cell 9: Batch Process Multiple Questions ───────────────────────────────
questions = [
    "What is the ruling party of the country that hosted the 2016 Olympics?",
    "Is the Eiffel Tower located in Paris?",
    "Was Barack Obama born in the United States?"
    # … add more …
    ]
for q in questions:
    print(f"\n--- {q} ---")
    docs     = load_or_compute_docs(q)
    sents    = load_or_compute_sents(q, docs)
    print("Docs:", len(docs), "Sentences:", len(sents))



--- What is the ruling party of the country that hosted the 2016 Olympics? ---


Topic retrieve: 0it [00:00, ?it/s]

In [None]:
# need to update the cell 9 for that docs thingy.

In [None]:
import time

count = 0
while True:
    print(count)
    count += 1
    
    print("Waiting for 2 minutes...")
    time.sleep(120)
    print("Resuming execution now.")
       


saving the model state

In [None]:
# ─── Saving the model state ───────────────────────────────────────────────────
import torch

# Assume `model` is your DisentangledReasoner instance on `device`
save_path = "disentangled_reasoner.pt"
torch.save(model.state_dict(), save_path)
print(f"✅ Model weights saved to {save_path}")


In [None]:
# ─── Reloading the model state ───────────────────────────────────────────────
import torch

# 1) Re-instantiate the model architecture
model = DisentangledReasoner().to(device)

# 2) Load the saved weights
checkpoint = torch.load("disentangled_reasoner.pt", map_location=device)
model.load_state_dict(checkpoint)

# 3) Set to eval mode (if you’re doing inference)
model.eval()
print("✅ Model weights loaded and ready for inference")


In [None]:
torch.save({
    "model_state": model.state_dict(),
    "optim_state": optimizer.state_dict(),
    "epoch": current_epoch,
    "loss": last_loss,
}, "checkpoint.pt")


In [None]:
ckpt = torch.load("checkpoint.pt")
model.load_state_dict(ckpt["model_state"])
optimizer.load_state_dict(ckpt["optim_state"])
start_epoch = ckpt["epoch"] + 1
