# LLM

In [None]:
import openai
import sqlite3
import tiktoken
import math
import time
import re
import os
import base64

In [None]:
# =========================
# GitHub-friendly CONFIG
# =========================
from pathlib import Path
import os
import openai

# Project root (default: current directory). You can override with an env var.
PROJECT_DIR = Path(os.getenv("PFAS_PROJECT_DIR", ".")).resolve()

# Put inputs (e.g., SQLite DB, term lists) under ./data
DATA_DIR = PROJECT_DIR / "data"

# All generated outputs (tables/figures) go under ./outputs
OUTPUT_DIR = PROJECT_DIR / "outputs"
FIG_DIR = OUTPUT_DIR / "figures"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
FIG_DIR.mkdir(parents=True, exist_ok=True)

# Main SQLite database expected at: ./data/PFAS_TK_articles_data.db
DB_PATH = DATA_DIR / "PFAS_TK_articles_data.db"

# Optional: default dataset input expected at: ./data/pfas_dataset.csv
PFAS_DATASET_CSV = DATA_DIR / "pfas_dataset.csv"

# -------------------------
# LLM configuration (NO hard-coded keys)
# -------------------------
# Set these as environment variables before running:
#   LLM_API_KEY   (required)
#   LLM_BASE_URL  (optional; e.g., UF endpoint)
#   LLM_MODEL     (optional; default below)
LLM_API_KEY = os.getenv("LLM_API_KEY")
LLM_BASE_URL = os.getenv("LLM_BASE_URL")  # e.g., LLM_BASE_URL if LLM_BASE_URL else None
LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o-mini")

if not LLM_API_KEY:
    raise ValueError(
        "Missing LLM_API_KEY env var. Set it before running (do NOT paste keys into the notebook)."
    )

client_kwargs = {'api_key': LLM_API_KEY}
if LLM_BASE_URL:
    client_kwargs['base_url'] = LLM_BASE_URL
client = openai.OpenAI(**client_kwargs)

In [None]:
messages = [
    {"role": "system", "content": "You are a Researcher that extracts relevant data from research papers relating to Toxicokinetics"},
    {"role": "user", "content": "1+1+54="}
  ]

# Use the new API interface

completion = client.chat.completions.create(
  model=LLM_MODEL,
  messages=messages
)

print(completion.choices[0].message)

if(completion):
    messages.append(completion.choices[0].message)
    messages.append({"role": "user", "content": "1-1+34="})
    completion = client.chat.completions.create(
      model=LLM_MODEL,
      messages=messages
    )
    print(completion.choices[0].message)
    print(completion)

In [None]:
messages = [
    {"role": "system", "content": "You are a Researcher that tells if a research papers is relatied to Toxicokinetics or not"},
    {"role": "user", "content": "1+1+54="}
  ]

# Use the new API interface

completion = client.chat.completions.create(
  model=LLM_MODEL,
  messages=messages
)

print(completion.choices[0].message)

if(completion):
    messages.append(completion.choices[0].message)
    messages.append({"role": "user", "content": "1-1+34="})
    completion = client.chat.completions.create(
      model=LLM_MODEL,
      messages=messages
    )
    print(completion.choices[0].message)
    print('----')
    print(completion)

## Checking relevance

In [None]:
def ChatGptApi(messages):
    completion = openai.ChatCompletion.create(
      model=LLM_MODEL,
      messages=messages,
        temperature=0,
        max_tokens=10,
    )
    return completion.choices[0]

In [None]:
prompt_1="""Given the title and abstract of a research paper, determine if the paper is relevant to the fields of toxicokinetics or pharmacokinetics.

**Toxicokinetics** involves the study of how toxic substances are absorbed, distributed, metabolized, and excreted in living organisms. This includes:
- Measuring concentrations of toxic substances in biological tissues over time.
- Modeling how these substances move through the body.
- Studying the effects of these substances on living organisms.

**Pharmacokinetics** involves the study of how drugs are absorbed, distributed, metabolized, and excreted in living organisms. This includes:
- Measuring concentrations of drugs in biological tissues over time.
- Modeling how these drugs move through the body.
- Studying the effects of these drugs on living organisms.

**Relevance Criteria:** A paper is considered relevant to toxicokinetics or pharmacokinetics if it includes any of the following:
- Discussions on the absorption, distribution, metabolism, or excretion of substances (whether toxic or drugs) in living organisms.
- Measurements or models of concentrations of these substances in biological tissues over time.
- Effects of these substances on living organisms.
- Bioaccumulation, biomagnification, or movement through food webs.
- Impact on health risks or safety due to the presence of these substances in the environment or organisms.
- Any involvement of toxic substances (e.g., PFAS) or drugs in biological tissues or organisms, even if the primary focus is on environmental impact or remediation.
- PFAS is involved

Title: [Insert Paper Title Here]

Abstract: [Insert Paper Abstract Here]

Make the program high recall, it's fine if I have false positives but I want to get all the Yes(cover all articles).

Based on these definitions and criteria, is this paper relevant to toxicokinetics or pharmacokinetics or involvement of toxins (PFAS) or PFAS involved or biological tissues? Answer 'Yes' if the content aligns with these fields, otherwise answer 'No.'"""

In [None]:
from openai import OpenAI
import time

# set up client once
client = OpenAI(api_key=LLM_API_KEY, base_url=LLM_BASE_URL if LLM_BASE_URL else None)

def isRelevantPaper(prompt):
    messages = [
        {"role": "system", "content": "You are a Researcher that tells if a research paper is related to Toxicokinetics of PFAS compounds or not."},
        {"role": "user", "content": prompt},
    ]

    tries = 0
    while tries < 3:   # max retries
        try:
            completion = client.chat.completions.create(
                model=LLM_MODEL,
                messages=messages,
                temperature=0
            )
            result = completion.choices[0].message.content.lower().strip()
            print("Model output:", result)

            if "yes" in result:
                return 1, result
            elif "no" in result:
                return 0, result
            else:
                return 2, result  # uncertain

        except Exception as e:
            print("Error occurred:", e)
            tries += 1
            time.sleep(5)

    return "Error", ""

In [None]:
conn = sqlite3.connect(str(DB_PATH))
cursor = conn.cursor()

cursor.execute("PRAGMA table_info(articles_data);")
columns = cursor.fetchall()
col_names = [col[1] for col in columns]
print("Columns in articles_data:", col_names)

cursor.execute("SELECT* FROM articles_data;")
rows = cursor.fetchall()
print(rows)
# Close the connection to the database
conn.close()

In [None]:
conn = sqlite3.connect(str(DB_PATH))
cursor = conn.cursor()

cursor.execute("SELECT id, pmid, doi, title, abstract FROM articles_data")
rows = cursor.fetchall()

not_related_pmids = []  # store PMIDs where is_related = 0

for row in rows:
    id, pmid, doi, title, abstract = row

    prompt_temp = prompt_1.replace("[Insert Paper Title Here]", title or "")
    prompt_temp = prompt_temp.replace("[Insert Paper Abstract Here]", abstract or "")

    isRelevant, res = isRelevantPaper(prompt_temp)

    cursor.execute("UPDATE articles_data SET is_related=? WHERE id=?", (isRelevant, id))
    conn.commit()

    print(f"PMID {pmid}: is_related={isRelevant}")

conn.close()

# Double check the not related articles manually
if not_related_pmids:
    print("\nPMIDs marked as NOT related:")
    for pmid in not_related_pmids:
        print(pmid)

In [None]:
import pandas as pd
# Connect to the SQLite database
conn = sqlite3.connect(str(DB_PATH))

# Load the entire articles_data table into a DataFrame
df = pd.read_sql_query("SELECT * FROM articles_data", conn)

# Save to Excel
output_path = str(OUTPUT_DIR / "articles_data_list.xlsx")
df.to_excel(output_path, index=False)

conn.close()

print("Exported articles_data to:", output_path)

## Prompts used for data extraction

In [None]:
import json

# 1. Read the txt file
with open(str(DATA_DIR / "organs_tissues_list.txt"), "r", encoding="utf-8") as f:
    data = f.read()

# 2. Parse the string into a Python dictionary
organs_dict = json.loads(data)

# 3. Save as JSON
with open(str(DATA_DIR / "organs_tissues_list.json"), "w", encoding="utf-8") as f:
    json.dump(organs_dict, f, indent=4, ensure_ascii=False)

print("Converted to organs_tissues_list.json")

In [None]:
import json

# 1. Read the txt file
with open(str(DATA_DIR / "species_list.txt"), "r", encoding="utf-8") as f:
    data = f.read()

# 2. Parse the string into a Python dictionary
organs_dict = json.loads(data)

# 3. Save as JSON
with open(str(DATA_DIR / "species_list.json"), "w", encoding="utf-8") as f:
    json.dump(organs_dict, f, indent=4, ensure_ascii=False)

print("Converted to species_list.json")

