In [15]:
# Cell 1: Import Libraries
import pandas as pd
import pickle
from sentence_transformers import SentenceTransformer
import torch
import numpy as np
import os

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [16]:
# Cell 2: Load Fine-tuned Model
model_path = "/Users/mataonbas/AIR-CheckThat!-GroupProject/CheckThat-ScientificClaimSourceRetrieval/fine_tuned_sbert"
model = SentenceTransformer(model_path)
model.to(device)
print("Model loaded successfully")

Model loaded successfully


In [17]:
# Cell 3: Load Test Data and Papers
# Load test data
test_df = pd.read_csv('/Users/mataonbas/AIR-CheckThat!-GroupProject/CheckThat-ScientificClaimSourceRetrieval/subtask4b_query_tweets_test.tsv', sep='\t')
print(f"Loaded {len(test_df)} test examples")

# Load papers data
with open('/Users/mataonbas/AIR-CheckThat!-GroupProject/CheckThat-ScientificClaimSourceRetrieval/subtask4b_collection_data.pkl', 'rb') as f:
    papers_df = pickle.load(f)
papers_df['text'] = papers_df['title'] + '. ' + papers_df['abstract']
print(f"Loaded {len(papers_df)} papers")

Loaded 1446 test examples
Loaded 7718 papers


In [18]:
# Cell 4: Define Prediction Function
def get_predictions(model, test_df, papers_df, top_k=5):
    # Convert all text to string type
    test_texts = [str(text) for text in test_df['tweet_text'].tolist()]
    paper_texts = [str(text) for text in papers_df['text'].tolist()]
    
    print("Encoding test queries...")
    # Encode test queries
    query_embeddings = model.encode(
        test_texts, 
        show_progress_bar=True, 
        convert_to_tensor=True,
        device=device
    )
    
    print("Encoding papers...")
    # Encode papers
    paper_embeddings = model.encode(
        paper_texts, 
        show_progress_bar=True, 
        convert_to_tensor=True,
        device=device
    )
    
    print("Computing similarities and generating predictions...")
    # Compute similarity
    paper_norm = torch.nn.functional.normalize(paper_embeddings, p=2, dim=1)
    paper_ids = papers_df['cord_uid'].tolist()
    predictions = []
    
    for query_emb in query_embeddings:
        query_norm = torch.nn.functional.normalize(query_emb.unsqueeze(0), p=2, dim=1)
        similarity = torch.matmul(query_norm, paper_norm.T).squeeze()
        top_indices = torch.topk(similarity, k=min(top_k, len(paper_norm))).indices.tolist()
        preds = [paper_ids[i] for i in top_indices]
        predictions.append(preds)
    
    return predictions

In [19]:
# Cell 5: Generate Predictions
print("Generating predictions...")
predictions = get_predictions(model, test_df, papers_df)
print(f"Generated predictions for {len(predictions)} test examples")

Generating predictions...
Encoding test queries...


Batches: 100%|██████████| 46/46 [00:26<00:00,  1.71it/s]


Encoding papers...


Batches: 100%|██████████| 242/242 [26:09<00:00,  6.49s/it]


Computing similarities and generating predictions...
Generated predictions for 1446 test examples


In [20]:
# Cell 6: Save Predictions
# Create output DataFrame
output_data = []
for i, preds in enumerate(predictions):
    output_data.append({
        'post_id': test_df.iloc[i]['post_id'],
        'preds': str(preds)
    })

output_df = pd.DataFrame(output_data)

# Save to TSV file
output_df.to_csv('predictions_fromfinetunedmodel.tsv', 
                 sep='\t', 
                 index=False, 
                 header=True)  # Include header
print("Predictions saved to predictions_fromfinetunedmodel.tsv")

# Print sample of predictions
print("\nSample predictions:")
print(output_df.head())

Predictions saved to predictions_fromfinetunedmodel.tsv

Sample predictions:
   post_id                                              preds
0        1  ['qgwu9fsk', 'bv7hvc1e', 'x4zuv4jo', '0oq0xmzr...
1        2  ['wubcq0xx', 'mm2aotem', 'u4ntxo0y', '4vkkaqhz...
2        3  ['gtp5daep', 'vh3qs9xv', 'm3m2n3fw', 'hdk02l2r...
3        4  ['ru2ty1y9', 'enlj85zc', 'zs78ndoa', '609b8j39...
4        5  ['f5p37j7g', 'nzat41wu', 'x9zg7ulr', 'n9zqc1gm...
