<a href="https://colab.research.google.com/github/Shristy183/transfer-learning-project/blob/main/transfer_learning_pipeline_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RAT Rejected Article Tracker – Crossref & OpenAlex Matcher (Colab)

Upload your **rejected article tracker (RAT) CSV**; the pipeline finds matches in **both Crossref and OpenAlex**, compares which source has more data, and outputs a CSV like `doi_results_fuzzy` with **Original_Title**, **Crossref_DOI** / **OpenAlex_DOI** (or "Not Found"), **Crossref_Found** / **OpenAlex_Found**, **best_doi**, **best_journal**, **publish_year**, **citations**, **match_score**, **source**, **best_link**, **matched_title**. Rejected-only rows (non-empty REJECT_REASON); disk cache for both APIs to save credits.

- **Rejected manuscripts**: Uses your CSV columns `MANUSCRIPT_ID`, `REJECT_REASON`, `TITLE`, `FIRST_AUTHOR`; only rows with non-empty `REJECT_REASON` are processed.
- **Fewer API calls**: Disk cache + deduplication by (title, first author) and by DOI so re-runs and duplicate rows don’t exhaust OpenAlex/CrossRef free credits.
- **Faster**: One OpenAlex search per unique (title, first author); one metadata request per unique DOI.

1. Install deps → 2. Config → 3. Upload CSV → 4. Run pipeline → 5. Download output

## 1. Install dependencies

In [None]:
!pip install -q torch transformers pandas numpy requests tqdm

## 2. Configuration

In [None]:
# ---------- CONFIG (edit as needed) ----------
OPENALEX_WORKS_SEARCH = "https://api.openalex.org/works"
OPENALEX_WORKS_BY_DOI = "https://api.openalex.org/works/https://doi.org/{doi}"
CROSSREF_WORKS = "https://api.crossref.org/works"
PER_PAGE = 5
CROSSREF_ROWS = 5
MAILTO = "your_email@example.com"  # Set for OpenAlex polite pool
USER_AGENT = "TransferLearningRAT/1.0 (mailto:Shristyranjan108@example.com)"  # Crossref polite pool
REQUEST_TIMEOUT_SEC = 30
API_DELAY_SEC = 0.3   # Delay when not using parallel (reduced for speed)
MAX_WORKERS = 5  # Parallel API requests per source (OpenAlex / Crossref polite pool)
TOP_K_CANDIDATES = 5
MAX_RETRIES = 3
RETRY_BACKOFF_SEC = 2.0

# ---------- CACHE (saves API credits; reuse across runs) ----------
USE_DISK_CACHE = True
CACHE_DIR = "/content/transfer_learning_cache"
SEARCH_CACHE_FILE = "openalex_search_cache.json"
DOI_CACHE_FILE = "openalex_doi_cache.json"
CROSSREF_SEARCH_CACHE_FILE = "crossref_search_cache.json"
CROSSREF_DOI_CACHE_FILE = "crossref_doi_cache.json"

# ---------- REJECTED CSV COLUMNS (must match your file) ----------
COL_MANUSCRIPT_ID = "MANUSCRIPT_ID"
COL_DATE_REJECTION = "DATE_OF_REJECTION"
COL_REJECT_REASON = "REJECT_REASON"
COL_TITLE = "TITLE"
COL_FIRST_AUTHOR = "FIRST_AUTHOR"
COL_CORRESPONDING_AUTHOR = "CORRESPONDING_AUTHOR"
COL_CO_AUTHORS = "CO_AUTHORS"

MODEL_NAME = "intfloat/e5-large-v2"
QUERY_PREFIX = "query: "
PASSAGE_PREFIX = "passage: "
BATCH_SIZE = 64
MAX_LENGTH = 512
NORMALIZE_EMBEDDINGS = True

SIMILARITY_THRESHOLD = 0.80
OUTPUT_FILENAME = "doi_results_fuzzy.csv"

## 3. Retrieval (OpenAlex + Crossref)