In [None]:
import json, re
from pathlib import Path

# Load your PFAS list
pfas = json.loads(Path(str(DATA_DIR / "pfas_search_terms.json")).read_text(encoding="utf-8"))

def make_target_chem_block(pfas_dict, max_syn=6):
    """
    Build a block of canonical PFAS names with a few synonyms for prompt context
    """
    lines = []
    for canon, syns in pfas_dict.items():
        syns_clean = [str(s).strip() for s in syns if str(s).strip()]
        line = f"{canon} → " + "; ".join(syns_clean[:max_syn])
        lines.append(line)
    return "\n".join(lines)

target_chem_block = make_target_chem_block(pfas)

In [None]:
import json

# Paths to your JSON files
species_path = str(DATA_DIR / "species_list.json")
organs_path = str(DATA_DIR / "organs_tissues_list.json")

# Load them into Python dictionaries
with open(species_path, "r", encoding="utf-8") as f:
    species_dict = json.load(f)

with open(organs_path, "r", encoding="utf-8") as f:
    organs_dict = json.load(f)

print("JSONs loaded")
print("Species categories:", species_dict.keys())
print("Organs categories:", organs_dict.keys())

In [None]:
prompt_main_data_extraction_v2 = """Task: Read the research article text and extract the requested items in the exact schema below.

Article Text:
{article_text}

Important constraints (read carefully):
• Chemicals: Report ONLY chemicals that appear in the “Target PFAS list” provided below (or any of their listed synonyms/CAS/SMILES). If none match, write “No target PFAS mentioned”.
• Canonical names: When a match occurs, output the chemical’s CANONICAL name exactly as shown in the list, not the synonym found in the text. Optionally include the matched synonym in parentheses.
• Keep answers concise and factual; do not infer.

Target PFAS list (canonical → synonyms & identifiers):
{target_chem_block}

—— OUTPUT SCHEMA ——
General Extraction
- Chemicals:
  Canonical PFAS names only (see above rule)
- Species:
  Only extract species names from the provided canonical list:
  {list(species_dict.keys())}
- Plasma Consideration:
  [Yes/No]
- Organs and Tissues Involved:
  Only extract from the provided canonical list:
  {list(organs_dict.keys())}
- Research Outcome:
  [1–3 sentences; include numeric metrics if present]
- Type of Study:
  [Human/Animal/In-vitro]

If Type of Study = Human, also provide:
- Study Type:
- Location:
- Gender of Subjects:
- Age of Subjects:
- Number of Subjects:
- Experimental Samples Involved:

If Type of Study = Animal, also provide:
- Study Type: [Experimental/Observational, species names only]
- Gender of Subjects:
- Age at Exposure:
- Route of Exposure:

Formatting rules:
• Use the headings exactly as above, in the same order.
• If an item is not reported, write “Not stated”.
• Do not include any other commentary.

"""

## Data Extraction

In [None]:
available_encodings = tiktoken.list_encoding_names()
print(available_encodings)

In [None]:
def count_tokens(input_string, model):
    # Initialize the tokenizer for the GPT-4 model
    encoding = tiktoken.encoding_for_model(model)

    # Encode the input string into tokens
    tokens = encoding.encode(input_string)

    # Return the number of tokens
    return len(tokens)

In [None]:
conn = sqlite3.connect(str(DB_PATH))
cursor = conn.cursor()

query = '''
    SELECT * from articles_blob_data
'''

# Optionally, retrieve and display the blob data to verify
cursor.execute(query)
rows = cursor.fetchall()

conn.close()  # Close the connection

In [None]:
PMID_unprocessed = {'417812','2906022', '3684504','6759044',  '6819698', '7299662', '7337230', '7639356', '7849929', '7884142', '7974521', '8229349', '8250967', '8516773', '9269454', '9862284'} # Randomly select one PMID to test prompt_main_data_extraction_V2

In [None]:
PMID_unprocessed = {'11519538', '11719891', '11855757', '15328768', '15366585', '16466536', '20556880', '2093123', '2374085', '25454233', '30411895', '30528102', '30631142', '31549993', '31568513', '32495786', '32897586', '33017053', '33382826', '33605484', '33647664', '34854961', '35138827', '35324171', '35580034', '37220906', '37984148', '38801906', '39504592', '39542374', '39556161', '40239480', '417812'}

In [None]:
PMID_unprocessed = {'2781142', '2901469', '3089945', '3098413', '3140922', '3179494', '3575876', '3759552', '6470567', '6515128', '6547941', '6787865',
                    '7153234', '7784553', '7849907', '7849908', '7849919', '7849920', '7978424', '8047981', '8079710', '8079711', '8097775', '8134252',
                    '8158684', '8333387', '8336905', '8429782', '8482678', '8611942', '8632927', '8646338', '8795104', '8870975', '8890829', '9084907',
                    '9177987', '9185039', '9219835', '9341714', '9383597', '9419473', '9625556', '9655737'}

In [None]:
# Connect to the SQLite database
conn = sqlite3.connect(str(DB_PATH))
cursor = conn.cursor()

# Automatically fetch all PMIDs where is_related = 1
cursor.execute("SELECT pmid FROM articles_data WHERE is_related = 1")
pmid_rows = cursor.fetchall()

# Flatten into a Python list of strings
PMID_unprocessed = [str(row[0]) for row in pmid_rows]

print("Fetched PMIDs:", PMID_unprocessed)
print("Total count:", len(PMID_unprocessed))

conn.close()

In [None]:
PMID_unprocessed = [str(p) for p in PMID_unprocessed]

# Find the position of the cutoff PMID
cutoff = "32790152"
if cutoff in PMID_unprocessed:
    cutoff_index = PMID_unprocessed.index(cutoff)
    PMID_unprocessed = PMID_unprocessed[cutoff_index + 1:]  # exclude cutoff itself
else:
    print(f"{cutoff} not found in PMID_unprocessed")

print("New total count:", len(PMID_unprocessed))
print("First 10 PMIDs after cutoff:", PMID_unprocessed[:10])

In [None]:
# Connect to the SQLite database
conn = sqlite3.connect(str(DB_PATH))
cursor = conn.cursor()

input_formating = {}

# Convert the set to a tuple for use in the SQL query
pmid_tuple = tuple(PMID_unprocessed)

# Query to retrieve matching rows from articles_blob_data
query = '''
    SELECT * FROM articles_blob_data WHERE PMID IN ({placeholders})
'''.format(placeholders=','.join(['?'] * len(pmid_tuple)))

# Execute the query with the DOIs as parameters
cursor.execute(query, pmid_tuple)
pmid_unprocessed_rows = cursor.fetchall()

for pmid_unprocessed_row in pmid_unprocessed_rows:
    pmid = pmid_unprocessed_row[0]  # Assuming pmid is in the first column
    body = pmid_unprocessed_row[3]  # Assuming body is in the third column

    cursor.execute('SELECT COUNT(*) FROM articles_data WHERE pmid = ?', (pmid,))
    present_pmid_count = cursor.fetchone()[0]
    # If PMID is not present, insert the data
    if present_pmid_count == 0:
        print("Not in relavent document")
        print(body[:10])  # Display the body

        # cursor.execute(
        #     '''
        #     INSERT INTO relavent_articles_data (doi, body)
        #     VALUES (?, ?)
        #     ''', (doi, body)
        # )
        # conn.commit()  # Commit the insert operation

conn.close()  # Close the connection

In [None]:
import pandas as pd
df_preview = pd.DataFrame(pmid_unprocessed_rows)
print(df_preview.head())

In [None]:
conn = sqlite3.connect(str(DB_PATH))
cursor = conn.cursor()

input_formating = {}

for i,row in enumerate(pmid_unprocessed_rows):
    pmid = row[0]  # Assuming pmid is in the first column
    article_text = row[3]

    #print(article_text)

    # Check if article_text is bytes, then decode it; otherwise, it's assumed to be a string
    if isinstance(article_text, bytes):
        article_text = article_text.decode('utf-8')

    sentences = article_text.split(". ")
    context_length = 128000

    # Calculate token counts for each sentence
    tokens_count = count_tokens(article_text, "gpt-oss-120b")
    if(tokens_count>context_length):
        partitions = math.ceil(tokens_count/context_length)
        print(i, tokens_count, len(article_text.split()), row[1], "---------------------------------------------------------", partitions)
        approx_partition_size = math.ceil(len(article_text) / partitions)
        input_formating[pmid]=[]
        start = 0
        for k in range(partitions):
            # remaining_text = article_text[start:]
            if k == partitions - 1:
                # For the last partition, take the remaining text
                end = len(article_text)
            else:
                # Find the approximate end position
                end = start + approx_partition_size

                # Find the nearest line break to the desired partition size
                # end = article_text.rfind('.', start, end)
                if end == -1:  # If no line break is found, use the approx size
                    end = start + approx_partition_size


            prompt_main_data_extraction_v2_temp = prompt_main_data_extraction_v2
            input_formating[row[1]].append(prompt_main_data_extraction_v2_temp.replace("{article_text}",article_text[start: end]))
            start = end + 1
    else:
        print(i, tokens_count, len(article_text.split()), row[1])

        prompt_main_data_extraction_v2_temp = prompt_main_data_extraction_v2
        input_formating[pmid]=[prompt_main_data_extraction_v2_temp.replace("{article_text}",article_text)]
        print(pmid,'|',article_text)

