In [7]:
# Run this cell: 
# The lines below will instruct jupyter to reload imported modules before 
# executing code cells. This enables you to quickly iterate and test revisions
# to your code without having to restart the kernel and reload all of your 
# modules each time you make a code change in a separate python file.

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
import os
# Change path to project root
if os.getcwd().endswith("notebooks"):
    os.chdir(os.path.dirname(os.getcwd()))
print(os.getcwd())

/Users/shloknatarajan/stanford/research/daneshjou/AutoGKB


In [9]:
from src.inference import Generator
from src.prompts import PromptGenerator
from src.article_parser import MarkdownParser
from typing import List

In [10]:
article_title = MarkdownParser(pmcid="PMC11730665").parse().title
article_text = MarkdownParser(pmcid="PMC11730665").parse().article_text

print(article_title)
print(article_text)

[32m2025-06-14 14:27:20.483[0m | [1mINFO    [0m | [36msrc.annotation_extraction.article_parser[0m:[36m__init__[0m:[36m32[0m - [1mGetting article text from PMCID: PMC11730665[0m
[32m2025-06-14 14:27:20.492[0m | [1mINFO    [0m | [36msrc.annotation_extraction.article_parser[0m:[36mremove_references_section[0m:[36m79[0m - [1mRemoved References section from article text[0m
[32m2025-06-14 14:27:20.493[0m | [1mINFO    [0m | [36msrc.annotation_extraction.article_parser[0m:[36m__init__[0m:[36m32[0m - [1mGetting article text from PMCID: PMC11730665[0m
[32m2025-06-14 14:27:20.494[0m | [1mINFO    [0m | [36msrc.annotation_extraction.article_parser[0m:[36mremove_references_section[0m:[36m79[0m - [1mRemoved References section from article text[0m


In [11]:
DRUG_SUMMARY = """
You are an expert pharmacogenomics researcher reading and extracting annotations from the following article

\n\n{article_text}\n\n

From this article:
First, briefly summarize the methods and results of the study
Next, create a list of extract all variants that have a studied described effect on drug response.
For each variant <-> drug relationship in the list output the following:
Pair: The (variant, drug(s)) pair
Variant: The Variant / Haplotypes (ex. rs2909451, CYP2C19*1, CYP2C19*2, *1/*18, etc.)
Gene: The gene group of the variant (ex. DPP4, CYP2C19, KCNJ11, etc.)
Allele: Specific allele or genotype if different from variant (ex. TT, *1/*18, del/del, etc.)
Relationship Description: Describe the drug(s) in this relationship
Variant Effect: Describe the outcome/effect found from the variant (drug efficacy, metabolism, toxicity, dosage, etc.)
Statistical Analysis: Describe the statistical analysis used and the reported p-values
Population Info: Describe the population of the study participants for this variant
Notes: Describe any other useful information included on this variant to understanding the study results.

These 9 attributes should be separately repeated for every distinct (variant, drug(s)) relationship. For every attribute,
include the information as well as a quote from the article the information was concluded from.
"""

In [13]:
VARIANT_LIST_PROMPT = """
You are an expert pharmacogenomics researcher reading and extracting annotations from the following article:

{article_text}

From this article, note down ALL discussed variants/haplotypes (ex. rs113993960, CYP1A1*1, etc.). Include information on the gene group and allele (if present). Your output format should be a list of the variants with the following attributes:
Variant: The Variant / Haplotypes (ex. rs2909451, CYP2C19*1, CYP2C19*2, *1/*18, etc.)
Gene: The gene group of the variant (ex. DPP4, CYP2C19, KCNJ11, etc.)
Allele: Specific allele or genotype if different from variant (ex. TT, *1/*18, del/del, etc.)
"""

In [23]:
from dotenv import load_dotenv
load_dotenv()
from src.inference import VariantList, Variant
import json
from loguru import logger
from src.utils import save_output
from src.components.all_variants import extract_all_variants


In [24]:
variant_list = extract_all_variants(pmcid="PMC11730665")

[32m2025-06-14 14:32:19.229[0m | [1mINFO    [0m | [36msrc.annotation_extraction.article_parser[0m:[36m__init__[0m:[36m32[0m - [1mGetting article text from PMCID: PMC11730665[0m
[32m2025-06-14 14:32:19.232[0m | [1mINFO    [0m | [36msrc.annotation_extraction.article_parser[0m:[36mremove_references_section[0m:[36m79[0m - [1mRemoved References section from article text[0m


In [177]:
variant_list

[{'variant_id': 'rs2909451', 'gene': 'DPP4', 'allele': 'TT'},
 {'variant_id': 'rs4664443', 'gene': 'DPP4', 'allele': 'GG'},
 {'variant_id': 'rs3765467', 'gene': 'GLP1R', 'allele': 'AG'},
 {'variant_id': 'rs2285676', 'gene': 'KCNJ11', 'allele': 'CC'},
 {'variant_id': 'rs163184', 'gene': 'KCNQ1', 'allele': 'GG'},
 {'variant_id': 'rs7754840', 'gene': 'CDKAL1', 'allele': 'CG'},
 {'variant_id': 'rs756992', 'gene': 'CDKAL1', 'allele': 'AG'},
 {'variant_id': 'rs1799853', 'gene': 'CYP2C9', 'allele': 'TT'},
 {'variant_id': 'rs1057910', 'gene': 'CYP2C9', 'allele': 'GG'}]

In [25]:
pmcid_grouped = json.load(open("data/variantAnnotations/annotations_by_pmcid.json"))
true_variant_list = {}
for paper in pmcid_grouped:
    variants = []
    variants.extend(paper['var_drug_ann'])
    variants.extend(paper['var_pheno_ann'])
    variants.extend(paper['var_fa_ann'])
    pmcid = paper['pmcid']
    variant_list = []
    for variant in variants:
        gene = variant['Gene']
        variant_id = variant['Variant/Haplotypes']
        allele = variant['Alleles']
        try:
            parsed_variant = Variant(variant_id=variant_id, gene=gene, allele=allele)
            variant_list.append(parsed_variant)
        except Exception as e:
            logger.error(f"Error parsing variant {variant_id} for PMCID {pmcid}: {e}")
            logger.error(f"Variant ID: {variant_id}")
            logger.error(f"Gene: {gene}")
            logger.error(f"Allele: {allele}")
    true_variant_list[pmcid] = variant_list

In [207]:
true_variant_list_json = {}
for key in true_variant_list:
    true_variant_list_json[key] = []
    for variant in true_variant_list[key]:
        true_variant_list_json[key].append(variant.model_dump())

In [208]:
true_variant_list_json

{'PMC5712579': [{'variant_id': 'HLA-B*35:08',
   'gene': 'HLA-B',
   'allele': '*35:08'},
  {'variant_id': 'HLA-B*39:01', 'gene': 'HLA-B', 'allele': '*39:01'},
  {'variant_id': 'HLA-B*15:02', 'gene': 'HLA-B', 'allele': '*15:02'},
  {'variant_id': 'HLA-B*44:03', 'gene': 'HLA-B', 'allele': '*44:03'},
  {'variant_id': 'HLA-A*02:07', 'gene': 'HLA-A', 'allele': '*02:07'},
  {'variant_id': 'HLA-A*33:03', 'gene': 'HLA-A', 'allele': '*33:03'}],
 'PMC3202555': [{'variant_id': 'rs1801272',
   'gene': 'CYP2A6',
   'allele': 'AA + AT'},
  {'variant_id': 'rs1801272', 'gene': 'CYP2A6', 'allele': 'AT + TT'},
  {'variant_id': 'CYP2A6*1, CYP2A6*12', 'gene': 'CYP2A6', 'allele': '*12'},
  {'variant_id': 'CYP2A6*1, CYP2A6*1x2', 'gene': 'CYP2A6', 'allele': '*1x2'},
  {'variant_id': 'CYP2B6*1, CYP2B6*9', 'gene': 'CYP2B6', 'allele': '*9'},
  {'variant_id': 'rs28399433', 'gene': 'CYP2A6', 'allele': 'A'},
  {'variant_id': 'CYP2B6*1, CYP2B6*4', 'gene': 'CYP2B6', 'allele': '*4'},
  {'variant_id': 'rs8192789', 'g

In [213]:
# save true variant list to file
json.dump(true_variant_list_json, open("data/benchmark/true_variant_list.json", "w"), indent=2)

## Compare extract variants list to true_variant_list

In [290]:
from tqdm import tqdm
from src.utils import compare_lists

In [277]:
for pmcid in extracted_variants.keys():
    experimental = [x["variant_id"] for x in extracted_variants[pmcid]]
    ground_truth = [x["variant_id"] for x in true_variants[pmcid]]
    # print differences
    compare_lists(ground_truth, experimental)

Experimental List:
[31mrs1057910[0m [31mrs1799853[0m [31mrs9923231[0m [31mrs1043550[0m [31mrs12714145[0m [31mrs1051740[0m [31mrs2108622[0m [31mrs6046[0m [31mrs11676382[0m [31mrs699664[0m [31mrs1799853[0m [31mrs1057910[0m [31mrs11676382[0m [31mrs6046[0m [31mrs2108622[0m [31mrs699664[0m [31mrs1043550[0m [31mrs12714145[0m [31mrs1051740[0m [31mrs9923231[0m

Ground Truth List:
[31mrs113993960[0m [31mCYP2C19*1[0m [31mCYP2C19*2[0m [31mCYP2C19*3[0m [31mCYP2C19*4[0m [31mCYP2C19*5[0m [31mCYP2C19*6[0m [31mCYP2C19*7[0m [31mCYP2C19*8[0m [31mCYP2C19*9[0m [31mCYP2C19*10[0m [31mCYP2C19*11[0m [31mCYP2C19*12[0m [31mCYP2C19*13[0m [31mCYP2C19*14[0m [31mCYP2C19*15[0m [31mCYP2C19*16[0m [31mCYP2C19*17[0m [31mCYP2C19*18[0m [31mCYP2C19*19[0m [31mCYP2C19*20[0m [31mCYP2C19*21[0m [31mCYP2C19*22[0m [31mCYP2C19*23[0m [31mCYP2C19*24[0m [31mCYP2C19*25[0m [31mCYP2C19*26[0m [31mCYP2C19*27[0m [31mCYP2C19*28[0m [31mCYP2C19*

In [254]:
print(extracted_variants.keys())
print(true_variants.keys())

dict_keys(['PMC6714673', 'PMC3553682', 'PMC4270923', 'PMC4557249', 'PMC4220464', 'PMC4730664'])
dict_keys(['PMC6714673', 'PMC3553682', 'PMC4270923', 'PMC4557249', 'PMC4220464', 'PMC4730664'])


In [268]:
[x['variant_id'] for x in ground_truth[pmcid]]

TypeError: list indices must be integers or slices, not str

In [274]:
true_variants = json.load(open("data/benchmark/true_variant_list.json"))

In [276]:
true_variants["PMC11730665"]

[{'variant_id': 'rs2909451', 'gene': 'DPP4', 'allele': 'TT'},
 {'variant_id': 'rs2285676', 'gene': 'KCNJ11', 'allele': 'CC'},
 {'variant_id': 'rs163184', 'gene': 'KCNQ1', 'allele': 'GG'},
 {'variant_id': 'rs7754840', 'gene': 'CDKAL1', 'allele': 'CG'},
 {'variant_id': 'rs4664443', 'gene': 'DPP4', 'allele': 'GG'},
 {'variant_id': 'rs1799853', 'gene': 'CYP2C9', 'allele': 'TT'},
 {'variant_id': 'rs3765467', 'gene': 'GLP1R', 'allele': 'AG'},
 {'variant_id': 'rs6923761', 'gene': 'GLP1R', 'allele': 'AA'}]

In [295]:
extracted_variants["PMC11730665"] = extract_variants_list(pmcid="PMC11730665", debug=True)

[32m2025-06-13 20:08:55.192[0m | [1mINFO    [0m | [36msrc.annotation_extraction.article_parser[0m:[36m__init__[0m:[36m32[0m - [1mGetting article text from PMCID: PMC11730665[0m
[32m2025-06-13 20:08:55.194[0m | [1mINFO    [0m | [36msrc.annotation_extraction.article_parser[0m:[36mremove_references_section[0m:[36m79[0m - [1mRemoved References section from article text[0m
[32m2025-06-13 20:08:55.194[0m | [34m[1mDEBUG   [0m | [36msrc.annotation_extraction.components[0m:[36mextract_variants_list[0m:[36m18[0m - [34m[1mModel: gpt-4o, Temperature: 0.1[0m
[32m2025-06-13 20:08:55.195[0m | [34m[1mDEBUG   [0m | [36msrc.annotation_extraction.components[0m:[36mextract_variants_list[0m:[36m19[0m - [34m[1mPMCID: PMC11730665[0m


In [296]:
experimental = [x["variant_id"] for x in extracted_variants["PMC11730665"]]
ground_truth = [x["variant_id"] for x in true_variants["PMC11730665"]]
# print differences
compare_lists(ground_truth, experimental, "PMC11730665")

Experimental List:
[32mrs2909451[0m [32mrs2285676[0m [32mrs163184[0m [32mrs7754840[0m [32mrs4664443[0m [32mrs1799853[0m [32mrs3765467[0m [32mrs6923761[0m

Ground Truth List:
[32mrs2909451[0m [32mrs4664443[0m [32mrs3765467[0m [32mrs6923761[0m [32mrs163184[0m [32mrs2285676[0m [32mrs7754840[0m [31mrs756992[0m [32mrs1799853[0m [31mrs1057910[0m


(8, 0, 0, 2)