In [None]:
# ---------- Disk cache to avoid exhausting OpenAlex/CrossRef free credits ----------
import json
import hashlib
from pathlib import Path

def _cache_path(filename):
    d = Path(CACHE_DIR)
    if USE_DISK_CACHE:
        d.mkdir(parents=True, exist_ok=True)
    return d / filename

def _norm_key(s):
    return " ".join(str(s).strip().lower().split()) if s else ""

def _search_cache_key(title, first_author=""):
    raw = _norm_key(title) + "|" + _norm_key(first_author)
    return hashlib.sha256(raw.encode("utf-8")).hexdigest()

def load_search_cache():
    p = _cache_path(SEARCH_CACHE_FILE)
    if not p.exists():
        return {}
    try:
        with open(p, "r", encoding="utf-8") as f:
            return json.load(f)
    except Exception:
        return {}

def save_search_cache(cache):
    p = _cache_path(SEARCH_CACHE_FILE)
    p.parent.mkdir(parents=True, exist_ok=True)
    with open(p, "w", encoding="utf-8") as f:
        json.dump(cache, f, ensure_ascii=False, indent=0)

def load_doi_cache():
    p = _cache_path(DOI_CACHE_FILE)
    if not p.exists():
        return {}
    try:
        with open(p, "r", encoding="utf-8") as f:
            return json.load(f)
    except Exception:
        return {}

def save_doi_cache(cache):
    p = _cache_path(DOI_CACHE_FILE)
    p.parent.mkdir(parents=True, exist_ok=True)
    with open(p, "w", encoding="utf-8") as f:
        json.dump(cache, f, ensure_ascii=False, indent=0)

def load_crossref_search_cache():
    p = _cache_path(CROSSREF_SEARCH_CACHE_FILE)
    if not p.exists(): return {}
    try:
        with open(p, "r", encoding="utf-8") as f: return json.load(f)
    except Exception: return {}

def save_crossref_search_cache(cache):
    p = _cache_path(CROSSREF_SEARCH_CACHE_FILE)
    p.parent.mkdir(parents=True, exist_ok=True)
    with open(p, "w", encoding="utf-8") as f:
        json.dump(cache, f, ensure_ascii=False, indent=0)

def load_crossref_doi_cache():
    p = _cache_path(CROSSREF_DOI_CACHE_FILE)
    if not p.exists(): return {}
    try:
        with open(p, "r", encoding="utf-8") as f: return json.load(f)
    except Exception: return {}

def save_crossref_doi_cache(cache):
    p = _cache_path(CROSSREF_DOI_CACHE_FILE)
    p.parent.mkdir(parents=True, exist_ok=True)
    with open(p, "w", encoding="utf-8") as f:
        json.dump(cache, f, ensure_ascii=False, indent=0)

print("Cache dir:", CACHE_DIR, "(enabled)" if USE_DISK_CACHE else "(disabled)")

In [None]:
import time
import requests
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

def _extract_doi(work):
    doi_url = work.get("doi")
    if not doi_url or not isinstance(doi_url, str):
        return None
    if doi_url.startswith("https://doi.org/"):
        return doi_url.replace("https://doi.org/", "", 1).strip()
    return doi_url.strip() or None

def search_openalex(title):
    params = {"search": title, "per-page": PER_PAGE, "mailto": MAILTO}
    candidates = []
    for attempt in range(MAX_RETRIES):
        try:
            resp = requests.get(OPENALEX_WORKS_SEARCH, params=params, timeout=REQUEST_TIMEOUT_SEC)
            resp.raise_for_status()
            data = resp.json()
            break
        except requests.RequestException:
            if attempt == MAX_RETRIES - 1:
                return []
            time.sleep(RETRY_BACKOFF_SEC * (attempt + 1))
    else:
        return []
    for hit in (data.get("results") or []):
        doi = _extract_doi(hit)
        if doi is None:
            continue
        display_title = (hit.get("title") or "").strip() or "(no title)"
        candidates.append({"title": display_title, "doi": doi, "source": "OpenAlex"})
        if len(candidates) >= TOP_K_CANDIDATES:
            break
    return candidates