conn.close()

In [None]:
def _canonical_block_from_dict_keys(d):
    # d can be a nested dict (e.g., categories -> items) or a flat dict
    # We only list canonical names (keys if flat, items if nested)
    names = []
    for k, v in d.items():
        # If value is a list of names inside a category, include those names
        if isinstance(v, (list, tuple, set)):
            names.extend([str(x).strip() for x in v if str(x).strip()])
        else:
            # Treat top-level keys as canonical names
            names.append(str(k).strip())
    # unique, sorted
    names = sorted({n for n in names if n})
    # pretty list (one per line with dash)
    return "\n  - " + "\n  - ".join(names) if names else "\n  - (none)"

In [None]:
def build_extraction_prompt(article_text, target_chem_block, species_dict, organs_dict):
    species_list = _canonical_block_from_dict_keys(species_dict)
    organs_list  = _canonical_block_from_dict_keys(organs_dict)

    return prompt_main_data_extraction_v2.format(
        article_text      = article_text,
        target_chem_block = target_chem_block,
        species_list      = species_list,
        organs_list       = organs_list,
    )

In [None]:
import time
from openai import OpenAI, RateLimitError, APIConnectionError, APIError

client = OpenAI(api_key=LLM_API_KEY, base_url=LLM_BASE_URL if LLM_BASE_URL else None)

def outputValuesExtractor(prompt_main, model=LLM_MODEL, temperature=0):
    """
    Call the model once and return the text. Uses the new OpenAI SDK (v1).
    Retries a few times on transient errors. Returns 'ErrorInPFASTK: ...' on failure.
    """
    messages = [
        {"role": "system",
         "content": "You are an assistant trained to extract specific information from research articles related to pharmacokinetics and toxicokinetics."},
        {"role": "user", "content": prompt_main},
    ]

    max_retries = 4
    backoff = 2.0

    for attempt in range(max_retries):
        try:
            resp = client.chat.completions.create(
                model=model,
                messages=messages,
                temperature=temperature,
                # do NOT set max_tokens to a negative or zero; omitting is fine
            )
            return (resp.choices[0].message.content or "").strip()

        except (RateLimitError, APIConnectionError, APIError) as e:
            # transient API errors — backoff and retry
            if attempt == max_retries - 1:
                return f"ErrorInPFASTK: {e}"
            time.sleep(backoff)
            backoff *= 1.5

        except Exception as e:
            # anything else: stop and return the error
            return f"ErrorInPFASTK: {e}"

In [None]:
import pandas as pd
from tqdm import tqdm

OUTPUT_SEP = "\n<output-seperator>\n"

def build_result_table_simple(input_formating, extractor_func):
    rows = []
    pmids = list(input_formating.keys())

    for pmid in tqdm(pmids, desc="Extracting"):
        prompts = input_formating.get(pmid, [])
        if not prompts:
            rows.append({
                "pmid": pmid,
                "output": "",
                "combined_output": "",
                "chemicals": None,
                "species": None,
                "plasma_consideration": None,
                "organs": None,
                "research_outcome": None,
                "type_of_study": None,
            })
            continue

        part_outputs = []
        for p in prompts:
            try:
                res = extractor_func(p)
            except Exception as e:
                res = f"ErrorInPFASTK: {e}"
            part_outputs.append(res if res else "")

        first_output = part_outputs[0] if part_outputs else ""
        combined_output = OUTPUT_SEP.join(part_outputs)

        rows.append({
            "pmid": pmid,
            "output": first_output,
            "combined_output": combined_output,
            "chemicals": None,
            "species": None,
            "plasma_consideration": None,
            "organs": None,
            "research_outcome": None,
            "type_of_study": None,
        })

    df = pd.DataFrame(rows, columns=[
        "pmid", "output", "combined_output",
        "chemicals", "species", "plasma_consideration", "organs",
        "research_outcome", "type_of_study"
    ])


    display(df.head(5))
    return df

In [None]:
df_result = build_result_table_simple(input_formating, outputValuesExtractor)

In [None]:
import re
import pandas as pd

def _clean_list(text):
    if not text:
        return None
    text = text.strip()
    text = re.sub(r'^\[|\]$', '', text).strip()   # strip surrounding brackets if present
    parts = [p.strip(" -•\t\r\n,") for p in re.split(r'[,;\n]+', text) if p.strip()]
    return parts or None

def _norm(s):
    if s is None:
        return ""
    s = s.replace("**", "")                # drop markdown bold
    s = re.sub(r"[ \t]+", " ", s)          # normalize spaces
    return s

# Headings WITHOUT the trailing colon; we'll accept either "Header" or "Header:"
HEADERS = [
    "chemicals",
    "species",
    "plasma consideration",
    "organs and tissues involved",
    "research outcome",
    "type of study",
    "gender of subjects",
    "age of subjects",
    "route of exposure",
    "number_of_subjects",
    "experimental_samples_involved"
]

def parse_combined_output(text):
    """
    Parse 'combined_output' (or 'output') into fields, tolerating headers
    with or without a trailing colon.
    """
    if not text:
        return { "chemicals": None, "species": None, "plasma_consideration": None,
                 "organs": None, "research_outcome": None, "type_of_study": None,
                 "gender_of_subjects": None, "age_at_exposure": None, "route_of_exposure": None,
                 "number_of_subjects": None, "experimental_samples_involved": None}

    # Normalize and lowercase working copy (same length as original t_raw).
    t_raw = _norm(text)
    t_lc  = t_raw.lower()

    # Find header spans (start & end of the header token itself)
    spans = []
    for h in HEADERS:
        # match e.g. "research outcome" or "research outcome:" (case-insensitive)
        pat = rf"(?s)\b{re.escape(h)}\b\s*:?"
        m = re.search(pat, t_lc)
        if m:
            spans.append((h, m.start(), m.end()))

    if not spans:
        return { "chemicals": None, "species": None, "plasma_consideration": None,
                 "organs": None, "research_outcome": None, "type_of_study": None,
                 "gender_of_subjects": None, "age_at_exposure": None, "route_of_exposure": None,
                 "number_of_subjects": None, "experimental_samples_involved": None}

    # Sort headers by position
    spans.sort(key=lambda x: x[1])

    # Build mapping header -> content between it and the next header
    sections = {}
    for i, (h, start, hdr_end) in enumerate(spans):
        next_start = spans[i+1][1] if i+1 < len(spans) else len(t_raw)
        chunk = t_raw[hdr_end:next_start].strip()
        sections[h] = chunk

    # Extract and clean fields
    chemicals            = _clean_list(sections.get("chemicals"))
    species              = _clean_list(sections.get("species"))
    plasma_consideration = sections.get("plasma consideration")
    if plasma_consideration:
        plasma_consideration = plasma_consideration.splitlines()[0].strip()

    organs               = _clean_list(sections.get("organs and tissues involved"))

    research_outcome     = sections.get("research outcome")
    if research_outcome:
        research_outcome = re.sub(r"[ \t]+", " ", research_outcome).strip()

    type_of_study        = sections.get("type of study")
    if type_of_study:
        type_of_study = re.sub(r'^\[|\]$', '', type_of_study.splitlines()[0]).strip()

    gender = sections.get("gender of subjects")
    if gender:
        gender = re.sub(r"\*\*|_", "", gender.strip())

    age = sections.get("age at exposure")
    if age:
        age = re.sub(r"\*\*|_", "", age.strip())

    route = sections.get("route of exposure")
    if route:
        route = re.sub(r"\*\*|_", "", route.strip())

    subject = sections.get("number of subjects")
    if subject:
        subject = re.sub(r"\*\*|_", "", subject.strip())

    sample = sections.get("experimental samples involved")
    if sample:
        sample = re.sub(r"\*\*|_", "", samples.strip())

    return {
        "chemicals": chemicals,
        "species": species,
        "plasma_consideration": plasma_consideration,
        "organs": organs,
        "research_outcome": research_outcome,
        "type_of_study": type_of_study,
        "gender_of_subjects": gender,
        "age_at_exposure": age,
        "route_of_exposure": route,
        "number_of_subjects": subject,
        "experimental_samples_involved": sample
    }

# --- Apply to your dataframe again ---
source_text = df_result["combined_output"].fillna(df_result["output"])
parsed = source_text.apply(parse_combined_output)
parsed_df = pd.DataFrame(parsed.tolist())

