In [18]:
# all the imports
import os
import re
import json
import math
import textwrap
import warnings
from dataclasses import dataclass
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import s3fs
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

from sentence_transformers import SentenceTransformer
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import io


In [6]:
warnings.filterwarnings("ignore")

# --- PARAMETERS YOU MAY TUNE ---
SUBSET_MAX_TXT = 80      # keep <100 as requested
TOP_K = 12               # top results per query
SUMMARY_MAX_WORDS = 70   # aim ~2-3 sentences
KEYWORDS_TOP_N = 6

# Models (small/CPU-friendly)
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
# For summarization we'll try a distilled BART checkpoint fine-tuned for CNN/DM.
# If unavailable locally, fallback to t5-small with the "summarize: " prefix.
SUMM_MODEL_PRIMARY = "sshleifer/distilbart-cnn-12-6"
SUMM_MODEL_FALLBACK = "t5-small"

# Optional verifier (zero-shot)
ZS_MODEL = "facebook/bart-large-mnli"  # ok on CPU for small batches

# Themes for optional verifier
THEMES = ["Deep Learning", "Clinical Trial", "Traditional Methods"]

# Example input queries (from prompt)
QUERIES = [
    "Adverse events with mRNA vaccines in pediatrics",
    "Transformer-based models for protein folding",
    "Clinical trial outcomes for monoclonal antibodies in oncology",
]

In [13]:
# Public S3 bucket per brief:
BUCKET = "pmc-oa-opendata"
FILELIST_KEY = "oa_comm/txt/metadata/csv/oa_comm.filelist.csv"
TXT_PREFIX = "oa_comm/txt/all/"

def get_fs():
    # Anonymous access to the public bucket
    # If your network blocks anon, set AWS creds and drop anon=True.
    try:
        return s3fs.S3FileSystem(anon=True)
    except Exception:
        return s3fs.S3FileSystem(anon=False)

fs = get_fs()
print(fs)

<s3fs.core.S3FileSystem object at 0x14a3e8830>


In [19]:
def load_filelist_sample_robust(max_txt=SUBSET_MAX_TXT, seed=7):
    s3_url = f"s3://{BUCKET}/{FILELIST_KEY}"
    with fs.open(s3_url, "rb") as f:
        raw = f.read()

    # Try to parse as a regular CSV with unknown columns
    df_try = pd.read_csv(io.BytesIO(raw), dtype=str, low_memory=False)

    candidates = []

    def collect_from_series(series):
        vals = series.dropna().astype(str)
        # accept relative keys or full s3 URIs; match .../oa_comm/txt/all/....txt
        mask = vals.str.contains(r"oa_comm/txt/all/.+\.txt", regex=True)
        return vals[mask].tolist()

    # scan all columns if tabular
    for col in df_try.columns:
        candidates.extend(collect_from_series(df_try[col]))

    # last-resort: regex scan the raw text (handles odd formats)
    if not candidates:
        txt = raw.decode("utf-8", errors="ignore")
        found = re.findall(r"(?:s3://pmc-oa-opendata/)?oa_comm/txt/all/[^\s,]+?\.txt", txt)
        candidates.extend(found)

    # Normalize to s3:// paths
    s3_paths = [
        x if x.startswith("s3://") else f"s3://{BUCKET}/{x.lstrip('/')}"
        for x in pd.Series(candidates).drop_duplicates().tolist()
    ]

    if not s3_paths:
        return []

    sample = pd.Series(s3_paths).sample(
        n=min(len(s3_paths), max_txt),
        random_state=seed
    ).tolist()
    return sample

sample_paths = load_filelist_sample_robust()
len(sample_paths), sample_paths[:3]


(80,
 ['s3://pmc-oa-opendata/oa_comm/txt/all/PMC6078067.txt',
  's3://pmc-oa-opendata/oa_comm/txt/all/PMC11010288.txt',
  's3://pmc-oa-opendata/oa_comm/txt/all/PMC6345036.txt'])