In [1]:
from openai import OpenAI
import json
import os
import dotenv

In [2]:
dotenv.load_dotenv()
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

In [3]:
with open("../data/pubmed_abstracts_20250519.json") as f:
    b = json.load(f)

with open("../data/pubmed_authors_20250502.json") as f:
    authors = json.load(f)

print(len(b))
print(len(authors))

16524
13589


In [4]:
# Remove entries in b that are already in authors
b = [entry for entry in b if entry['pmid'] not in [a['pmid'] for a in authors]]
len(b)

7827

In [5]:
b[0]

{'pmid': '14656017',
 'ab': "Classical homocystinuria is associated with arterial vascular diseases and venous thrombosis. In the last decade, many studies, including some prospective studies, have been published indicating that moderate hyperhomocysteinaemia is also a risk factor for venous thrombosis. The 677C>T mutation in the methylenetetrahydrofolate reductase (MTHFR) gene is an important cause of mild hyperhomocysteinaemia. Recent metaanalyses show an elevated risk of venous thrombosis for subjects with the TT-genotype. Based on the concept of 'Mendelian randomisation', this observation supports the hypothesis that hyperhomocysteinaemia is a causal risk factor for venous thrombosis. The results of one homocysteine-lowering trial regarding venous thrombosis are awaited at the end of 2003. In this paper the current evidence for hyperhomocysteinaemia as a risk factor for venous thrombosis is being discussed.",
 'pub_date': '2003-12-06',
 'title': 'Hyperhomocysteinaemia as a risk fac

Send one abstract at a time the the openai api - this is very slow and more expensive.

In [6]:
auth_prompt = {"role": "user", "content": """Extract the university name and country from this text. Provide the result in json format with one field for the 'institution' and one field for the 'country'. If the country is not mentioned, provide an empty string. If the institution is not mentioned, provide an empty string. If the institution is mentioned but not the country, provide an empty string for the country. For the institution, retain only the university name and no department names etc."""}

def openai_prompt_auth(author_affil, pmid):
    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
            messages=[{"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": bytes(author_affil, 'utf-8').decode('utf-8', 'ignore')},
                    auth_prompt],
    )
    o = json.loads(response.choices[0].message.content)
    o['pmid'] = pmid
    return o

author_affil = b[0]['author_affil']
openai_prompt_auth(author_affil, 12345678)

# result = []
# for i in range(len(a)):
#     print(i)
#     if 'ab' not in a[i].keys():
#         continue
#     try:        
#         o = openai_prompt_auth(a[i]['author_affil'], a[i]['pmid'])
#         result.append(o)
#     except:
#         continue
#     if i % 100 == 0:
#         with open("data/pubmed_authors.json", "w") as f:
#             json.dump(result, f)

# with open("data/pubmed_authors.json", "w") as f:
#     json.dump(result, f)


# result = []
# for i in range(len(b)):
#     print(i)
#     if 'ab' not in b[i].keys():
#         continue
#     try:        
#         o = openai_prompt(b[i]['ab'], b[i]['pmid'])
#         result.append(o)
#     except:
#         continue
#     if i % 100 == 0:
#         with open("../data/abstract_summary_20250502.json", "w") as f:
#             json.dump(result, f)

# with open("../data/abstract_summary_20250502.json", "w") as f:
#     json.dump(result, f)

{'institution': 'University Medical Center Nijmegen',
 'country': 'The Netherlands',
 'pmid': 12345678}

Batches are faster and cheaper.
1. For each batch create a .jsonl file with the requests
2. Use the openai api to send the requests in the .jsonl file
3. Retrieve the responses, parse and save them