# For display, join list-like fields
for col in ["chemicals", "species", "organs"]:
    parsed_df[col] = parsed_df[col].apply(lambda x: "; ".join(x) if isinstance(x, list) else x)

df_result[["chemicals", "species", "plasma_consideration",
           "organs", "research_outcome", "type_of_study","gender_of_subjects", "age_at_exposure", "route_of_exposure", "number_of_subjects", "experimental_samples_involved"]] = parsed_df[
               ["chemicals","species","plasma_consideration","organs","research_outcome","type_of_study", "gender_of_subjects", "age_at_exposure", "route_of_exposure", "number_of_subjects", "experimental_samples_involved"]
           ]

display(df_result)

In [None]:
df_result.to_csv(str(OUTPUT_DIR / "df_result_last.csv"), index=False, encoding = "utf-8")

In [None]:
Token_lengths = []
for input_formats in input_formating:
    single_token_count=0
    for prompts in input_formating[input_formats]:
        single_token_count+=count_tokens(prompts, "gpt-4-turbo")
    Token_lengths.append(single_token_count)

TokenCountsMillions=sum(Token_lengths)/1000000
print(str(TokenCountsMillions) + " Million Tokens = "+str((TokenCountsMillions)*5)+" $")

# Figure

## Import

In [None]:
!pip install plotly -q

In [None]:
!pip install pubchempy

In [None]:
!pip install pubchempy ftfy -q

In [None]:
import re, os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import BoundaryNorm
import plotly.graph_objects as go
import matplotlib.cm as cm
import matplotlib.colors as colors

In [None]:
from google.colab import files
uploaded = files.upload()

In [None]:
# Load the dataset corresponding to Table S2
df = pd.read_csv("table_s2.csv")

In [None]:
from google.colab import files
uploaded = files.upload()

In [None]:
# Load the dataset corresponding to Table S1
PFAS430 = pd.read_csv("table_S1.csv")

## Setting

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap, BoundaryNorm
import seaborn as sns
import re

# ----------------------------
# CONFIG
# ----------------------------
CSV_PATH = "/content/pfas_dataset.csv"   # file path
CHEM_COL = "chemicals"                   # as provided
SPECIES_COL = "species"                  # as provided
ORGANS_COL = "organs"                    # as provided
PMID_COL = "pmid"                        # used to count "records"

TOPN = 20
SAVE_DPI = 600
SAVE_KW = dict(format="tiff", dpi=SAVE_DPI, pil_kwargs={"compression": "tiff_lzw"})

plt.rcParams.update({
    "figure.dpi": 110,
    "savefig.dpi": SAVE_DPI,
    "font.size": 12,          # base font
    "axes.titlesize": 20,     # title
    "axes.labelsize": 16,     # x/y labels
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
})

# =========================
# SPECIES_NORMALIZE
# =========================
SPECIES_NORMALIZE = {
    # Humans & Primates
    "human": "Human", "humans": "Human",
    "nonhuman primate": "Nonhuman primate", "non human primate": "Nonhuman primate",
    "monkey": "Monkey", "monkeys": "Monkey",

    # Rodents / Lab Animals
    "rat": "Rat", "rats": "Rat",
    "mouse": "Mouse", "mice": "Mouse", "mice (mus musculus)": "Mouse",
    "hamster": "Hamster", "hamsters": "Hamster",
    "rabbit": "Rabbit", "rabbits": "Rabbit",

    # Domestic Mammals / Livestock
    "dog": "Dog", "dogs": "Dog",
    "cat": "Cat", "cats": "Cat",
    "pig": "Pig", "pigs": "Pig", "porcine": "Pig", "porcine (pig)": "Pig",
    "cattle": "Cattle", "beef cattle": "Cattle", "dairy cows": "Cattle", "dairy cattle": "Cattle", "bovine": "Cattle",
    "sheep": "Sheep", "ram": "Sheep",
    "goat": "Goat", "goats": "Goat",
    "horse": "Horse", "horses": "Horse",
    "camel": "Camel",

    # Poultry / Birds
    "chicken": "Chicken", "chickens": "Chicken", "broiler": "Chicken",
    "duck": "Duck", "ducks": "Duck",
    "turkey": "Turkey",
    "quail": "Quail",
    "bird": "Bird", "birds": "Bird",

    # Fish (General & Specific)
    "fish": "Fish", "fishes": "Fish",
    "zebrafish": "Zebrafish", "zebra fish": "Zebrafish",
    "carp": "Carp", "common carp": "Carp", "crucian carp": "Carp", "grass carp": "Carp", "bighead carp": "Carp",
    "catfish": "Catfish", "yellow cat fish": "Catfish",
    "tilapia": "Tilapia",
    "minnow": "Minnow",
    "goby": "Goby",
    "mudskipper": "Mudskipper",
    "mullet": "Mullet",
    "shad": "Shad",
    "rockfish": "Rockfish",
    "shark": "Shark", "sharks": "Shark",
    "snakehead": "Snakehead",
    "midge": "Midge", # Often aquatic larvae, grouping near fish/aquatic context or keep separate

    # Marine Mammals & Large Wildlife
    "dolphin": "Dolphin",
    "porpoise": "Porpoise",
    "whale": "Whale", "beluga whale": "Whale",
    "seal": "Seal", "ringed seal": "Seal",
    "sea lion": "Sea lion",
    "sea otter": "Sea otter",
    "polar bear": "Polar bear", "polar bears": "Polar bear",
    "walrus": "Walrus",
    "otter": "Otter",
    "fox": "Fox",
    "mink": "Mink",
    "caribou": "Caribou",
    "wolf": "Wolf",
    "tiger": "Tiger",

    # Invertebrates (Aquatic & Terrestrial)
    "mussel": "Mussel", "mussels": "Mussel",
    "oyster": "Oyster", "oysters": "Oyster",
    "clam": "Clam", "clams": "Clam",
    "bivalve": "Bivalve", "bivalves": "Bivalve", "bivalvia": "Bivalve",
    "crustacean": "Crustacean", "crustaceans": "Crustacean",
    "gastropod": "Gastropod", "gastropods": "Gastropod",
    "cephalopod": "Cephalopod", "cephalopods": "Cephalopod",
    "shrimp": "Shrimp", "prawn": "Shrimp",
    "crab": "Crab", "blue crab": "Crab",
    "lobster": "Lobster",
    "crayfish": "Crayfish",
    "gammarid": "Gammarid",
    "mysid": "Mysid",
    "amphipod": "Amphipod",
    "diporeia": "Amphipod", # Specific amphipod
    "zooplankton": "Zooplankton",
    "plankton": "Plankton",
    "bentho": "Benthos",
    "lugworm": "Worm",
    "worm": "Worm",
    "earthworm": "Earthworm",
    "snail": "Snail",
    "whelk": "Whelk",
    "limpet": "Limpet",
    "conch": "Conch",
    "abalone": "Abalone",
    "urchin": "Sea urchin",
    "sea star": "Starfish",
    "anemone": "Sea anemone",
    "sea squirt": "Sea squirt",
    "isopod": "Isopod",
    "cricket": "Cricket",
    "honeybee": "Honeybee",

    # Amphibians & Reptiles
    "frog": "Frog", "frogs": "Frog", "bullfrog": "Frog", "tadpole": "Frog",
    "amphibian": "Amphibian",
    "snake": "Snake",

    # Other / Micro
    "bat": "Bat",
    "nematode": "Nematode",
    "bacteria": "Bacteria",
    "algae": "Algae",
    "invertebrates": "Invertebrates",
}

# =========================
# ORGANS_NORMALIZE
# =========================