def retrieve_candidates_for_titles(titles, first_authors=None, delay_sec=API_DELAY_SEC):
    """Retrieve OpenAlex candidates with disk cache, deduplication, and parallel API calls."""
    n = len(titles)
    if first_authors is None: first_authors = [""] * n
    first_authors = list(first_authors) if len(first_authors) == n else [""] * n
    key_to_indices = {}
    for i in range(n):
        t = str(titles[i]).strip() if titles[i] else ""
        fa = str(first_authors[i]).strip() if first_authors[i] else ""
        key = _search_cache_key(t, fa)
        if key not in key_to_indices: key_to_indices[key] = []
        key_to_indices[key].append((i, t))
    cache = load_search_cache() if USE_DISK_CACHE else {}
    cache_updated = False
    key_to_candidates = {}
    for key in key_to_indices:
        if key in cache:
            key_to_candidates[key] = cache[key]
            continue
        query_title = (key_to_indices[key][0][1] or "").strip()
        if not query_title:
            key_to_candidates[key] = []
            continue
    keys_to_fetch = [(k, key_to_indices[k][0][1]) for k in key_to_indices if k not in key_to_candidates and (key_to_indices[k][0][1] or "").strip()]
    if keys_to_fetch:
        with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
            future_to_key = {ex.submit(search_openalex, qt): k for k, qt in keys_to_fetch}
            for future in tqdm(as_completed(future_to_key), total=len(future_to_key), desc="OpenAlex (parallel)"):
                key = future_to_key[future]
                try:
                    cands = future.result()
                    key_to_candidates[key] = cands
                    if USE_DISK_CACHE: cache[key] = cands; cache_updated = True
                except Exception: key_to_candidates[key] = []
    if USE_DISK_CACHE and cache_updated: save_search_cache(cache)
    results = [key_to_candidates.get(_search_cache_key(str(titles[i]).strip() if titles[i] else "", str(first_authors[i]).strip() if first_authors[i] else ""), []) for i in range(n)]
    return results

# ---------- Crossref search (cached + deduplicated) ----------
def search_crossref(title):
    headers = {"User-Agent": USER_AGENT}
    params = {"query.bibliographic": title, "rows": CROSSREF_ROWS}
    candidates = []
    for attempt in range(MAX_RETRIES):
        try:
            resp = requests.get(CROSSREF_WORKS, params=params, headers=headers, timeout=REQUEST_TIMEOUT_SEC)
            resp.raise_for_status()
            data = resp.json()
            break
        except requests.RequestException:
            if attempt == MAX_RETRIES - 1: return []
            time.sleep(RETRY_BACKOFF_SEC * (attempt + 1))
    else:
        return []
    for item in (data.get("message", {}).get("items") or [])[:TOP_K_CANDIDATES]:
        doi = (item.get("DOI") or "").strip()
        if not doi: continue
        titles_list = item.get("title") or []
        display_title = (titles_list[0] if titles_list else "").strip() or "(no title)"
        candidates.append({"title": display_title, "doi": doi, "source": "Crossref"})
    return candidates

