In [1]:
# Cell 1: Setup everything


%env WANDB_DISABLED=true

# 1) Install needed packages
!pip install --quiet rank-bm25 sentence-transformers torch tqdm

# 2) Imports
import pandas as pd
import numpy as np
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, InputExample, losses, util
from torch.utils.data import DataLoader
import torch
from tqdm.auto import tqdm

# 3) Load data files
papers_df = pd.read_pickle("subtask4b_collection_data.pkl")
train_df  = pd.read_csv("subtask4b_query_tweets_train.tsv", sep="\t")
dev_df    = pd.read_csv("subtask4b_query_tweets_dev.tsv",   sep="\t")
test_df   = pd.read_csv("subtask4b_query_tweets_test.tsv",  sep="\t")

# 4) Quick check
print("Papers:", papers_df.shape)
print("Train:",  train_df.shape)
print("Dev:",    dev_df.shape)
print("Test:",   test_df.shape)


env: WANDB_DISABLED=true
Papers: (7718, 17)
Train: (12853, 3)
Dev: (1400, 3)
Test: (1446, 2)


In [2]:
# Cell 2: BM25 baseline & MRR

# Build BM25 index
tokenized = [
    (papers_df.iloc[i]['title'] + " " + (papers_df.iloc[i].get('abstract') or "")).split()
    for i in range(len(papers_df))
]
bm25 = BM25Okapi(tokenized)

# Retrieval + MRR functions
def get_topk_bm25(tweet, k=5):
    toks = tweet.split()
    scores = bm25.get_scores(toks)
    idxs = np.argsort(scores)[-k:][::-1]
    return papers_df.iloc[idxs]['cord_uid'].tolist()

def compute_mrr(true_ids, preds, k=5):
    rr = []
    for t, p in zip(true_ids, preds):
        if t in p:
            r = p.index(t) + 1
            rr.append(1.0/r if r<=k else 0.0)
        else:
            rr.append(0.0)
    return sum(rr)/len(rr)

# Evaluate on dev
dev_preds = [get_topk_bm25(r['tweet_text']) for _, r in dev_df.iterrows()]
print("BM25 Dev MRR@5:", compute_mrr(dev_df['cord_uid'].tolist(), dev_preds))


BM25 Dev MRR@5: 0.5519166666666668


In [4]:
# Cell 3: Prepare & fine-tune


import os
os.environ["WANDB_DISABLED"] = "true"

# Build Input from data
def make_input_examples(df):
    exs = []
    for _, r in tqdm(df.iterrows(), total=len(df), desc="Making examples"):
        q = r['tweet_text']
        paper = papers_df[papers_df['cord_uid']==r['cord_uid']].iloc[0]
        doc = paper['title'] + " " + (paper.get('abstract') or "")
        exs.append(InputExample(texts=[q, doc]))
    return exs

train_ex = make_input_examples(train_df)
dev_ex   = make_input_examples(dev_df)

# DataLoaders with small batch sizes
train_loader = DataLoader(train_ex, shuffle=True,  batch_size=8)
dev_loader   = DataLoader(dev_ex,   shuffle=False, batch_size=16)

# Load a lighter bi-encoder and define loss
model      = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
train_loss = losses.MultipleNegativesRankingLoss(model)

# Fine-tune for 2 epochs
model.fit(
    train_objectives=[(train_loader, train_loss)],
    epochs=2,
    optimizer_params={'lr': 2e-5},
    show_progress_bar=True,
    output_path='fine_tuned_mini'
)


Making examples:   0%|          | 0/12853 [00:00<?, ?it/s]

Making examples:   0%|          | 0/1400 [00:00<?, ?it/s]

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.2715
1000,0.1982
1500,0.1717
2000,0.1653


Step,Training Loss
500,0.2715
1000,0.1982
1500,0.1717
2000,0.1653
2500,0.1462
3000,0.1649


In [5]:
# Cell 4: Embed, rank, save

texts = [
    row['title'] + " " + (row.get('abstract') or "")
    for _, row in papers_df.iterrows()
]
uids = papers_df['cord_uid'].tolist()
paper_embs = model.encode(texts, convert_to_tensor=True, show_progress_bar=True)

def rank_top5(df):
    rows = []
    for _, r in tqdm(df.iterrows(), total=len(df), desc="Ranking"):
        q_emb = model.encode(r['tweet_text'], convert_to_tensor=True)
        sims = util.cos_sim(q_emb, paper_embs)[0]
        topk = torch.topk(sims, k=5).indices.cpu().numpy()
        rows.append((r['post_id'], ",".join([uids[i] for i in topk])))
    return pd.DataFrame(rows, columns=['post_id','ranking'])

dev_out  = rank_top5(dev_df)
test_out = rank_top5(test_df)

dev_out.to_csv('dev_ranking_nn.tsv',  sep='\t', index=False)
test_out.to_csv('test_ranking_nn.tsv', sep='\t', index=False)
print("Done: generated dev_ranking_nn.tsv and test_ranking_nn.tsv")


Batches:   0%|          | 0/242 [00:00<?, ?it/s]

Ranking:   0%|          | 0/1400 [00:00<?, ?it/s]

Ranking:   0%|          | 0/1446 [00:00<?, ?it/s]

Done: generated dev_ranking_nn.tsv and test_ranking_nn.tsv


In [6]:
!ls -lh


total 20M
drwxr-xr-x 6 root root 4.0K May 30 16:29 checkpoints
-rw-r--r-- 1 root root  69K May 30 20:44 dev_ranking_nn.tsv
drwxr-xr-x 4 root root 4.0K May 30 20:25 fine_tuned_mini
-rw-r--r-- 1 root root  47K May 30 15:43 getting_started_subtask4b.ipynb
-rw-r--r-- 1 root root 4.5K May 30 15:43 README.md
drwxr-xr-x 1 root root 4.0K May 28 19:28 sample_data
-rw-r--r-- 1 root root  16M May 30 15:43 subtask4b_collection_data.pkl
-rw-r--r-- 1 root root 292K May 30 15:43 subtask4b_query_tweets_dev.tsv
-rw-r--r-- 1 root root 256K May 30 15:43 subtask4b_query_tweets_test_gold.tsv
-rw-r--r-- 1 root root 244K May 30 15:43 subtask4b_query_tweets_test.tsv
-rw-r--r-- 1 root root 2.7M May 30 15:43 subtask4b_query_tweets_train.tsv
-rw-r--r-- 1 root root  70K May 30 20:44 test_ranking_nn.tsv


In [7]:
from google.colab import files
# Download the TSVs directly
files.download('dev_ranking_nn.tsv')
files.download('test_ranking_nn.tsv')

# Zip up the model folder before downloading
import shutil
shutil.make_archive('fine_tuned_mini', 'zip', 'fine_tuned_mini')
files.download('fine_tuned_mini.zip')


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>