ORGANS_NORMALIZE = {
    # Core Organs
    "liver": "Liver", "livers": "Liver", "hepatocytes": "Liver", "hepatopancreas": "Liver", # Hepatopancreas functionally similar in inverts
    "kidney": "Kidney", "kidneys": "Kidney",
    "brain": "Brain", "cortex": "Brain", "hippocampus": "Brain", "cerebellum": "Brain", "cerebral cortex": "Brain",
    "lung": "Lung", "lungs": "Lung",
    "heart": "Heart", "heart muscle": "Heart",
    "spleen": "Spleen",
    "pancreas": "Pancreas",
    "gallbladder": "Gallbladder", "gall bladder": "Gallbladder",

    # Muscle & Flesh
    "muscle": "Muscle", "muscles": "Muscle", "skeletal muscle": "Muscle",
    "gastrocnemius": "Muscle", "gastrocnemius muscle": "Muscle", "thigh": "Muscle",
    "fillet": "Muscle", # Common in fish
    "meat": "Muscle",

    # Aquatic Specific
    "gill": "Gills", "gills": "Gills",
    "swim bladder": "Swim bladder", "swimming bladder": "Swim bladder",
    "fin": "Fins", "fins": "Fins",
    "scale": "Scales",
    "mucus": "Mucus",

    # Gastrointestinal
    "intestine": "Intestine", "intestines": "Intestine", "small intestine": "Intestine", "large intestine": "Intestine",
    "duodenum": "Intestine", "jejunum": "Intestine", "cecum": "Intestine",
    "stomach": "Stomach", "gizzard": "Stomach", "digestive tract": "Gastrointestinal tract", "digestive system": "Gastrointestinal tract",
    "gut": "Gastrointestinal tract", "gastrointestinal tract": "Gastrointestinal tract",
    "gastric mucosa": "Stomach",

    # Blood & Matrices
    "blood": "Blood", "whole blood": "Whole blood",
    "serum": "Serum", "blood serum": "Serum", "fetal serum": "Serum", "maternal serum": "Serum", "infant serum": "Serum",
    "plasma": "Plasma", "blood plasma": "Plasma", "fetal plasma": "Plasma", "neonatal plasma": "Plasma", "arterial plasma": "Plasma", "venous plasma": "Plasma",
    "cord blood": "Cord blood",
    "red blood cells": "Red blood cells", "rbc": "Red blood cells",
    "blood clot": "Blood",
    "haemolymph": "Haemolymph", # Invertebrate blood

    # Excreta / Elimination
    "urine": "Urine",
    "feces": "Feces", "faeces": "Feces", "stool": "Feces",
    "bile": "Bile", "gallbladder bile": "Bile",
    "sweat": "Sweat",

    # Endocrine / Reproductive
    "thyroid": "Thyroid", "thyroid gland": "Thyroid",
    "adrenal": "Adrenal", "adrenals": "Adrenal", "adrenal gland": "Adrenal", "adrenal glands": "Adrenal",
    "pituitary": "Pituitary", "pituitary gland": "Pituitary",
    "testis": "Testis", "testes": "Testis", "leydig cells": "Testis",
    "ovary": "Ovary", "ovaries": "Ovary",
    "uterus": "Uterus",
    "prostate": "Prostate",
    "gonad": "Gonad", "gonads": "Gonad", "reproductive tract": "Gonad",
    "mammary gland": "Mammary gland", "mammary tissue": "Mammary gland", "breast": "Mammary gland", "udder": "Mammary gland",
    "cervix": "Cervix",

    # Development / Pregnancy
    "placenta": "Placenta",
    "amniotic fluid": "Amniotic fluid",
    "fetus": "Fetus",
    "embryo": "Embryo", "whole embryo": "Embryo",
    "yolk": "Egg yolk", "egg yolk": "Egg yolk",
    "albumen": "Egg white",
    "egg": "Egg", "eggs": "Egg",

    # Milk
    "milk": "Milk", "breast milk": "Breast milk", "breastmilk": "Breast milk",

    # Adipose / Skin / Bone / Integument
    "adipose": "Adipose tissue", "adipose tissue": "Adipose tissue", "fat": "Adipose tissue", "blubber": "Adipose tissue",
    "white fat": "Adipose tissue", "white adipose tissue": "Adipose tissue",
    "epididymal fat": "Adipose tissue", "abdominal fat": "Adipose tissue", "perigonadal adipose": "Adipose tissue", "subcutaneous fat": "Adipose tissue",
    "fat pad": "Adipose tissue", "fat pads": "Adipose tissue", "inguinal fat pads": "Adipose tissue",
    "lipid": "Adipose tissue", # Context dependent, but usually implies fat extract
    "skin": "Skin", "integumenta": "Skin",
    "bone": "Bone", "whole bone": "Bone", "femur": "Bone", "cartilage": "Bone",
    "bone marrow": "Bone marrow",
    "hair": "Hair",
    "nail": "Nail",

    # Nervous / Sensory
    "eye": "Eye", "eyes": "Eye", "retina": "Eye",
    "olfactory epithelium": "Olfactory epithelium", "olfactory rosette": "Olfactory epithelium",
    "spinal cord": "Spinal cord",

    # General / Composite / PBPK Buckets
    "carcass": "Carcass", "carcass remainder": "Carcass",
    "whole body": "Whole body", "whole organism": "Whole body", "whole pup": "Whole body",
    "rest of body": "Rest of body", "rest of body tissues": "Rest of body",
    "soft tissue": "Soft tissue", "viscera": "Viscera", "visceral mass": "Viscera",
    "head": "Head",
    "extremities": "Extremities",
    "filtrate": "Filtrate",
    "storage": "Storage",
    "cell": "Cells",
    "tissue": "Tissue",

    # Immune / Lymph
    "thymus": "Thymus",
    "lymph node": "Lymph node", "lymph nodes": "Lymph node", "thoracic lymph nodes": "Lymph node", "mesenteric lymph nodes": "Lymph node",
    "immune": "Immune system",

    # Plants (from your list)
    "root": "Root",
    "shoot": "Shoot",

    # Invertebrate specifics
    "digestive gland": "Hepatopancreas", # Often functionally equivalent to liver/pancreas in mollusks
    "mantle": "Mantle",
    "adductor muscle": "Muscle",
    "shell": "Shell",

    # Misc
    "gland": "Gland",
    "tumour": "Tumor",
    "mandible": "Bone",
    "palatal shelves": "Bone",
}

# ======================================
# Color palettes
# ======================================

# 1) Individual Top-20 bars (PFAS / species / organs)
# 8 base colors from the rainbow palette
BASE_RAINBOW_HEX = [
    "#f57c6e",  # red-orange
    "#f2b56e",  # orange
    "#fbe79e",  # yellow
    "#84c3b7",  # green-teal
    "#88d7da",  # light cyan
    "#71b8ed",  # blue
    "#b8aeea",  # lavender
    "#f2a8da",  # pink
]

# Build a smooth colormap from these 8 colors
RAINBOW_CMAP = LinearSegmentedColormap.from_list(
    "soft_rainbow", BASE_RAINBOW_HEX, N=256
)

# Sample 20 evenly spaced colors from the colormap
BAR_PALETTE_INDIV = [
    mcolors.to_hex(RAINBOW_CMAP(x))
    for x in np.linspace(0, 1, 20)
]

# Use this as your default palette for individual Top-20 plots
BAR_PALETTE = BAR_PALETTE_INDIV

# 2) Species groups (7 categories including "Other")
BAR_PALETTE_SPECIES_GROUP = [
    "#4E659B",  # 078,101,155
    "#8A8CBF",  # 138,140,191
    "#B8A8CF",  # 184,168,207
    "#E7BCC6",  # 231,188,198
    "#FDCF9E",  # 253,207,158
    "#EFA484",  # 239,164,132
    "#B6766C",  # 182,118,108
]

# 3) Organ groups (10 categories including "Other")
BAR_PALETTE_ORGAN_GROUP = [
    "#E76254",  # 231, 98, 84
    "#EF8A47",  # 239,138, 71
    "#F7AA58",  # 247,170, 88
    "#FFD06F",  # 255,208,111
    "#FFE6B7",  # 255,230,183
    "#AAE6E0",  # 170,230,224
    "#72BCD5",  # 114,188,213
    "#528FAD",  #  82,143,173
    "#376795",  #  55,103,149
    "#1E466E",  #  30, 70,110
]

# Default palette used by barplot_topn for individual plots
BAR_PALETTE = BAR_PALETTE_INDIV


# =========================
# Parsing / cleaning utils
# =========================
SEMICOLON_SPLIT_RE = re.compile(r"\s*;\s*")
ABBR_RE = re.compile(r"^(.*?)\s*\(([^)]+)\)\s*$")

def _split_semicolon(cell: str):
    if not isinstance(cell, str) or not cell.strip():
        return []
    return [p.strip() for p in SEMICOLON_SPLIT_RE.split(cell.strip()) if p.strip()]

def clean_chemical(token: str) -> str:
    """
    If token contains 'Full name (ABBR)', return ABBR (uppercased).
    Else if it's likely an abbreviation, uppercase; else Title Case full name.
    """
    t = (token or "").strip()
    if not t:
        return ""
    m = ABBR_RE.match(t)
    if m:
        abbr = re.sub(r"[^A-Za-z0-9\-]", "", m.group(2).strip())
        return abbr.upper() if abbr else m.group(1).strip().title()
    if " " not in t and len(t) <= 12:
        cleaned = re.sub(r"[^A-Za-z0-9\-]", "", t)
        if cleaned:
            return cleaned.upper()
    return re.sub(r"\s+", " ", t.lower()).strip().title()

def clean_category(token: str, mapping: dict | None = None) -> str:
    t = (token or "").strip()
    if not t:
        return ""
    low = t.lower()
    if mapping and low in mapping:
        return mapping[low]
    return re.sub(r"\s+", " ", t).strip().title()