In [None]:
def openai_prompt_auth_batch(abstracts, jsonl_file_root, batch_size=500):
    """
    abstracts: list of abstracts to process
    jsonl_file_root: file to write the results to e.g. "../data/author_processing_20250502". <batch>.jsonl will be appended to this filename
    batch_size: number of abstracts to process at once
    """
    auth_prompt = {"role": "user", "content": """Extract the university name and country from this text. Provide the result in json format with one field for the 'institution' and one field for the 'country'. If the country is not mentioned, provide an empty string. If the institution is not mentioned, provide an empty string. If the institution is mentioned but not the country, provide an empty string for the country. For the institution, retain only the university name and no department names etc."""}
    
    # Get number of batches to generate
    num_batches = len(abstracts) // batch_size + 1
    print("Number of batches: {}".format(num_batches))
    batch_file_names = ["{}.{}.jsonl".format(jsonl_file_root, i) for i in range(num_batches)]
    
    for batch in range(num_batches):
        # Get the batch of abstracts
        start = batch * batch_size
        end = min((batch + 1) * batch_size, len(abstracts))
        print("Processing batch {} of {}".format(batch, num_batches))
        
        # Create the jsonl file for this batch
        jsonl_file = batch_file_names[batch]
        
        # Process the abstracts in this batch
        with open(jsonl_file, "w") as f:
            for i in range(start, end):
                if 'author_affil' not in abstracts[i].keys():
                    continue
                
                try:
                    o = {
                        'custom_id': abstracts[i]['pmid'], 
                        "method": "POST", 
                        "url": "/v1/chat/completions",
                        "body": {
                            "model": "gpt-3.5-turbo",
                            "messages": [
                                {"role": "system", "content": "You are a helpful assistant."},
                                {"role": "user", "content": bytes(abstracts[i]['author_affil'], 'utf-8').decode('utf-8', 'ignore')},
                                auth_prompt
                            ],
                            "max_tokens": 1000
                        }
                    }
                    f.write(json.dumps(o) + "\n")
                except:
                    continue
    return batch_file_names
author_batches = openai_prompt_auth_batch(b, "../data/author_processing_20250519_batch", batch_size=500)

Number of batches: 16
Processing batch 0 of 16
Processing batch 1 of 16
Processing batch 2 of 16
Processing batch 3 of 16
Processing batch 4 of 16
Processing batch 5 of 16
Processing batch 6 of 16
Processing batch 7 of 16
Processing batch 8 of 16
Processing batch 9 of 16
Processing batch 10 of 16
Processing batch 11 of 16
Processing batch 12 of 16
Processing batch 13 of 16
Processing batch 14 of 16
Processing batch 15 of 16


In [None]:
batch_input_file = []
for i in range(len(author_batches)):
    a = client.files.create(
        file=open(author_batches[i], "rb"),
        purpose="batch"
    )
    batch_input_file.append(a)
    

In [75]:
batch_input_file[0].id

'file-4AhPiq8QHmGsMBaVQX5Cht'

In [None]:
batches = []
for i in range(len(batch_input_file)):
    f = batch_input_file[i]
    batch_input_file_id = f.id
    a = client.batches.create(
        input_file_id=batch_input_file_id,
        endpoint="/v1/chat/completions",
        completion_window="24h",
        metadata={
            "author processing": "20250502",
            "batch": str(i),
            "batch_input_file": author_batches[i]
        }
    )
    batches.append(a)

In [80]:
batches = ["batch_682a3e53bfe88190acc35829d83d1fe5", "batch_682a3e541f9881908dfbe48e0fe60e24", "batch_682a3e546d0481909180ac02cd18abd0", "batch_682a3e54b49881908a435724f4452912", "batch_682a3e54fab08190a5fe9496d9d70459", "batch_682a3e5579e08190bcb1f5b1fc4d4021", "batch_682a3e55df80819091dc54a586d6de08", "batch_682a3e5646bc819092a794c144a72fdd", "batch_682a3e56cc948190b41cec169671bf81", "batch_682a3e5767708190ae0d37ae17ff0712", "batch_682a3e57c6cc8190a9b8c88f5c5cf5ae"]
batch = client.batches.retrieve(batches[i])
batch

Batch(id='batch_682a3e57c6cc8190a9b8c88f5c5cf5ae', completion_window='24h', created_at=1747598935, endpoint='/v1/chat/completions', input_file_id='file-Q7KuX3yoZmPMPF1hG1D8yR', object='batch', status='in_progress', cancelled_at=None, cancelling_at=None, completed_at=None, error_file_id=None, errors=None, expired_at=None, expires_at=1747685335, failed_at=None, finalizing_at=None, in_progress_at=1747598937, metadata={'author processing': '20250502', 'batch': '10', 'batch_input_file': '../data/author_processing_20250502_batch.10.jsonl'}, output_file_id=None, request_counts=BatchRequestCounts(completed=388, failed=0, total=391))

