# Non-academic collaborations


## Load data

Load the Dimensions export used for institution parsing.


In [1]:
from pathlib import Path
import json
import time

import pandas as pd

# If you are missing dependencies, run:
# %pip install pandas openpyxl openai

DATA_REL = Path("data/dimensions/api/raw/combined/202511/df_dimensions.xlsx")


def find_data_path():
    for base in (Path.cwd(), Path.cwd().parent):
        candidate = base / DATA_REL
        if candidate.exists():
            return candidate
    raise FileNotFoundError(f"Could not find {DATA_REL}")


DATA_PATH = find_data_path()
df = pd.read_excel(DATA_PATH)


In [2]:
df.head()

Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,id,title,abstract,acknowledgements,altmetric,altmetric_id,authors,authors_count,...,category_hra,funding_section,editors,subtitles,clinical_trial_ids,resulting_publication_doi,book_doi,book_title,isbn,proceedings_title
0,0,0,pub.1142697354,Adiposity by Differing Measures and the Risk o...,Purpose: To examine the association between ad...,The authors thank the participants of the UK B...,1.0,117236054.0,"[{'affiliations': [{'city': 'Guangzhou', 'city...",12,...,,,,,,,,,,
1,1,180,pub.1107515670,Independent and combined associations of mater...,Background: Limited evidence suggests that exp...,This research has been conducted using the UK ...,8.0,50122557.0,"[{'affiliations': [{'city': 'Bristol', 'city_i...",5,...,,This work was supported by the UK Medical Rese...,,,,,,,,
2,2,172,pub.1111460082,Common conditions associated with hereditary h...,OBJECTIVE: To compare prevalent and incident m...,This research was conducted using the UK Bioba...,897.0,54023435.0,"[{'affiliations': [{'city': 'Exeter', 'city_id...",9,...,"[{'id': '3903', 'name': 'Population & Society'}]",Funding: This study was funded by an award to ...,,,,,,,,
3,3,173,pub.1111458679,Hereditary Hemochromatosis Associations with F...,BACKGROUND: Iron is essential for life but con...,This research has been conducted using the UK ...,54.0,54027903.0,"[{'affiliations': [{'city': 'Exeter', 'city_id...",7,...,"[{'id': '3901', 'name': 'Clinical'}]",This work was supported by an award to D.M. by...,,,,,,,,
4,4,174,pub.1111441688,Genetic Assessment of Potential Long-Term On-T...,BACKGROUND: Although short-term trials have su...,,21.0,54234074.0,"[{'affiliations': [{'city': 'Leicester', 'city...",7,...,"[{'id': '3900', 'name': 'Biomedical'}]",,,,,,,,,


## Optional: check available OpenAI models

Useful if you want to verify available model IDs.


In [3]:
from openai import OpenAI
from pathlib import Path

api_key = Path("../keys/openai.txt").read_text().strip()
client = OpenAI(api_key=api_key)

models = [m.id for m in client.models.list().data]
[m for m in models if "mini" in m or "nano" in m]


['gpt-4o-mini-tts-2025-12-15',
 'gpt-realtime-mini-2025-12-15',
 'gpt-audio-mini-2025-12-15',
 'gpt-4o-mini-2024-07-18',
 'gpt-4o-mini',
 'gpt-4o-mini-realtime-preview-2024-12-17',
 'gpt-4o-mini-audio-preview-2024-12-17',
 'gpt-4o-mini-realtime-preview',
 'gpt-4o-mini-audio-preview',
 'o3-mini',
 'o3-mini-2025-01-31',
 'gpt-4o-mini-search-preview-2025-03-11',
 'gpt-4o-mini-search-preview',
 'gpt-4o-mini-transcribe',
 'gpt-4o-mini-tts',
 'o4-mini-2025-04-16',
 'o4-mini',
 'gpt-4.1-mini-2025-04-14',
 'gpt-4.1-mini',
 'gpt-4.1-nano-2025-04-14',
 'gpt-4.1-nano',
 'codex-mini-latest',
 'o4-mini-deep-research',
 'o4-mini-deep-research-2025-06-26',
 'gpt-5-mini-2025-08-07',
 'gpt-5-mini',
 'gpt-5-nano-2025-08-07',
 'gpt-5-nano',
 'gpt-image-1-mini',
 'gpt-audio-mini',
 'gpt-audio-mini-2025-10-06',
 'gpt-realtime-mini',
 'gpt-realtime-mini-2025-10-06',
 'gpt-5.1-codex-mini',
 'gpt-4o-mini-transcribe-2025-12-15',
 'gpt-4o-mini-transcribe-2025-03-20',
 'gpt-4o-mini-tts-2025-03-20']