def explode_multi(df_in: pd.DataFrame, col: str, cleaner, mapping=None) -> pd.DataFrame:
    tmp = df_in[[PMID_COL, col]].copy()
    tmp[col] = tmp[col].apply(_split_semicolon)
    tmp = tmp.explode(col)
    tmp[col] = tmp[col].fillna("").astype(str)
    tmp[col] = tmp[col].apply(lambda x: cleaner(x) if mapping is None else cleaner(x, mapping))
    tmp = tmp[tmp[col].str.len() > 0]
    return tmp

# =========================
# Plot helpers
# =========================
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# ==========================================
# CONFIGURATION: Set DPI to 300
# ==========================================
SAVE_DPI = 300
SAVE_KW = dict(format="tiff", dpi=SAVE_DPI, pil_kwargs={"compression": "tiff_lzw"})

# Update global matplotlib settings for consistency
plt.rcParams.update({
    "savefig.dpi": SAVE_DPI,
    "figure.dpi": 110,  # Screen resolution
})

# ==========================================
# FUNCTION DEFINITION (Saves TIFF & PNG)
# ==========================================
def barplot_topn(counts: pd.Series,
                 title: str,
                 xlabel: str,
                 ylabel: str,
                 save_path: str,
                 palette=None):
    """
    Vertical bar plot for top N categories.
    Saves in BOTH .tiff and .png formats automatically.
    """
    # Print counts for verification
    print(f"\n[Record counts] {title} (all categories):")
    print(counts.to_string())

    # Prepare data
    top_counts = counts.head(TOPN)
    n = len(top_counts)
    if palette is None:
        palette = BAR_PALETTE
    palette_use = palette[:n]

    # Create figure
    fig_w = max(8, min(28, 0.50 * n + 2))
    plt.figure(figsize=(fig_w, 7))

    # Plot (Vertical)
    ax = sns.barplot(
        x=top_counts.index,
        y=top_counts.values,
        hue=top_counts.index,
        palette=palette_use
    )

    # Add labels on top of bars
    for container in ax.containers:
        ax.bar_label(container, fmt='%d', padding=3, fontsize=9)

    # Adjust axes
    ymax = top_counts.values.max()
    ax.set_ylim(0, ymax * 1.15)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Add Titles/Labels
    ax.set_xlabel(xlabel, fontsize=16, fontweight='bold')
    ax.set_ylabel(ylabel, fontsize=16, fontweight='bold')
    ax.set_title(title, fontsize=20, pad=12, fontweight='bold')

    plt.tight_layout()

    # --- SAVE LOGIC ---
    # 1. Save the TIFF (using the path you provided in the call)
    plt.savefig(save_path, **SAVE_KW)
    print(f"Saved TIFF: {save_path}")

    # 2. Automatically save the PNG (derived from the same filename)
    base_name, _ = os.path.splitext(save_path)
    png_path = f"{base_name}.png"
    plt.savefig(png_path, format="png", dpi=SAVE_DPI)
    print(f"Saved PNG:  {png_path}")
    # ------------------

    plt.show()

def sort_with_other_last(counts: pd.Series, other_label="Other") -> pd.Series:
    """Sort counts descending, but force 'Other' to be last if present."""
    if other_label in counts.index:
        main = counts.drop(other_label).sort_values(ascending=False)
        return pd.concat([main, counts[[other_label]]])
    else:
        return counts.sort_values(ascending=False)

def plot_heatmap_varscale(
    M,
    title,
    xlabel, ylabel,
    save_path,
    mode="log",
    cmap="coolwarm",
    annotate=True,
    dpi_save_kwargs=None
):
    """
    Visualize skewed count matrices with alternative scalings while annotating true counts.
    """
    if dpi_save_kwargs is None:
        dpi_save_kwargs = dict(format="tiff", dpi=600, pil_kwargs={"compression": "tiff_lzw"})

    counts = M.astype(int).copy()
    data   = counts.copy().astype(float)
    cbar_label = "Count"

    if mode == "abs":
        vmin, vmax = 0, data.values.max()

    elif mode == "log":
        data = np.log10(data + 1.0)
        vmin, vmax = 0, data.values.max()
        cbar_label = "log10(count + 1)"

    elif mode == "sqrt":
        data = np.sqrt(data)
        vmin, vmax = 0, data.values.max()
        cbar_label = "sqrt(count)"

    elif mode == "clip95":
        vmin, vmax = 0, np.percentile(data.values, 95)
        cbar_label = "Count (vmax = 95th pct)"

    elif mode == "quantile":
        # 10 quantile bins; adjust if you want more/less bands
        qs = np.linspace(0, 100, 11)
        bounds = np.percentile(data.values.ravel(), qs)
        # ensure unique, slight jitter if needed
        bounds = np.unique(bounds)
        if len(bounds) < 3:  # degenerate case
            bounds = np.array([0, 1, data.values.max()])
        norm = BoundaryNorm(bounds, ncolors=plt.get_cmap(cmap).N, clip=True)
        vmin = vmax = None   # handled by norm
    elif mode == "row_pct":
        row_sums = data.sum(axis=1).replace(0, np.nan)
        pct = data.div(row_sums, axis=0) * 100.0
        data = pct.fillna(0.0)
        vmin, vmax = 0, 100
        cbar_label = "Row %"
    else:
        raise ValueError("mode must be one of: abs, log, sqrt, clip95, quantile, row_pct")

    # figure size
    h, w = data.shape
    fig_w = max(10, min(28, 0.50 * w + 6))
    fig_h = max(8,  min(28, 0.50 * h + 4))

    plt.figure(figsize=(fig_w, fig_h))
    if mode == "quantile":
        ax = sns.heatmap(
            data, annot=annotate, fmt="d" if mode!="row_pct" else ".0f",
            cmap=cmap, norm=norm, linewidths=0.2, linecolor="white",
            cbar_kws={"label": cbar_label}
        )
    else:
        ax = sns.heatmap(
            data, annot=annotate, fmt="d" if mode!="row_pct" else ".0f",
            cmap=cmap, vmin=vmin, vmax=vmax, linewidths=0.2, linecolor="white",
            cbar_kws={"label": cbar_label}
        )

    # If using row_pct, annotate with counts but color by %
    if mode == "row_pct" and annotate:
        for (i, j), _ in np.ndenumerate(data.values):
            ax.text(
                j + 0.5, i + 0.5,
                f"{counts.iat[i, j]}",     # show counts
                ha="center", va="center", color="black", fontsize=8
            )

    ax.set_title(title, pad=12)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    plt.tight_layout()
    plt.savefig(save_path, **dpi_save_kwargs)
    plt.show()

# =========================
# Normalize & explode
# =========================
chem_long = explode_multi(df, CHEM_COL, cleaner=clean_chemical)
species_long = explode_multi(df, SPECIES_COL, cleaner=clean_category, mapping=SPECIES_NORMALIZE)
organs_long  = explode_multi(df, ORGANS_COL,  cleaner=clean_category, mapping=ORGANS_NORMALIZE)

## Figure 3

In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import pandas as pd
import numpy as np # Added for color calculation

# CHANGE THIS to match your actual year column
YEAR_COL = "publication_year"

# Clean year data
years = pd.to_numeric(df[YEAR_COL], errors="coerce").dropna().astype(int)
year_distribution = years.value_counts().sort_index()

# Color Visibility
num_years = len(year_distribution)
color_indices = np.linspace(0.3, 1.0, num_years)
gradient_colors = cm.Blues(color_indices)

# Plot
plt.figure(figsize=(16, 6))

bars = plt.bar(
    year_distribution.index.astype(str),
    year_distribution.values,
    color=gradient_colors
)

plt.title("Distribution of PFAS Publications by Year", fontsize=20, weight="bold")
plt.xlabel("Publication Year", fontsize=16, weight="bold")
plt.ylabel("Number of Articles", fontsize=16, weight="bold")

plt.xticks(rotation=75, fontsize=14)
plt.ylim(0, year_distribution.max() + 10)

# Remove Top and Right Spines
ax = plt.gca() # Get current axes
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Add counts above each bar
for bar, value in zip(bars, year_distribution.values):
    plt.text(
        bar.get_x() + bar.get_width() / 2,
        bar.get_height() + 0.5,
        str(value),
        ha='center',
        va='bottom',
        fontsize=12
    )

plt.tight_layout()

# Save TIFF (publication quality)
plt.savefig("PFAS_PublicationYear_Distribution.tiff", dpi=600, format="tiff")
plt.savefig("PFAS_PublicationYear_Distribution.png", dpi=600, format="png")

plt.show()

year_distribution

## Figure 4

In [None]:
import json

# Load species groups from species_list.txt
with open("species_list.txt", "r", encoding="utf-8") as f:
    SPECIES_GROUPS = json.load(f)

# Build species → group map (case-insensitive)
SPECIES_TO_GROUP = {}
for group, names in SPECIES_GROUPS.items():
    for name in names:
        SPECIES_TO_GROUP[name.strip().lower()] = group