def retrieve_crossref_candidates_for_titles(titles, first_authors=None, delay_sec=API_DELAY_SEC):
    """Retrieve Crossref candidates with disk cache, deduplication, and parallel API calls."""
    n = len(titles)
    if first_authors is None: first_authors = [""] * n
    first_authors = list(first_authors) if len(first_authors) == n else [""] * n
    key_to_indices = {}
    for i in range(n):
        t = str(titles[i]).strip() if titles[i] else ""
        fa = str(first_authors[i]).strip() if first_authors[i] else ""
        key = _search_cache_key(t, fa)
        if key not in key_to_indices: key_to_indices[key] = []
        key_to_indices[key].append((i, t))
    cache = load_crossref_search_cache() if USE_DISK_CACHE else {}
    cache_updated = False
    key_to_candidates = {}
    for key in key_to_indices:
        if key in cache:
            key_to_candidates[key] = cache[key]
            continue
        query_title = (key_to_indices[key][0][1] or "").strip()
        if not query_title:
            key_to_candidates[key] = []
            continue
    keys_to_fetch = [(k, key_to_indices[k][0][1]) for k in key_to_indices if k not in key_to_candidates and (key_to_indices[k][0][1] or "").strip()]
    if keys_to_fetch:
        with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
            future_to_key = {ex.submit(search_crossref, qt): k for k, qt in keys_to_fetch}
            for future in tqdm(as_completed(future_to_key), total=len(future_to_key), desc="Crossref (parallel)"):
                key = future_to_key[future]
                try:
                    cands = future.result()
                    key_to_candidates[key] = cands
                    if USE_DISK_CACHE: cache[key] = cands; cache_updated = True
                except Exception: key_to_candidates[key] = []
    if USE_DISK_CACHE and cache_updated: save_crossref_search_cache(cache)
    results = [key_to_candidates.get(_search_cache_key(str(titles[i]).strip() if titles[i] else "", str(first_authors[i]).strip() if first_authors[i] else ""), []) for i in range(n)]
    return results

## 4. E5 Embeddings (batch)

In [None]:
import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer

# Force T4 GPU when available (Runtime → Change runtime type → T4 GPU)
def _get_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        gpu_name = torch.cuda.get_device_name(0) if torch.cuda.device_count() else "GPU"
        print(f"Using device: {device} ({gpu_name})")
        return device
    print("Using device: cpu (no GPU found)")
    return torch.device("cpu")

def load_model_and_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModel.from_pretrained(MODEL_NAME)
    device = _get_device()
    model = model.to(device)
    model.eval()
    return model, tokenizer, device