## Identify institution-related columns

Pick author/institution columns for parsing.


In [4]:
# Identify author institution-related columns. Adjust this heuristic if needed.
name_keys = ("institution", "affiliation", "affiliations", "organization", "organisation", "org")
author_cols = [c for c in df.columns if "author" in c.lower()]
institution_cols = [c for c in df.columns if any(k in c.lower() for k in name_keys)]
author_institution_cols = [c for c in institution_cols if "author" in c.lower()] or institution_cols

author_institution_cols


['research_org_cities',
 'research_org_countries',
 'research_org_country_names',
 'research_org_names',
 'research_org_state_codes',
 'research_org_state_names',
 'research_org_types',
 'research_orgs']

In [5]:
df[author_institution_cols].head()

Unnamed: 0,research_org_cities,research_org_countries,research_org_country_names,research_org_names,research_org_state_codes,research_org_state_names,research_org_types,research_orgs
0,"[{'id': '6952201', 'name': 'East Melbourne'}, ...","[{'id': 'CN', 'name': 'China'}, {'id': 'AU', '...","['China', 'Australia']","['Sun Yat-sen University', 'Centre for Eye Res...","[{'id': 'AU-VIC', 'name': 'Victoria'}]",['Victoria'],"['Education', 'Facility', 'Healthcare']","[{'acronym': 'CERA', 'city_name': 'East Melbou..."
1,"[{'id': '2654675', 'name': 'Bristol'}, {'id': ...","[{'id': 'GB', 'name': 'United Kingdom'}, {'id'...","['United Kingdom', 'Norway']","['Norwegian Institute of Public Health', 'Univ...",,,"['Government', 'Education', 'Facility']","[{'city_name': 'Bristol', 'country_code': 'GB'..."
2,"[{'id': '2649808', 'name': 'Exeter'}, {'id': '...","[{'id': 'GB', 'name': 'United Kingdom'}, {'id'...","['United States', 'United Kingdom']","['University of Exeter', 'University of Connec...","[{'id': 'US-MD', 'name': 'Maryland'}, {'id': '...","['Connecticut', 'Maryland']","['Government', 'Education', 'Healthcare']","[{'acronym': 'UCHC', 'city_name': 'Farmington'..."
3,"[{'id': '2649808', 'name': 'Exeter'}, {'id': '...","[{'id': 'GB', 'name': 'United Kingdom'}, {'id'...","['United States', 'United Kingdom']","['University of Connecticut', 'University of E...","[{'id': 'US-MD', 'name': 'Maryland'}, {'id': '...","['Connecticut', 'Maryland']","['Government', 'Education', 'Healthcare']","[{'acronym': 'UCHC', 'city_name': 'Farmington'..."
4,"[{'id': '2867714', 'name': 'Munich'}, {'id': '...","[{'id': 'GB', 'name': 'United Kingdom'}, {'id'...","['United Kingdom', 'Germany']","['German Centre for Cardiovascular Research', ...",,,"['Facility', 'Healthcare', 'Education']","[{'city_name': 'Leicester', 'country_code': 'G..."


## Parse institution strings

Normalize, split, and deduplicate institution-like strings from the author fields.


In [6]:
import ast
import json
import re

INSTITUTION_KEYS = {
    "name",
    "institution",
    "organization",
    "organisation",
    "org",
    "affiliation",
    "affiliations",
    "raw_affiliation",
}