def map_species_group(name: str) -> str:
    low = str(name).strip().lower()
    return SPECIES_TO_GROUP.get(low, "Other")

# Assign groups
species_long["Species_group"] = species_long[SPECIES_COL].apply(map_species_group)

# Count records per group (deduplicate by PMID × group)
species_group_counts = (
    species_long
    .drop_duplicates([PMID_COL, "Species_group"])
    ["Species_group"]
    .value_counts()
)

# Sort by count, but keep "Other" last
species_group_counts_sorted = sort_with_other_last(species_group_counts, other_label="Other")

# Plot grouped species with 7-color palette
barplot_topn(
    species_group_counts_sorted,
    title="Species Category Count by Group",
    xlabel="Species Group",
    ylabel="Number of Records",
    save_path="Species_grouped.png",
    palette=BAR_PALETTE_SPECIES_GROUP,
)

In [None]:
# Load organ/tissue groups from organs_tissues_list.txt
with open("organs_tissues_list.txt", "r", encoding="utf-8") as f:
    ORGANS_GROUPS = json.load(f)

# Build organ → group map (case-insensitive)
ORG_TO_GROUP = {}
for group, names in ORGANS_GROUPS.items():
    for name in names:
        ORG_TO_GROUP[name.strip().lower()] = group

def map_organ_group(name: str) -> str:
    low = str(name).strip().lower()
    return ORG_TO_GROUP.get(low, "Other")

# Assign groups
organs_long["Organ_group"] = organs_long[ORGANS_COL].apply(map_organ_group)

# Count records per organ group (deduplicate by PMID × group)
organ_group_counts = (
    organs_long
    .drop_duplicates([PMID_COL, "Organ_group"])
    ["Organ_group"]
    .value_counts()
)

# Sort with "Other" last
organ_group_counts_sorted = sort_with_other_last(organ_group_counts, other_label="Other")

# Plot grouped organs with 10-color palette
barplot_topn(
    organ_group_counts_sorted,
    title="Organ Category Count by Group",
    xlabel="Organ Group",
    ylabel="Number of Records",
    save_path="Organs_grouped.png",
    palette=BAR_PALETTE_ORGAN_GROUP,
)

## Figure 5

In [None]:
# === Individual Bar Plots (Top-20) ===

# 1. PFAS
chem_counts = (
    chem_long.drop_duplicates([PMID_COL, CHEM_COL])[CHEM_COL]
    .value_counts()
    .sort_values(ascending=False)
)
barplot_topn(
    chem_counts,
    title="PFAS Category Count",
    xlabel="PFAS",
    ylabel="Number of Records",
    save_path="PFAS_category_count_top20_vertical.tiff",
    # The function will automatically create "PFAS_category_count_top20_vertical.png"
)

# 2. Species
species_counts = (
    species_long.drop_duplicates([PMID_COL, SPECIES_COL])[SPECIES_COL]
    .value_counts()
    .sort_values(ascending=False)
)
barplot_topn(
    species_counts,
    title="Species Category Count",
    xlabel="Species",
    ylabel="Number of Records",
    save_path="Species_individual_top20_vertical.tiff",
)

# 3. Organs
org_counts = (
    organs_long.drop_duplicates([PMID_COL, ORGANS_COL])[ORGANS_COL]
    .value_counts()
    .sort_values(ascending=False)
)
barplot_topn(
    org_counts,
    title="Organs Category Count",
    xlabel="Organ",
    ylabel="Number of Records",
    save_path="Organs_individual_top20_vertical.tiff",
)

## Figure 6

### PFAS x Species

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

# ========================
# 1. Build PFAS × Species_group table
# ========================
pairs_pfas_speciesgrp = (
    chem_long[[PMID_COL, CHEM_COL]]
    .merge(
        species_long[[PMID_COL, "Species_group"]],
        on=PMID_COL,
        how="inner"
    )
    .drop_duplicates(subset=[PMID_COL, CHEM_COL, "Species_group"])
    .rename(columns={CHEM_COL: "PFAS"})
)

pfas_species_counts = (
    pairs_pfas_speciesgrp
    .groupby(["PFAS", "Species_group"])
    .size()
    .reset_index(name="count")
)

# Choose top-N PFAS by total records
TOPN_PFAS_FOR_STACK = 16
top_pfas = (
    pfas_species_counts
    .groupby("PFAS")["count"].sum()
    .sort_values(ascending=False)
    .head(TOPN_PFAS_FOR_STACK)
    .index
)

pfas_species_top = pfas_species_counts[pfas_species_counts["PFAS"].isin(top_pfas)]

pfas_species_pivot = (
    pfas_species_top
    .pivot(index="PFAS", columns="Species_group", values="count")
    .fillna(0)
)

# Order PFAS by total count
pfas_species_pivot = pfas_species_pivot.loc[
    pfas_species_pivot.sum(axis=1).sort_values(ascending=False).index
]

# ========================
# 2. Force species order to match Figure 1
# ========================
species_order = [
    "Humans",
    "Experimental Animals",
    "Aquatic Species",
    "Livestock",
    "Pets",
    "Wildlife",
    "Other",
]

# Only keep those species groups that actually appear
cols_in_order = [s for s in species_order if s in pfas_species_pivot.columns]

# Reorder columns in pivot table
pfas_species_pivot = pfas_species_pivot[cols_in_order]

# ========================
# 3. Colors from BAR_PALETTE_SPECIES_GROUP
# ========================
# (must be in the same order as `species_order`)
species_color_map = {
    group: color
    for group, color in zip(species_order, BAR_PALETTE_SPECIES_GROUP)
}

species_colors_for_cols = [species_color_map[c] for c in cols_in_order]

# ========================
# 4. Plot stacked bar
# ========================
plt.figure(figsize=(12, 6))
ax = pfas_species_pivot.plot(
    kind="bar",
    stacked=True,
    color=species_colors_for_cols,
    figsize=(12, 6)
)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.ylabel("Number of Records", fontsize=12, fontweight="bold")
plt.xlabel("PFAS", fontsize=12, fontweight="bold")
plt.title("Species Group Distribution for Top PFAS", fontsize=14, fontweight="bold")
plt.xticks(rotation=45, ha="right")
plt.legend(
    title="Species Group",
    bbox_to_anchor=(1.02, 1),
    loc="upper left"
)

plt.tight_layout()
plt.savefig("PFAS_SpeciesGroup_stacked.png", dpi=300, format="png")
plt.show()

### PFAS x Organs

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

# ========================
# 1. Build PFAS × Organs_group table
# ========================
pairs_pfas_organsgrp = (
    chem_long[[PMID_COL, CHEM_COL]]
    .merge(
        organs_long[[PMID_COL, "Organ_group"]],
        on=PMID_COL,
        how="inner"
    )
    .drop_duplicates(subset=[PMID_COL, CHEM_COL, "Organ_group"])
    .rename(columns={CHEM_COL: "PFAS"})
)

pfas_organs_counts = (
    pairs_pfas_organsgrp
    .groupby(["PFAS", "Organ_group"])
    .size()
    .reset_index(name="count")
)

# Choose top-N PFAS by total records
TOPN_PFAS_FOR_STACK = 16
top_pfas = (
    pfas_organs_counts
    .groupby("PFAS")["count"].sum()
    .sort_values(ascending=False)
    .head(TOPN_PFAS_FOR_STACK)
    .index
)

pfas_organs_top = pfas_organs_counts[pfas_organs_counts["PFAS"].isin(top_pfas)]

pfas_organs_pivot = (
    pfas_organs_top
    .pivot(index="PFAS", columns="Organ_group", values="count")
    .fillna(0)
)

# Order PFAS by total count
pfas_organs_pivot = pfas_organs_pivot.loc[
    pfas_organs_pivot.sum(axis=1).sort_values(ascending=False).index
]

# ========================
# 2. Force organs order
# ========================
organs_order = [
    "Major Organs",
    "Blood and Circulation",
    "Other Tissues",
    "Reproductive Organs",
    "Immune System",
    "Digestive System",
    "Respiratory System",
    "Other",
]

# Only keep those organs groups that actually appear
cols_in_order = [s for s in organs_order if s in pfas_organs_pivot.columns]

# Reorder columns in pivot table
pfas_organs_pivot = pfas_organs_pivot[cols_in_order]

# ========================
# 3. Colors from BAR_PALETTE_ORGAN_GROUP
# ========================
# (must be in the same order as `organs_order`)
organs_color_map = {
    group: color
    for group, color in zip(organs_order, BAR_PALETTE_ORGAN_GROUP)
}

organs_colors_for_cols = [organs_color_map[c] for c in cols_in_order]