def _encode_batch(model, tokenizer, device, texts, prefix, batch_size=BATCH_SIZE, max_length=MAX_LENGTH, normalize=NORMALIZE_EMBEDDINGS):
    prefixed = [prefix + (t or "") for t in texts]
    all_emb = []
    for start in range(0, len(prefixed), batch_size):
        batch = prefixed[start : start + batch_size]
        inputs = tokenizer(batch, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            out = model(**inputs)
            mask = inputs["attention_mask"]
            last_hidden = out.last_hidden_state
            summed = (last_hidden * mask.unsqueeze(-1)).sum(dim=1)
            lengths = mask.sum(dim=1, keepdim=True).clamp(min=1e-9)
            emb = (summed / lengths).cpu().numpy()
        if normalize:
            norms = np.linalg.norm(emb, axis=1, keepdims=True)
            norms = np.where(norms == 0, 1.0, norms)
            emb = emb / norms
        all_emb.append(emb)
    return np.vstack(all_emb).astype(np.float32)

def encode_queries(model, tokenizer, device, titles, batch_size=BATCH_SIZE):
    return _encode_batch(model, tokenizer, device, titles, QUERY_PREFIX, batch_size=batch_size)

def encode_passages(model, tokenizer, device, titles, batch_size=BATCH_SIZE):
    return _encode_batch(model, tokenizer, device, titles, PASSAGE_PREFIX, batch_size=batch_size)

## 5. Similarity & ranking

## Load E5 on T4 GPU (run this before Step 1 so Colab shows GPU usage)

In [None]:
# Load E5 onto GPU *now* so T4 is used and Colab Resources show GPU memory (before long retrieval)
print("Loading E5 model onto GPU (T4)...")
model, tokenizer, device = load_model_and_tokenizer()
# Warmup: one batch so GPU allocates memory
encode_queries(model, tokenizer, device, ["warmup"])
print("GPU ready. Run Step 1 (retrieval) next.")

In [None]:
def build_row_splits(candidates_per_row):
    splits = [0]
    for cands in candidates_per_row:
        splits.append(splits[-1] + len(cands))
    return splits

def flatten_candidates(candidates_per_row):
    titles, dicts = [], []
    for cands in candidates_per_row:
        for c in cands:
            titles.append(c.get("title") or "")
            dicts.append(c)
    return titles, dicts

def rank_and_select_best(query_embeddings, passage_embeddings, row_splits, threshold=SIMILARITY_THRESHOLD):
    n_queries = query_embeddings.shape[0]
    results = []
    for i in range(n_queries):
        start, end = row_splits[i], row_splits[i + 1]
        if start >= end:
            results.append((-1, -1.0))
            continue
        q = query_embeddings[i : i + 1]
        p = passage_embeddings[start:end]
        scores = np.dot(p, q.T).ravel()
        best_local = int(np.argmax(scores))
        best_score = float(scores[best_local])
        if best_score < threshold:
            results.append((-1, best_score))
        else:
            results.append((start + best_local, best_score))
    return results

def rank_and_select_best_by_source(query_embeddings, passage_embeddings, row_splits, flat_candidate_dicts, threshold=SIMILARITY_THRESHOLD):
    """Returns per row: (best_global_idx, best_score, best_oa_global_idx, best_oa_score, best_cr_global_idx, best_cr_score)."""
    n_queries = query_embeddings.shape[0]
    results = []
    for i in range(n_queries):
        start, end = row_splits[i], row_splits[i + 1]
        best_oa, best_cr = -1, -1
        best_oa_score, best_cr_score = -1.0, -1.0
        if start >= end:
            results.append((-1, -1.0, -1, -1.0, -1, -1.0))
            continue
        q = query_embeddings[i : i + 1]
        p = passage_embeddings[start:end]
        scores = np.dot(p, q.T).ravel()
        for j in range(len(scores)):
            idx = start + j
            s = float(scores[j])
            src = (flat_candidate_dicts[idx].get("source") or "").strip()
            if src == "OpenAlex" and s > best_oa_score:
                best_oa, best_oa_score = idx, s
            if src == "Crossref" and s > best_cr_score:
                best_cr, best_cr_score = idx, s
        best_local = int(np.argmax(scores))
        best_global = start + best_local
        best_score = float(scores[best_local])
        if best_score < threshold:
            best_global = -1
            best_score = -1.0
        results.append((best_global, best_score, best_oa, best_oa_score, best_cr, best_cr_score))
    return results

## 6. Metadata enrichment (OpenAlex + Crossref by DOI)

In [None]:
def _safe_int(val, default=None):
    if val is None: return default
    try: return int(val)
    except (TypeError, ValueError): return default

def _safe_str(val):
    return "" if val is None else str(val).strip()

def _empty_metadata():
    return {"journal": None, "publication_year": None, "cited_by_count": None, "type": None, "open_access": None}

def _parse_work(data):
    pub_year = _safe_int(data.get("publication_year"))
    cited = _safe_int(data.get("cited_by_count"), 0)
    type_ = _safe_str(data.get("type"))
    journal = None
    if "primary_location" in data and isinstance(data["primary_location"], dict):
        src = data["primary_location"].get("source")
        if isinstance(src, dict): journal = _safe_str(src.get("display_name")) or None
    if not journal and isinstance(data.get("host_venue"), dict): journal = _safe_str(data["host_venue"].get("display_name")) or None
    oa = data.get("open_access")
    if oa is None: oa_str = None
    elif isinstance(oa, dict):
        is_oa = oa.get("is_oa")
        oa_str = ("true" if is_oa else "false") if is_oa is not None else (_safe_str(oa.get("status")) or None)
    else: oa_str = _safe_str(oa) if oa else None
    return {"journal": journal, "publication_year": pub_year, "cited_by_count": cited, "type": type_ or None, "open_access": oa_str}

def fetch_work_by_doi(doi):
    if not doi or not str(doi).strip():
        return _empty_metadata()
    url = OPENALEX_WORKS_BY_DOI.format(doi=doi.strip())
    for attempt in range(MAX_RETRIES):
        try:
            resp = requests.get(url, timeout=REQUEST_TIMEOUT_SEC)
            resp.raise_for_status()
            return _parse_work(resp.json())
        except requests.RequestException:
            if attempt == MAX_RETRIES - 1:
                return _empty_metadata()
            time.sleep(RETRY_BACKOFF_SEC * (attempt + 1))
    return _empty_metadata()

def get_metadata_by_doi_cached(doi):
    """Fetch OpenAlex work by DOI with disk cache to save API credits."""
    doi_key = str(doi).strip() if doi else ""
    if not doi_key:
        return _empty_metadata()
    cache = load_doi_cache() if USE_DISK_CACHE else {}
    if doi_key in cache:
        return cache[doi_key]
    meta = fetch_work_by_doi(doi_key)
    if USE_DISK_CACHE:
        cache[doi_key] = meta
        save_doi_cache(cache)
    return meta

def fetch_metadata_batch_cached(dois, delay_sec=API_DELAY_SEC):
    """Fetch OpenAlex metadata for DOIs; cache + deduplication."""
    dois = [str(d).strip() for d in dois if d and str(d).strip()]
    unique = list(dict.fromkeys(dois))
    cache = load_doi_cache() if USE_DISK_CACHE else {}
    cache_updated = False
    result = {}
    for doi in tqdm(unique, desc="OpenAlex DOI metadata (cached)"):
        if doi in cache:
            result[doi] = cache[doi]
            continue
        meta = fetch_work_by_doi(doi)
        result[doi] = meta
        if USE_DISK_CACHE: cache[doi] = meta; cache_updated = True
        if delay_sec > 0: time.sleep(delay_sec)
    if USE_DISK_CACHE and cache_updated: save_doi_cache(cache)
    return result

def _empty_crossref_meta():
    return {"journal": None, "publication_year": None, "cited_by_count": None}

def _parse_crossref_work(msg):
    ct = msg.get("container-title") or []
    journal = (ct[0] if ct else "").strip() or None
    issued = msg.get("issued") or msg.get("published-print") or {}
    parts = (issued.get("date-parts") or [[]])[0]
    pub_year = _safe_int(parts[0]) if parts else None
    cited = _safe_int(msg.get("is-referenced-by-count"), 0)
    return {"journal": journal, "publication_year": pub_year, "cited_by_count": cited}

def fetch_crossref_by_doi(doi):
    from urllib.parse import quote
    if not doi or not str(doi).strip(): return _empty_crossref_meta()
    url = "https://api.crossref.org/works/" + quote(str(doi).strip(), safe="")
    headers = {"User-Agent": USER_AGENT}
    for attempt in range(MAX_RETRIES):
        try:
            resp = requests.get(url, headers=headers, timeout=REQUEST_TIMEOUT_SEC)
            resp.raise_for_status()
            return _parse_crossref_work(resp.json().get("message") or {})
        except requests.RequestException:
            if attempt == MAX_RETRIES - 1: return _empty_crossref_meta()
            time.sleep(RETRY_BACKOFF_SEC * (attempt + 1))
    return _empty_crossref_meta()

def fetch_crossref_metadata_batch_cached(dois, delay_sec=API_DELAY_SEC):
    dois = [str(d).strip() for d in dois if d and str(d).strip()]
    unique = list(dict.fromkeys(dois))
    cache = load_crossref_doi_cache() if USE_DISK_CACHE else {}
    cache_updated = False
    result = {}
    for doi in tqdm(unique, desc="Crossref DOI metadata (cached)"):
        if doi in cache: result[doi] = cache[doi]; continue
        meta = fetch_crossref_by_doi(doi)
        result[doi] = meta
        if USE_DISK_CACHE: cache[doi] = meta; cache_updated = True
        if delay_sec > 0: time.sleep(delay_sec)
    if USE_DISK_CACHE and cache_updated: save_crossref_doi_cache(cache)
    return result

## 7. Upload input CSV & run pipeline

In [None]:
from google.colab import files
import pandas as pd
from pathlib import Path

# Option A: Upload CSV from your machine
uploaded = files.upload()
input_name = list(uploaded.keys())[0]  # use first uploaded file
input_path = Path(input_name)

# Option B: If you already have file in Colab (e.g. in /content), set it manually:
# input_path = Path("/content/input_manuscripts.csv")

In [None]:
df = pd.read_csv(input_path)
# Required columns (rejected CSV: MANUSCRIPT_ID, REJECT_REASON, TITLE; FIRST_AUTHOR optional for dedupe)
for col in (COL_MANUSCRIPT_ID, COL_TITLE):
    if col not in df.columns:
        raise ValueError(f"CSV must have column: {col}")
# Identify rejected manuscripts: keep only rows with non-empty REJECT_REASON
if COL_REJECT_REASON in df.columns:
    df = df[df[COL_REJECT_REASON].notna() & (df[COL_REJECT_REASON].astype(str).str.strip() != "")]
    print(f"Filtered to rejected manuscripts only: {len(df)} rows (REJECT_REASON present).")
else:
    print("Warning: No REJECT_REASON column; processing all rows.")
manuscript_ids = df[COL_MANUSCRIPT_ID].astype(str).tolist()
titles = df[COL_TITLE].fillna("").astype(str).tolist()
first_authors = df[COL_FIRST_AUTHOR].fillna("").astype(str).tolist() if COL_FIRST_AUTHOR in df.columns else [""] * len(df)
n_rows = len(titles)
print(f"Loaded {n_rows} rows. Running pipeline...")

In [None]:
print("Step 1: Retrieving candidates from OpenAlex and Crossref (parallel, cached + deduplicated)...")
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=2) as ex:
    f_oa = ex.submit(retrieve_candidates_for_titles, titles, first_authors)
    f_cr = ex.submit(retrieve_crossref_candidates_for_titles, titles, first_authors)
    oa_per_row = f_oa.result()
    cr_per_row = f_cr.result()