KEYWORDS = (
    "university",
    "college",
    "institute",
    "institut",
    "school",
    "hospital",
    "clinic",
    "centre",
    "center",
    "foundation",
    "trust",
    "ministry",
    "government",
    "council",
    "company",
    "inc",
    "ltd",
    "llc",
    "gmbh",
    "corp",
    "plc",
    "sa",
    "srl",
    "bv",
    "ag",
    "laboratory",
    "lab",
)

MAX_INSTITUTIONS_PER_ROW = 200
MAX_CHARS_PER_INSTITUTION = 150

_SPLIT_RE = re.compile(r"[;|]+")


def _normalize_text(s):
    s = re.sub(r"\s+", " ", s)
    s = re.sub(r"\S+@\S+", "", s)
    s = re.sub(r"https?://\S+", "", s)
    return s.strip(" ,;|")


def _contains_keyword(s):
    s = s.lower()
    return any(k in s for k in KEYWORDS)


def _try_parse(s):
    s = s.strip()
    if not s or s[0] not in "[{":
        return None
    try:
        return json.loads(s)
    except Exception:
        try:
            return ast.literal_eval(s)
        except Exception:
            return None


def _extract_strings(obj):
    strings = []
    if obj is None:
        return strings
    if isinstance(obj, (list, tuple, set)):
        for item in obj:
            strings.extend(_extract_strings(item))
        return strings
    if isinstance(obj, dict):
        for key, value in obj.items():
            if key in INSTITUTION_KEYS:
                strings.extend(_extract_strings(value))
        return strings
    strings.append(str(obj))
    return strings


def _select_institution_segments(s):
    cleaned = _normalize_text(s)
    if not cleaned:
        return []
    parts = [p.strip() for p in _SPLIT_RE.split(cleaned) if p.strip()]
    selected = []
    for part in parts:
        comma_parts = [p.strip() for p in part.split(",") if p.strip()]
        if comma_parts:
            keyword_part = next((p for p in comma_parts if _contains_keyword(p)), None)
            part = keyword_part or comma_parts[0]
        part = _normalize_text(part)
        if part:
            if len(part) > MAX_CHARS_PER_INSTITUTION:
                part = part[:MAX_CHARS_PER_INSTITUTION].rstrip()
            selected.append(part)
    if not any(_contains_keyword(p) for p in selected):
        selected = selected[:2]
    return selected


def _extract_institutions_from_value(value):
    if value is None or (isinstance(value, float) and str(value) == "nan"):
        return []
    if isinstance(value, str):
        parsed = _try_parse(value)
        if parsed is not None:
            strings = _extract_strings(parsed)
        else:
            strings = [value]
    else:
        strings = _extract_strings(value)

    candidates = []
    for s in strings:
        candidates.extend(_select_institution_segments(s))
    return candidates


def _dedupe_keep_order(items):
    seen = set()
    out = []
    for item in items:
        key = item.lower()
        if key in seen:
            continue
        seen.add(key)
        out.append(item)
    return out


def build_institution_list(row):
    candidates = []
    for value in row:
        candidates.extend(_extract_institutions_from_value(value))
    candidates = _dedupe_keep_order(candidates)
    return candidates[:MAX_INSTITUTIONS_PER_ROW]


## Build institution lists per row

Create a list field and a compact text preview for the LLM.


In [7]:
# Create a list per row for model input and a compact text preview.
author_institutions = df[author_institution_cols].copy().fillna("")

df["author_institutions_list"] = author_institutions.apply(
    build_institution_list, axis=1
)

df["author_institutions_text"] = df["author_institutions_list"].map(
    lambda xs: " | ".join(xs)
)

df[["author_institutions_text"]].head()


Unnamed: 0,author_institutions_text
0,East Melbourne | Guangzhou | Melbourne | China...
1,Bristol | Oslo | United Kingdom | Norway | Nor...
2,Exeter | Farmington | Baltimore | United Kingd...
3,Exeter | Farmington | Storrs | Baltimore | Uni...
4,Munich | Leicester | Berlin | United Kingdom |...


## LLM client configuration

Set up OpenAI client and shared request settings.