In [87]:
def retrieve_batch_status(batches):
    batch_status = []
    results_files = []
    for i in range(len(batches)):
        batch = client.batches.retrieve(batches[i])
        batch_status.append(batch)
        results_files.append(batch.output_file_id)

    # Count the number of batches in each status
    status_count = {}
    for status in batch_status:
        if status.status not in status_count:
            status_count[status.status] = 0
        status_count[status.status] += 1
    # Print the status count
    for status, count in status_count.items():
        print(f"Status: {status}, Count: {count}")
    # Return results files if all batches are completed
    if all(status.status == "completed" for status in batch_status):
        return results_files
    else:
        return None

batch_output_files = retrieve_batch_status(batches)
batch_output_files


Status: completed, Count: 11


['file-PnhwmGJ4jGi4LskGmPFqrf',
 'file-HhaBVhwUk1PGNVD37KMeYa',
 'file-1mXH6mEATQd95gV2gV4QYq',
 'file-7T3z5EGF5GqKXzutVWCaqg',
 'file-BoXbm2a27gmAXQFEVc1WqP',
 'file-WaZZoEyQcHyLMrPzhKirKP',
 'file-LeKvjxhhvcgLVAaqCoPkJD',
 'file-DwrRcG4SjushNSn1G7hmYc',
 'file-CN9XN5iem4kg459PFG99yY',
 'file-WzY1JrCsJj7iKrUBP1BqZi',
 'file-9wXcRtASgUFXHfhsDa3ZpT']

In [None]:
def read_output(input):
    # remove markdown code blocks
    input = input.replace("```json", "")
    input = input.replace("```", "")
    # remove leading and trailing whitespace
    input = input.strip()
    # remove leading and trailing newlines
    input = input.strip("\n")
    # remove leading and trailing spaces
    input = input.strip(" ")
    # remove leading and trailing tabs
    input = input.strip("\t")
    # remove leading and trailing carriage returns
    input = input.strip("\r")
    # remove leading and trailing form feeds
    input = input.strip("\f")
    # remove leading and trailing vertical tabs
    input = input.strip("\v")
    # remove leading and trailing null characters
    input = input.strip("\0")
    return json.loads(input)

if batch_output_files is not None:
    # Get results from the batches
    batch_results = []
    for i in range(len(batch_output_files)):
        print("Processing batch {}".format(i))
        batch = client.files.content(batch_output_files[i])
        batch_output = [json.loads(line) for line in file_response.text.splitlines()]
        for x in batch_output:
            try:
                pmid = {"pmid": x["custom_id"]}
                cont = read_output(x["response"]["body"]["choices"][0]["message"]["content"])
                a = {**pmid, **cont}
                batch_results.append(a)
            except:
                print("Error in response for pmid {}".format(x["custom_id"]))
                continue
        with open("../data/author_processing_20250502.json", "w") as f:
            json.dump(batch_results, f)
else:
    print("Not all batches are completed yet.")
    batch_results = []


{'institution': 'Nanchang University', 'country': 'China'}
{'institution': 'Shanxi Medical University', 'country': 'China'}
{'institution': 'Sun Yat-sen University', 'country': 'China'}
{'institution': 'Chung Shan Medical University', 'country': 'Taiwan'}
{'institution': 'Soochow University', 'country': 'China'}
{'institution': 'Changzhou Maternal and Child Health Care Hospital', 'country': 'China'}
{'institution': 'Binzhou Medical University', 'country': 'China'}
{'institution': 'Huazhong University of Science and Technology', 'country': 'China'}
{'institution': 'Westlake University', 'country': 'China'}
{'institution': 'Harbin Medical University', 'country': 'China'}
{'institution': 'Jiangnan University', 'country': 'China'}
{'institution': 'Jilin University', 'country': 'China'}
{'institution': 'Vanderbilt University', 'country': 'USA'}
{'institution': 'Shahid Beheshti University of Medical Sciences', 'country': 'Iran'}
{'institution': 'University of Cambridge', 'country': 'UK'}
{'i

In [89]:
with open("../data/pubmed_authors.json") as f:
    authors = json.load(f)
# merge authors and batch_results
authors_full = [x for x in authors if x['pmid'] not in [y['pmid'] for y in batch_results]]
for x in batch_results:
    authors_full.append(x)
len(authors_full)

13589

In [90]:
with open("../data/pubmed_authors_20250502.json", "w") as f:
    json.dump(authors_full, f)