candidates_per_row = [oa_per_row[i] + cr_per_row[i] for i in range(n_rows)]
row_splits = build_row_splits(candidates_per_row)
passage_titles, flat_candidate_dicts = flatten_candidates(candidates_per_row)

In [None]:
print("Step 2: Computing embeddings on GPU (batch)...")
query_embeddings = encode_queries(model, tokenizer, device, titles)
if passage_titles:
    passage_embeddings = encode_passages(model, tokenizer, device, passage_titles)
else:
    passage_embeddings = np.zeros((0, query_embeddings.shape[1]), dtype=np.float32)

In [None]:
print("Step 3: Ranking by semantic similarity (per source: OpenAlex & Crossref)...")
best_per_row = rank_and_select_best_by_source(query_embeddings, passage_embeddings, row_splits, flat_candidate_dicts)

def _get_doi_or_not_found(global_idx, flat_dicts):
    if global_idx < 0: return "Not Found"
    d = flat_dicts[global_idx]
    return (d.get("doi") or "").strip() or "Not Found"

Crossref_DOI_per_row = [_get_doi_or_not_found(best_per_row[i][4], flat_candidate_dicts) for i in range(n_rows)]
OpenAlex_DOI_per_row = [_get_doi_or_not_found(best_per_row[i][2], flat_candidate_dicts) for i in range(n_rows)]
Crossref_Found = ["Yes" if best_per_row[i][4] >= 0 else "No" for i in range(n_rows)]
OpenAlex_Found = ["Yes" if best_per_row[i][2] >= 0 else "No" for i in range(n_rows)]