In [8]:
from openai import OpenAI
import asyncio
import ast
import re

try:
    from openai import AsyncOpenAI
    HAS_ASYNC = True
except Exception:
    AsyncOpenAI = None
    HAS_ASYNC = False

KEY_REL = Path("keys/openai.txt")


def find_key_path():
    for base in (Path.cwd(), Path.cwd().parent):
        candidate = base / KEY_REL
        if candidate.exists():
            return candidate
    raise FileNotFoundError(f"Could not find {KEY_REL}")


api_key = find_key_path().read_text().strip()
client = OpenAI(api_key=api_key)
async_client = AsyncOpenAI(api_key=api_key) if HAS_ASYNC else None

MODEL = "gpt-5-nano"
SERVICE_TIER = "priority"
SERVICE_TIER_MAP = {"standard": "default"}

REQUEST_SLEEP = 0.0  # set to a positive number if you need to slow requests
REQUEST_TIMEOUT = 120


# Determinism controls (best-effort; full determinism is not guaranteed)
# Note: gpt-5-mini does not accept temperature, so leave it as None.
TEMPERATURE = None
TOP_P = 1
FREQUENCY_PENALTY = 0
PRESENCE_PENALTY = 0
SEED = 12345  # set to None to disable
RESPONSE_FORMAT = {"type": "json_object"}
MAX_OUTPUT_TOKENS = 2000  # increase if batches are large

SYSTEM_PROMPT = (
    "You classify institution strings as academic organizations, non-academic organizations, "
    "and UK-based companies. Only label entries that are clearly organizations. "
    "Ignore locations (cities, countries, states), departments, faculties, and generic terms "
    "like 'Education', 'Facility', or 'Healthcare'. "
    "Academic: universities, colleges, degree-granting schools, and explicitly "
    "university-affiliated hospitals. "
    "Non-academic: companies, industry, private sector, NGOs, charities, foundations, "
    "government agencies, and non-university hospitals or clinics. "
    "UK company: a private-sector company clearly based in the United Kingdom "
    "(explicit UK/United Kingdom/England/Scotland/Wales/Northern Ireland signal or well-known UK company). "
    "If you are unsure about UK or organizational status, do NOT label it."
)

STRICT_JSON_INSTRUCTIONS = (
    "Return ONLY valid JSON. Use double quotes for all keys and strings. No trailing commas."
)


## LLM request and parsing helpers

Build prompts, call the API, and parse JSON safely.


In [9]:
def _extract_param(msg, marker):
    if marker not in msg:
        return None
    after = msg.split(marker, 1)[1].strip()
    if not after:
        return None
    if after[0] in ("'", '"'):
        quote = after[0]
        end = after.find(quote, 1)
        if end != -1:
            return after[1:end]
    value = after.split()[0].strip()
    if value.startswith(("'", '"')):
        value = value[1:]
    if value.endswith(("'", '"')):
        value = value[:-1]
    return value


def _remove_unsupported_param(msg, kwargs):
    for marker in ("Unsupported parameter:", "unexpected keyword argument"):
        key = _extract_param(msg, marker)
        if key and key in kwargs:
            kwargs.pop(key, None)
            return True
    if "service_tier" in msg and "invalid" in msg:
        if "service_tier" in kwargs:
            kwargs.pop("service_tier", None)
            return True
    return False


def _build_response_kwargs(messages):
    kwargs = {
        "model": MODEL,
        "input": messages,
    }
    if RESPONSE_FORMAT is not None:
        kwargs["response_format"] = RESPONSE_FORMAT
    if TEMPERATURE is not None:
        kwargs["temperature"] = TEMPERATURE
    if TOP_P is not None:
        kwargs["top_p"] = TOP_P
    if FREQUENCY_PENALTY is not None:
        kwargs["frequency_penalty"] = FREQUENCY_PENALTY
    if PRESENCE_PENALTY is not None:
        kwargs["presence_penalty"] = PRESENCE_PENALTY
    if SERVICE_TIER:
        kwargs["service_tier"] = SERVICE_TIER_MAP.get(SERVICE_TIER, SERVICE_TIER)
    if SEED is not None:
        kwargs["seed"] = SEED
    if MAX_OUTPUT_TOKENS is not None:
        kwargs["max_output_tokens"] = MAX_OUTPUT_TOKENS
    return kwargs