# ========================
# 4. Plot stacked bar
# ========================
plt.figure(figsize=(12, 6))
ax = pfas_organs_pivot.plot(
    kind="bar",
    stacked=True,
    color=organs_colors_for_cols,
    figsize=(12, 6)
)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.ylabel("Number of Records", fontsize=12, fontweight="bold")
plt.xlabel("PFAS", fontsize=12, fontweight="bold")
plt.title("Organs Group Distribution for Top PFAS", fontsize=14, fontweight="bold")
plt.xticks(rotation=45, ha="right")
plt.legend(
    title="Organs Group",
    bbox_to_anchor=(1.02, 1),
    loc="upper left"
)

plt.tight_layout()
plt.savefig("PFAS_OrgansGroup_stacked.tiff", **SAVE_KW)
plt.savefig("PFAS_OrgansGroup_stacked.png", dpi=300, format="png")
plt.show()


### Species x Organs

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

# ========================
# 1. Build Species_group × Organ_group table
# ========================

pairs_speciesgrp_organgrp = (
    species_long[[PMID_COL, "Species_group"]]
    .merge(
        organs_long[[PMID_COL, "Organ_group"]],
        on=PMID_COL,
        how="inner"
    )
    .drop_duplicates(subset=[PMID_COL, "Species_group", "Organ_group"])
)

species_org_counts = (
    pairs_speciesgrp_organgrp
    .groupby(["Species_group", "Organ_group"])
    .size()
    .reset_index(name="count")
)

species_org_matrix = (
    species_org_counts
    .pivot(index="Species_group", columns="Organ_group", values="count")
    .fillna(0)
)

# Order species groups (X-axis categories)
species_order = [
    "Humans",
    "Experimental Animals",
    "Aquatic Species",
    "Livestock",
    "Pets",
    "Wildlife",
    "Other",
]
rows_in_order = [s for s in species_order if s in species_org_matrix.index]
species_org_matrix = species_org_matrix.loc[rows_in_order]

# ========================
# 2. Force organ group order (Legend/Stack order)
# ========================
organ_order = [
    "Major Organs",
    "Blood and Circulation",
    "Other Tissues",
    "Common Test Matrices",
    "Reproductive Organs",
    "Immune System",
    "Digestive System",
    "Respiratory System",
    "Other",
]
cols_in_order_org = [o for o in organ_order if o in species_org_matrix.columns]
species_org_matrix = species_org_matrix[cols_in_order_org]

# ========================
# 3. Map organ colors
# ========================
organ_color_map = {
    group: color
    for group, color in zip(organ_order, BAR_PALETTE_ORGAN_GROUP)
}
organ_colors_for_cols = [organ_color_map[c] for c in cols_in_order_org]

# ========================
# 4. Stacked Vertical Barplot
# ========================
plt.figure(figsize=(10, 7))

ax = species_org_matrix.plot(
    kind="bar",
    stacked=True,
    color=organ_colors_for_cols,
    figsize=(10, 7)
)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# --- Updated Axis Labels ---
plt.xlabel("Species Group", fontsize=12, fontweight="bold")
plt.ylabel("Number of Records", fontsize=12, fontweight="bold")

plt.title("Distribution of Organ Groups within Species Groups",
          fontsize=14, fontweight="bold", pad=20)

plt.xticks(rotation=45, ha='right', rotation_mode='anchor')

plt.tight_layout()

# Save in both formats at 300 DPI
plt.savefig("SpeciesGroup_to_OrganGroup_stacked_ver.png", dpi=300, format="png")

plt.show()

## Figure 7

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LogNorm

# -----------------------------
# Font sizes
# -----------------------------
HM_TITLE_FS   = 22
HM_XLABEL_FS  = 22
HM_YLABEL_FS  = 22
HM_TICK_FS    = 16
HM_ANNOT_FS   = 14
HM_CBAR_LBLFS = 18
HM_CBAR_TICK  = 14

# ======================================
# 1) Plot Function (generic y-label)
# ======================================
def plot_top20_heatmap(ct: pd.DataFrame,
                       x_label: str,
                       y_label: str,
                       filename: str,
                       title_prefix: str,
                       topn: int = None):
    """
    Plot a Top-N x Top-N heatmap using log color scaling.
    - Colors use log scale on (count + 1)
    - Annotations show TRUE counts (no +1)
    """
    if topn is None:
        topn = TOPN

    # Select Top-N based on row/col sums
    rows = ct.sum(axis=1).sort_values(ascending=False).head(topn).index
    cols = ct.sum(axis=0).sort_values(ascending=False).head(topn).index
    sub = ct.loc[rows, cols].copy()

    # If matrix is empty, fail gracefully
    if sub.empty:
        print(f"[Skip] {title_prefix}: empty matrix after Top-{topn} filtering.")
        return

    # Dynamic figure size
    h, w = sub.shape
    fig_w = max(10, min(28, 0.50 * w + 6))
    fig_h = max(8,  min(28, 0.50 * h + 4))

    # Color setup
    cmap = sns.color_palette("coolwarm", as_cmap=True)

    # LogNorm requires positive values
    vmax = int(sub.values.max()) + 1
    norm = LogNorm(vmin=1, vmax=max(2, vmax))

    plt.figure(figsize=(fig_w, fig_h))
    ax = sns.heatmap(
        sub + 1,
        annot=sub, fmt="d",
        cmap=cmap,
        norm=norm,
        linewidths=0.25,
        linecolor="white",
        annot_kws={"fontsize": HM_ANNOT_FS},
        cbar_kws={"label": "Count"}
    )

    # Title and labels
    ax.set_title(title_prefix, fontsize=HM_TITLE_FS, fontweight="bold", pad=16)
    ax.set_xlabel(x_label, fontsize=HM_XLABEL_FS, fontweight="bold", labelpad=20)
    ax.set_ylabel(y_label, fontsize=HM_YLABEL_FS, fontweight="bold")

    # Tick labels
    ax.tick_params(axis="x", labelsize=HM_TICK_FS)
    ax.tick_params(axis="y", labelsize=HM_TICK_FS)

    # Colorbar styling
    cbar = ax.collections[0].colorbar
    cbar.ax.set_ylabel("Count", fontsize=HM_CBAR_LBLFS, fontweight="bold")
    cbar.ax.tick_params(labelsize=HM_CBAR_TICK)

    plt.tight_layout()
    plt.savefig(filename, **SAVE_KW)
    plt.show()


# ============================================================
# 2) Build matrices and plot
# ============================================================

# -----------------------------
# PFAS × Organs
# -----------------------------
pairs_chem_org = (
    chem_long[[PMID_COL, CHEM_COL]]
    .merge(organs_long[[PMID_COL, ORGANS_COL]], on=PMID_COL, how="inner")
    .drop_duplicates(subset=[PMID_COL, CHEM_COL, ORGANS_COL])
    .rename(columns={CHEM_COL: "PFAS", ORGANS_COL: "Organ"})
)

ct_chem_org = (
    pairs_chem_org
    .groupby(["PFAS", "Organ"], as_index=False)
    .size()
    .pivot(index="PFAS", columns="Organ", values="size")
    .fillna(0)
    .astype(int)
)

plot_top20_heatmap(
    ct_chem_org,
    x_label="Organ",
    y_label="PFAS",
    filename="PFAS_x_Organs_heatmap_top20x20.png",
    title_prefix="PFAS × Organs"
)

# -----------------------------
# PFAS × Species
# -----------------------------
pairs_chem_species = (
    chem_long[[PMID_COL, CHEM_COL]]
    .merge(species_long[[PMID_COL, SPECIES_COL]], on=PMID_COL, how="inner")
    .drop_duplicates(subset=[PMID_COL, CHEM_COL, SPECIES_COL])
    .rename(columns={CHEM_COL: "PFAS", SPECIES_COL: "Species"})
)

ct_chem_species = (
    pairs_chem_species
    .groupby(["PFAS", "Species"], as_index=False)
    .size()
    .pivot(index="PFAS", columns="Species", values="size")
    .fillna(0)
    .astype(int)
)

plot_top20_heatmap(
    ct_chem_species,
    x_label="Species",
    y_label="PFAS",
    filename="PFAS_x_Species_heatmap_top20x20.png",
    title_prefix="PFAS × Species"
)

# -----------------------------
# Species × Organs
# -----------------------------
pairs_species_org = (
    species_long[[PMID_COL, SPECIES_COL]]
    .merge(organs_long[[PMID_COL, ORGANS_COL]], on=PMID_COL, how="inner")
    .drop_duplicates(subset=[PMID_COL, SPECIES_COL, ORGANS_COL])
    .rename(columns={SPECIES_COL: "Species", ORGANS_COL: "Organ"})
)

ct_species_org = (
    pairs_species_org
    .groupby(["Species", "Organ"], as_index=False)
    .size()
    .pivot(index="Species", columns="Organ", values="size")
    .fillna(0)
    .astype(int)
)

plot_top20_heatmap(
    ct_species_org,
    x_label="Organ",
    y_label="Species",
    filename="Species_x_Organs_heatmap_top20x20.png",
    title_prefix="Species × Organs"
)