best_global_per_row = [best_per_row[i][0] for i in range(n_rows)]
best_score_per_row = [best_per_row[i][1] for i in range(n_rows)]
matched_title_per_row = [(flat_candidate_dicts[best_global_per_row[i]].get("title") or "") if best_global_per_row[i] >= 0 else "" for i in range(n_rows)]
matched_doi_per_row = [(flat_candidate_dicts[best_global_per_row[i]].get("doi") or None) if best_global_per_row[i] >= 0 else None for i in range(n_rows)]
source_per_row = [(flat_candidate_dicts[best_global_per_row[i]].get("source") or "") if best_global_per_row[i] >= 0 else "" for i in range(n_rows)]
score_per_row = [round(best_score_per_row[i], 4) if best_score_per_row[i] >= 0 else None for i in range(n_rows)]

In [None]:
print("Step 4: Fetching metadata for best matches (OpenAlex + Crossref, cached)...")
dois_oa = list(dict.fromkeys([str(matched_doi_per_row[i]).strip() for i in range(n_rows) if source_per_row[i] == "OpenAlex" and matched_doi_per_row[i] and str(matched_doi_per_row[i]).strip()]))
dois_cr = list(dict.fromkeys([str(matched_doi_per_row[i]).strip() for i in range(n_rows) if source_per_row[i] == "Crossref" and matched_doi_per_row[i] and str(matched_doi_per_row[i]).strip()]))
oa_meta = fetch_metadata_batch_cached(dois_oa) if dois_oa else {}
cr_meta = fetch_crossref_metadata_batch_cached(dois_cr) if dois_cr else {}
metadata_per_row = []
for i in range(n_rows):
    doi = matched_doi_per_row[i]
    key = str(doi).strip() if doi else ""
    if source_per_row[i] == "Crossref" and key:
        m = cr_meta.get(key, _empty_crossref_meta())
        metadata_per_row.append({"journal": m.get("journal"), "publication_year": m.get("publication_year"), "cited_by_count": m.get("cited_by_count")})
    elif key:
        m = oa_meta.get(key, _empty_metadata())
        metadata_per_row.append({"journal": m.get("journal"), "publication_year": m.get("publication_year"), "cited_by_count": m.get("cited_by_count")})
    else:
        metadata_per_row.append({"journal": None, "publication_year": None, "cited_by_count": None})