def _build_chat_kwargs(messages):
    kwargs = {
        "model": MODEL,
        "messages": messages,
    }
    if RESPONSE_FORMAT is not None:
        kwargs["response_format"] = RESPONSE_FORMAT
    if TEMPERATURE is not None:
        kwargs["temperature"] = TEMPERATURE
    if TOP_P is not None:
        kwargs["top_p"] = TOP_P
    if FREQUENCY_PENALTY is not None:
        kwargs["frequency_penalty"] = FREQUENCY_PENALTY
    if PRESENCE_PENALTY is not None:
        kwargs["presence_penalty"] = PRESENCE_PENALTY
    if SEED is not None:
        kwargs["seed"] = SEED
    if MAX_OUTPUT_TOKENS is not None:
        kwargs["max_tokens"] = MAX_OUTPUT_TOKENS
    return kwargs


def _call_with_param_retry(call_fn, kwargs):
    for _ in range(5):
        try:
            return call_fn(**kwargs)
        except Exception as exc:
            msg = str(exc)
            if _remove_unsupported_param(msg, kwargs):
                continue
            raise
    raise RuntimeError("Repeated unsupported-parameter errors; check request params.")


def call_model(system_prompt, user_prompt):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]
    if hasattr(client, "responses"):
        kwargs = _build_response_kwargs(messages)
        response = _call_with_param_retry(client.responses.create, kwargs)
        return response.output_text

    # Fallback for older clients
    kwargs = _build_chat_kwargs(messages)
    response = _call_with_param_retry(client.chat.completions.create, kwargs)
    return response.choices[0].message.content


async def _acall_with_param_retry(call_fn, kwargs):
    for _ in range(5):
        try:
            return await call_fn(**kwargs)
        except Exception as exc:
            msg = str(exc)
            if _remove_unsupported_param(msg, kwargs):
                continue
            raise
    raise RuntimeError("Repeated unsupported-parameter errors; check request params.")


async def call_model_async(system_prompt, user_prompt):
    if not HAS_ASYNC:
        return await asyncio.to_thread(call_model, system_prompt, user_prompt)

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]
    if hasattr(async_client, "responses"):
        kwargs = _build_response_kwargs(messages)
        response = await _acall_with_param_retry(async_client.responses.create, kwargs)
        return response.output_text

    kwargs = _build_chat_kwargs(messages)
    response = await _acall_with_param_retry(async_client.chat.completions.create, kwargs)
    return response.choices[0].message.content


def _extract_json_fragment(text):
    starts = [i for i in (text.find("{"), text.find("[")) if i != -1]
    if not starts:
        return None
    start = min(starts)

    stack = []
    in_str = False
    escape = False

    for i in range(start, len(text)):
        ch = text[i]
        if in_str:
            if escape:
                escape = False
            elif ch == "\\":
                escape = True
            elif ch == '"':
                in_str = False
            continue

        if ch == '"':
            in_str = True
            continue
        if ch in "{[":
            stack.append(ch)
            continue
        if ch in "]}":
            if not stack:
                continue
            open_ch = stack.pop()
            if (open_ch == "{" and ch != "}") or (open_ch == "[" and ch != "]"):
                return None
            if not stack:
                return text[start : i + 1]

    return None


def _parse_json_object(raw):
    text = raw.strip()
    candidates = [text]
    fragment = _extract_json_fragment(text)
    if fragment and fragment not in candidates:
        candidates.append(fragment)

    last_exc = None
    for candidate in candidates:
        try:
            return json.loads(candidate)
        except json.JSONDecodeError as exc:
            last_exc = exc

    for candidate in candidates:
        try:
            return ast.literal_eval(candidate)
        except Exception as exc:
            last_exc = exc

    raise last_exc


