In [63]:
import requests
import pandas as pd
import numpy as np
import json
import os


# Get the icd codes

In [64]:
link = 'https://www.cms.gov/files/document/valid-icd-10-list.xlsx'

In [65]:
# Download the xlsx file
response = requests.get(link)
with open('./data/icd10_codes.xlsx', 'wb') as file:
    file.write(response.content)

# Open the xlsx file
icd_codes = pd.read_excel('./data/icd10_codes.xlsx')
print(icd_codes.head())

    CODE            SHORT DESCRIPTION (VALID ICD-10 FY2024)  \
0   A000  Cholera due to Vibrio cholerae 01, biovar chol...   
1   A001    Cholera due to Vibrio cholerae 01, biovar eltor   
2   A009                               Cholera, unspecified   
3  A0100                         Typhoid fever, unspecified   
4  A0101                                 Typhoid meningitis   

              LONG DESCRIPTION (VALID ICD-10 FY2024) NF EXCL  
0  Cholera due to Vibrio cholerae 01, biovar chol...     NaN  
1    Cholera due to Vibrio cholerae 01, biovar eltor     NaN  
2                               Cholera, unspecified     NaN  
3                         Typhoid fever, unspecified     NaN  
4                                 Typhoid meningitis     NaN  


In [66]:
icd_codes = icd_codes.drop('NF EXCL', axis=1)
icd_codes.columns = ['code', 'short_description','long_description']
icd_codes = icd_codes.drop('short_description', axis=1) 
icd_codes=icd_codes.dropna()
print(icd_codes.isna().mean()) # Check for missing values

code                0.0
long_description    0.0
dtype: float64


# Create batch files

In [67]:
# Create batch files 
# Each batch is limited to 50 thousand rows

batch_file  =  icd_codes.copy()
batch_file_name = 'icd_codes_batch'
num_files = len(batch_file) // 50000 + 1
for num_file in range(num_files):
    output_file = f'./batch_files/{batch_file_name}_part{num_file}.jsonl'
    # make sure that the file does not exist, so don't add to an existing file
    if os.path.exists(output_file):
        os.remove(output_file)
    with open(output_file, 'a') as file:
        for index, row in batch_file.iloc[50000*num_file : min(50000*(num_file+1),len(batch_file))].iterrows():
            payload = {
                "custom_id":f"custom_id_{index}",
                "method": "POST",
                "url": "/v1/embeddings",
                "body": {
                    "input": row["long_description"],
                    "model": "text-embedding-3-large",
                    "encoding_format": "float"
                }
            }
            file.write(json.dumps(payload) + '\n')

    # Sanity check
    with open(output_file, 'r') as file:
        for line in file.readlines()[:2]:
            print(line)


{"custom_id": "custom_id_0", "method": "POST", "url": "/v1/embeddings", "body": {"input": "Cholera due to Vibrio cholerae 01, biovar cholerae", "model": "text-embedding-3-large", "encoding_format": "float"}}

{"custom_id": "custom_id_1", "method": "POST", "url": "/v1/embeddings", "body": {"input": "Cholera due to Vibrio cholerae 01, biovar eltor", "model": "text-embedding-3-large", "encoding_format": "float"}}

{"custom_id": "custom_id_50002", "method": "POST", "url": "/v1/embeddings", "body": {"input": "Nondisplaced spiral fracture of shaft of unspecified fibula, initial encounter for open fracture type IIIA, IIIB, or IIIC", "model": "text-embedding-3-large", "encoding_format": "float"}}

{"custom_id": "custom_id_50003", "method": "POST", "url": "/v1/embeddings", "body": {"input": "Nondisplaced spiral fracture of shaft of unspecified fibula, subsequent encounter for closed fracture with routine healing", "model": "text-embedding-3-large", "encoding_format": "float"}}



# Run the batch embeddings 

### set up openai environement

In [68]:
# set up the openai environment
from openai import OpenAI
# OPENAI_API_KEY = '<YOUR_API_KEY>'
# os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
client = OpenAI()

In [69]:
# create the batch files for the batch job
batch_folder = './batch_files'
batch_input_files = []
for file in os.listdir(batch_folder):
    batch_input_files.append(client.files.create(
        file=open(f'{batch_folder}/{file}', "rb"),
        purpose="batch"
    ))

In [70]:
# create the batch job
batch_file_ids= [batch_file.id for batch_file in batch_input_files] # we get the ids of the batch files
job_creations = []
for i,file_id in enumerate(batch_file_ids):
    job_creations.append(client.batches.create(
    input_file_id=file_id,
    endpoint="/v1/embeddings",
    completion_window="24h", # currently only 24h is supported 
    metadata={
      "description": f"part_{i}_icd_embeddings"
    }
    ))

In [71]:
# WE can see here the jobs created, they start with validation
for job in job_creations:
    print(job)

# we extract the ids for the job to check the status
job_ids = [job.id for job in job_creations]