In [None]:
best_link_per_row = [("https://doi.org/" + str(matched_doi_per_row[i])) if matched_doi_per_row[i] else "" for i in range(n_rows)]
best_doi_str = [str(matched_doi_per_row[i]) if matched_doi_per_row[i] else "" for i in range(n_rows)]

OUTPUT_COLUMNS = ["MANUSCRIPT_ID", "Original_Title", "Crossref_DOI", "OpenAlex_DOI", "Crossref_Found", "OpenAlex_Found", "best_doi", "best_journal", "publish_year", "citations", "match_score", "source", "best_link", "matched_title"]
rows = []
for i in range(n_rows):
    meta = metadata_per_row[i]
    rows.append({
        "MANUSCRIPT_ID": manuscript_ids[i],
        "Original_Title": titles[i],
        "Crossref_DOI": Crossref_DOI_per_row[i],
        "OpenAlex_DOI": OpenAlex_DOI_per_row[i],
        "Crossref_Found": Crossref_Found[i],
        "OpenAlex_Found": OpenAlex_Found[i],
        "best_doi": best_doi_str[i],
        "best_journal": meta.get("journal") or "",
        "publish_year": meta.get("publication_year") if meta.get("publication_year") is not None else "",
        "citations": meta.get("cited_by_count") if meta.get("cited_by_count") is not None else "",
        "match_score": score_per_row[i] if score_per_row[i] is not None else "",
        "source": source_per_row[i],
        "best_link": best_link_per_row[i],
        "matched_title": matched_title_per_row[i],
    })
out_df = pd.DataFrame(rows, columns=OUTPUT_COLUMNS)
out_df.to_csv(OUTPUT_FILENAME, index=False)
print(f"Step 5: Saved {len(out_df)} rows to {OUTPUT_FILENAME}")
print("Summary: Crossref found", sum(1 for x in Crossref_Found if x == "Yes"), "; OpenAlex found", sum(1 for x in OpenAlex_Found if x == "Yes"))
out_df.head(10)

## 8. Download output CSV

In [None]:
files.download(OUTPUT_FILENAME)