def _build_batch_prompt(rows, strict=False):
    prompt = (
        "Input is a JSON array of rows. Each row is a list of institution strings. "
        "For each row, return an object with keys: non_academic_indices, academic_indices, "
        "uk_company_indices (each a list of 0-based integer positions in that row). "
        "Only include indices when you are confident the string is an organization of that type. "
        "uk_company_indices must be a subset of non_academic_indices. "
        "If there are none, use empty lists. "
        'Return exactly: {"results": [...]} with the same length and order as the input.'
    )
    prompt += "\n"
    if strict:
        prompt += STRICT_JSON_INSTRUCTIONS + "\n"
    prompt += f"Input:\n{json.dumps(rows, ensure_ascii=True)}"
    return prompt


def _extract_results(data):
    if isinstance(data, list):
        return data
    if isinstance(data, dict):
        return data.get("results")
    return None


def classify_non_academic_batch(rows, strict=False):
    if not rows:
        return []

    user_prompt = _build_batch_prompt(rows, strict=strict)
    raw = call_model(SYSTEM_PROMPT, user_prompt)
    data = _parse_json_object(raw)
    results = _extract_results(data)
    if results is None:
        raise ValueError("Missing 'results' in model response.")
    if len(results) != len(rows):
        raise ValueError("Batch result length mismatch.")
    return results


async def classify_non_academic_batch_async(rows, strict=False):
    if not rows:
        return []

    user_prompt = _build_batch_prompt(rows, strict=strict)
    try:
        raw = await asyncio.wait_for(
            call_model_async(SYSTEM_PROMPT, user_prompt),
            timeout=REQUEST_TIMEOUT,
        )
    except asyncio.TimeoutError as exc:
        raise TimeoutError(f"Request timed out after {REQUEST_TIMEOUT}s") from exc
    data = _parse_json_object(raw)
    results = _extract_results(data)
    if results is None:
        raise ValueError("Missing 'results' in model response.")
    if len(results) != len(rows):
        raise ValueError("Batch result length mismatch.")
    return results


## Batch classification settings

Configure batching, output paths, and prep the rows for classification.


In [10]:
# Run classification with caching and batching to reduce API calls.
# If you are missing tqdm, run:
# %pip install tqdm
import warnings
from tqdm import TqdmWarning

warnings.filterwarnings("ignore", category=TqdmWarning)
from tqdm.auto import tqdm
import asyncio

MAX_ROWS = None  # set to an int to test on a subset
BATCH_SIZE = 20
MAX_CONCURRENCY = 16
STRICT_RETRY = False
MAX_FALLBACK_SPLITS = 1

SAVE_PROGRESS_CSV = True
SAVE_PROGRESS_EVERY_N = 5
SAVE_EXCEL_EVERY_N = None
SAVE_EXCEL_AT_END = True

LOAD_CACHE = True
APPEND_CACHE = True

VERBOSE = False

OUTPUT_REL = Path("output/non_academic_flagged.xlsx")


def resolve_output_path():
    for base in (Path.cwd(), Path.cwd().parent):
        candidate = base / OUTPUT_REL
        if candidate.parent.exists():
            return candidate
    return Path.cwd() / "non_academic_flagged.xlsx"


OUTPUT_PATH = resolve_output_path()
PROGRESS_CSV_PATH = OUTPUT_PATH.with_suffix(".progress.csv")
CACHE_PATH = OUTPUT_PATH.with_suffix(".cache.jsonl")

BASE_RESULT = {
    "non_academic_indices": [],
    "academic_indices": [],
    "uk_company_indices": [],
}

failed_rows = []

df_work = df.copy()
if MAX_ROWS is not None:
    df_work = df_work.head(MAX_ROWS).copy()

rows = df_work["author_institutions_list"].tolist()
row_keys = [tuple(r) for r in rows]
unique_keys = list(dict.fromkeys(row_keys))


def make_batches(seq, size):
    for i in range(0, len(seq), size):
        yield seq[i : i + size]


## Normalize model outputs

Convert indices to institution lists and add UK company flags.