Batch(id='batch_keC2nwmzuJ8PmgqInkn3Hcg7', completion_window='24h', created_at=1720811130, endpoint='/v1/embeddings', input_file_id='file-SxQJNEeK6kHcqVLb4t855vWi', object='batch', status='validating', cancelled_at=None, cancelling_at=None, completed_at=None, error_file_id=None, errors=None, expired_at=None, expires_at=1720897530, failed_at=None, finalizing_at=None, in_progress_at=None, metadata={'description': 'part_0_icd_embeddings'}, output_file_id=None, request_counts=BatchRequestCounts(completed=0, failed=0, total=0))
Batch(id='batch_zOIjkAp9PDkGYVgRXHvbdrpx', completion_window='24h', created_at=1720811130, endpoint='/v1/embeddings', input_file_id='file-EYv1ENRK0HBxRNsmlJ7rvh5J', object='batch', status='validating', cancelled_at=None, cancelling_at=None, completed_at=None, error_file_id=None, errors=None, expired_at=None, expires_at=1720897530, failed_at=None, finalizing_at=None, in_progress_at=None, metadata={'description': 'part_1_icd_embeddings'}, output_file_id=None, request_c

In [83]:
import time
fail_flag = False
finished = set()
while True:
    # we check the status of the jobs
    for job_id in job_ids:
        job  = client.batches.retrieve(job_id)
        if job.status == "failed":
            print(f"Job {job_id} has failed with error {job.errors}")
            fail_flag = True
            break
        elif job.status == 'in_progress':
            print(f'Job {job_id} is in progress, {job.request_counts.completed}/{job.request_counts.total} requests completed')
        elif job.status == 'finalizing':
            print(f'Job {job_id} is finalizing, waiting for the output file id')
        elif job.status == "completed":
            print(f"Job {job_id} has finished")
            finished.add(job_id)
        else:
            print (f'Job {job_id} is in status {job.status}')
        
    if fail_flag == True or len(finished) == len(job_ids):
        break
    time.sleep(600)

# When the job is finished we can see the output file id that will be used to extract the output files
output_files_ids= []
if fail_flag ==  False:
    for job_id in job_ids:
        output_files_ids.append(client.batches.retrieve(job_id).output_file_id)

Job batch_keC2nwmzuJ8PmgqInkn3Hcg7 has finished
Job batch_zOIjkAp9PDkGYVgRXHvbdrpx is finalizing, waiting for the output file id
Job batch_keC2nwmzuJ8PmgqInkn3Hcg7 has finished
Job batch_zOIjkAp9PDkGYVgRXHvbdrpx is finalizing, waiting for the output file id
Job batch_keC2nwmzuJ8PmgqInkn3Hcg7 has finished
Job batch_zOIjkAp9PDkGYVgRXHvbdrpx is finalizing, waiting for the output file id
Job batch_keC2nwmzuJ8PmgqInkn3Hcg7 has finished
Job batch_zOIjkAp9PDkGYVgRXHvbdrpx has finished


In [84]:
for job_id in job_ids:
        job  = client.batches.retrieve(job_id)
        print(f'{job.request_counts.failed}/{job.request_counts.total} requests failed in job {job_id}')   

0/23199 requests failed in job batch_keC2nwmzuJ8PmgqInkn3Hcg7
0/50000 requests failed in job batch_zOIjkAp9PDkGYVgRXHvbdrpx


# Extract the embedding files

In [86]:
output_files = []
for output_file_id in output_files_ids:
    output_file = client.files.content(output_file_id).text
    output_files.append(output_file)
    output_file_split = output_file.split('\n')
    print(len(output_file_split))

23200
50001


In [87]:
embedding_results = []
for file in output_files:
    for line in file.split('\n')[:-1]:
            data =json.loads(line)
            custom_id = data.get('custom_id')
            embedding = data['response']['body']['data'][0]['embedding']
            embedding_results.append([custom_id, embedding])


embedding_results = pd.DataFrame(embedding_results, columns=['custom_id', 'embedding'])

In [88]:
icd_codes = icd_codes.reset_index()
icd_codes = icd_codes.rename(columns={'index':'id'})
embedding_results['id'] = embedding_results['custom_id'].apply(lambda x: int(x.split('custom_id_')[1]))
icd_codes_with_embedding = icd_codes.merge(embedding_results[['id','embedding']], on='id', how='left') 

In [89]:
icd_codes_with_embedding.to_csv('./data/icd_codes_with_embedding.csv', index=False)

In [90]:
icd_codes_with_embedding.head()

Unnamed: 0,id,code,long_description,embedding
0,0,A000,"Cholera due to Vibrio cholerae 01, biovar chol...","[-0.01081124, 0.016794948, -0.0016598871, 0.01..."
1,1,A001,"Cholera due to Vibrio cholerae 01, biovar eltor","[-0.020291328, 0.020674184, 0.0033783433, 0.02..."
2,2,A009,"Cholera, unspecified","[-0.005399217, 0.010569167, -0.00032957003, 0...."
3,3,A0100,"Typhoid fever, unspecified","[-0.011074301, 0.013406112, 0.0035753243, -0.0..."
4,4,A0101,Typhoid meningitis,"[0.011390618, -0.0014092546, 0.0023590443, -0...."