In [11]:
def _normalize_indices(value, max_len):
    if not isinstance(value, list):
        return []
    seen = set()
    out = []
    for idx in value:
        try:
            i = int(idx)
        except Exception:
            continue
        if 0 <= i < max_len and i not in seen:
            seen.add(i)
            out.append(i)
    return out


def _indices_to_institutions(row_list, indices):
    return [row_list[i] for i in indices]


def normalize_result(row_list, result):
    if not isinstance(result, dict):
        result = {}
    n = len(row_list)

    non_idx = _normalize_indices(result.get("non_academic_indices", []), n)
    acad_idx = _normalize_indices(result.get("academic_indices", []), n)
    uk_idx = _normalize_indices(result.get("uk_company_indices", []), n)

    non_set = set(non_idx)
    uk_idx = [i for i in uk_idx if i in non_set]

    non_insts = _indices_to_institutions(row_list, non_idx)
    acad_insts = _indices_to_institutions(row_list, acad_idx)
    uk_insts = _indices_to_institutions(row_list, uk_idx)

    org_indices = set(non_idx) | set(acad_idx)
    uk_only = 1 if org_indices and org_indices.issubset(set(uk_idx)) else 0

    return {
        "non_academic_indices": non_idx,
        "academic_indices": acad_idx,
        "uk_company_indices": uk_idx,
        "non_academic_institutions": non_insts,
        "academic_institutions": acad_insts,
        "uk_company_institutions": uk_insts,
        "non_academic_flag": 1 if non_insts else 0,
        "academic_flag": 1 if acad_insts else 0,
        "uk_company_flag": 1 if uk_insts else 0,
        "uk_company_only_flag": uk_only,
        "academic_uk_company_collab_flag": 1 if acad_insts and uk_insts else 0,
    }


if () in unique_keys:
    results_map = {(): normalize_result([], BASE_RESULT)}
    unique_keys = [k for k in unique_keys if k]
else:
    results_map = {}

cache_written_keys = set(results_map.keys())


def load_cache():
    if not LOAD_CACHE or not CACHE_PATH.exists():
        return 0
    loaded = 0
    with CACHE_PATH.open() as handle:
        for line in handle:
            line = line.strip()
            if not line:
                continue
            try:
                record = json.loads(line)
            except Exception:
                continue
            key = tuple(record.get("key", []))
            result = record.get("result", {})
            results_map[key] = normalize_result(list(key), result)
            cache_written_keys.add(key)
            loaded += 1
    return loaded


cached = load_cache()
if cached:
    unique_keys = [k for k in unique_keys if k not in cache_written_keys]
    if VERBOSE:
        print(f"Loaded {cached} cached rows; remaining {len(unique_keys)} unique rows")


def append_cache_entries(batch_keys):
    if not APPEND_CACHE or not batch_keys:
        return
    with CACHE_PATH.open("a") as handle:
        for key in batch_keys:
            if key in cache_written_keys:
                continue
            result = results_map.get(key)
            if result is None:
                continue
            record = {"key": list(key), "result": result}
            handle.write(json.dumps(record, ensure_ascii=True) + "\n")
            cache_written_keys.add(key)


def apply_results_to_df():
    results = []
    for key in row_keys:
        result = results_map.get(key)
        if result is None:
            result = normalize_result(list(key), BASE_RESULT)
        results.append(result)

    df_work["non_academic_indices"] = [r["non_academic_indices"] for r in results]
    df_work["academic_indices"] = [r["academic_indices"] for r in results]
    df_work["uk_company_indices"] = [r["uk_company_indices"] for r in results]

    df_work["non_academic_institutions"] = [r["non_academic_institutions"] for r in results]
    df_work["academic_institutions"] = [r["academic_institutions"] for r in results]
    df_work["uk_company_institutions"] = [r["uk_company_institutions"] for r in results]

    df_work["non_academic_flag"] = [r["non_academic_flag"] for r in results]
    df_work["academic_flag"] = [r["academic_flag"] for r in results]
    df_work["uk_company_flag"] = [r["uk_company_flag"] for r in results]
    df_work["uk_company_only_flag"] = [r["uk_company_only_flag"] for r in results]
    df_work["academic_uk_company_collab_flag"] = [
        r["academic_uk_company_collab_flag"] for r in results
    ]
    return df_work


def save_snapshot(batch_keys=None, progress=False, excel=False):
    if batch_keys:
        append_cache_entries(batch_keys)
    if progress or excel:
        apply_results_to_df()
        if progress and SAVE_PROGRESS_CSV:
            df_work.to_csv(PROGRESS_CSV_PATH, index=False)
        if excel:
            df_work.to_excel(OUTPUT_PATH, index=False)


## Run classification and save incrementally

Classify each batch via the LLM and save progress after each batch.


In [12]:
async def classify_batch_with_fallback(batch_rows, depth=0):
    if not batch_rows:
        return []
    if all(len(r) == 0 for r in batch_rows):
        return [BASE_RESULT for _ in batch_rows]

    try:
        return await classify_non_academic_batch_async(batch_rows)
    except Exception:
        if len(batch_rows) > 1 and depth < MAX_FALLBACK_SPLITS:
            mid = len(batch_rows) // 2
            left = await classify_batch_with_fallback(batch_rows[:mid], depth + 1)
            right = await classify_batch_with_fallback(batch_rows[mid:], depth + 1)
            return left + right
        if STRICT_RETRY:
            try:
                return await classify_non_academic_batch_async(batch_rows, strict=True)
            except Exception:
                failed_rows.extend(batch_rows)
                return [BASE_RESULT for _ in batch_rows]
        failed_rows.extend(batch_rows)
        return [BASE_RESULT for _ in batch_rows]


async def process_batches(batches):
    semaphore = asyncio.Semaphore(MAX_CONCURRENCY)
    save_lock = asyncio.Lock()
    batch_counter = 0

    pbar = tqdm(
        total=len(unique_keys),
        desc="Classifying",
        unit="row",
        mininterval=0.1,
        miniters=1,
    )

    async def handle_batch(batch_keys):
        nonlocal batch_counter
        batch_rows = [list(k) for k in batch_keys]
        async with semaphore:
            batch_results = await classify_batch_with_fallback(batch_rows)
        if len(batch_results) != len(batch_keys):
            raise ValueError("Batch result length mismatch.")
        for key, result in zip(batch_keys, batch_results):
            results_map[key] = normalize_result(list(key), result)
        if REQUEST_SLEEP:
            await asyncio.sleep(REQUEST_SLEEP)
        async with save_lock:
            batch_counter += 1
            do_progress = (
                SAVE_PROGRESS_CSV
                and SAVE_PROGRESS_EVERY_N
                and batch_counter % SAVE_PROGRESS_EVERY_N == 0
            )
            do_excel = (
                SAVE_EXCEL_EVERY_N
                and batch_counter % SAVE_EXCEL_EVERY_N == 0
            )
            save_snapshot(batch_keys, progress=do_progress, excel=bool(do_excel))
        pbar.update(len(batch_keys))

    tasks = [asyncio.create_task(handle_batch(batch)) for batch in batches]
    if tasks:
        await asyncio.gather(*tasks)
    pbar.close()
    return results_map


batches = list(make_batches(unique_keys, BATCH_SIZE))
start = time.time()

if batches:
    results_map = await process_batches(batches)

elapsed = time.time() - start
if row_keys and VERBOSE:
    req_count = len(batches)
    row_rate = len(unique_keys) / elapsed if elapsed > 0 else 0
    req_rate = req_count / elapsed if elapsed > 0 else 0
    print(
        f"Processed {len(unique_keys)} API rows in {elapsed:.1f}s "
        f"across {req_count} requests (~{row_rate:.2f} rows/s, ~{req_rate:.2f} req/s)"
    )

if failed_rows and VERBOSE:
    print(f"Warning: {len(failed_rows)} rows failed to parse and were set empty.")


## Final snapshot

Write the final results after all batches complete.


In [13]:
save_snapshot(progress=True, excel=SAVE_EXCEL_AT_END)
OUTPUT_PATH


PosixPath('/home/porco/Dropbox/20_years_of_ukb/output/non_academic_flagged.xlsx')