In [1]:
import pandas as pd
from datasets.dataset_dict import DatasetDict
from datasets import Dataset
import re
import numpy as np
from pandarallel import pandarallel
pandarallel.initialize(progress_bar=True)

INFO: Pandarallel will run on 10 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


In [2]:
from multiprocessing import Pool
import random
from os import listdir
import os
import pandas as pd
import time
from pathlib import Path
import re
from tqdm import tqdm

# Load Dataset

## Merge and Save

In [3]:
def merge_two_dfs_no_dup_cols(df1, df2, merge_col):
    cols_to_use = df2.columns.difference(df1.columns).tolist() + merge_col
    return df1.merge(df2[cols_to_use], on=merge_col)

In [3]:
dfs = {}

for subset in ['train', 'valid', 'test_phase_1']:
    df_discharge_target = pd.read_csv(f'./physionet.org/files/discharge-me/1.2/{subset}/discharge_target.csv.gz', keep_default_na=False)
    df_discharge = pd.read_csv(f'./physionet.org/files/discharge-me/1.2/{subset}/discharge.csv.gz', keep_default_na=False)
    df_radiology = pd.read_csv(f'./physionet.org/files/discharge-me/1.2/{subset}/radiology.csv.gz', keep_default_na=False)
    df_diagnoses_ed = pd.read_csv(f'./physionet.org/files/discharge-me/1.2/{subset}/diagnosis.csv.gz', keep_default_na=False)
    df_triage_ed = pd.read_csv(f'./physionet.org/files/discharge-me/1.2/{subset}/triage.csv.gz', keep_default_na=False)
    df_stays_ed = pd.read_csv(f'./physionet.org/files/discharge-me/1.2/{subset}/edstays.csv.gz', keep_default_na=False)
#     dfs[subset] = df_discharge.merge(df_discharge_target, on=['note_id', 'hadm_id'])    

    merged_df = merge_two_dfs_no_dup_cols(df_discharge, df_discharge_target, ['hadm_id'])
    merged_df = merge_two_dfs_no_dup_cols(merged_df, df_radiology.rename(columns={'text': 'radiology_text'}), ['hadm_id'])
    merge_col = merged_df.drop(columns=['radiology_text']).columns.tolist()
    merged_df = merged_df.groupby(merge_col).agg({
        'radiology_text': lambda x: x.tolist()
    }).reset_index()
    
    merged_df_2 = merge_two_dfs_no_dup_cols(df_stays_ed, df_triage_ed, ['stay_id', 'subject_id'])
    merged_df_2 = merge_two_dfs_no_dup_cols(merged_df_2, df_diagnoses_ed, ['stay_id', 'subject_id'])
    merge_col_2 = [col for col in merged_df_2.columns if col not in df_diagnoses_ed.drop(columns=['stay_id', 'subject_id']).columns]
    agg_col_2 = {col: lambda x: x.tolist() for col in df_diagnoses_ed.drop(columns=['stay_id', 'subject_id']).columns}
    merged_df_2 = merged_df_2.groupby(merge_col_2).agg(agg_col_2).reset_index()
    
    final_merged_df = merged_df.merge(merged_df_2)
    dfs[subset] = final_merged_df

## Remove Target from the Input Text

In [12]:
def remove_output_from_input(row):
    row['new_text'] = row['text'].replace(row['brief_hospital_course'], '')
    row['new_text'] = re.sub(r'Brief Hospital Course:\n*', r'', row['new_text'], flags=re.DOTALL)
    
    row['new_text'] = row['new_text'].replace(row['discharge_instructions'], '')
    row['new_text'] = re.sub(r'Discharge Instructions:\n*', r'', row['new_text'], flags=re.DOTALL)
    
    row['new_text'] = re.sub(r'(\n ?)+(Followup Instructions)', r'\n\n\n\2', row['new_text'], flags=re.DOTALL)
    
    return row
    
discharge = discharge.apply(remove_output_from_input, axis=1)

In [16]:
for domain, df in dfs.items():
    dfs[domain] = dfs[domain].apply(remove_output_from_input, axis=1)

In [17]:
for domain, df in dfs.items():
    dfs[domain] = dfs[domain].rename(columns={'new_text': 'processed_text'})

## Remove Target Less than 10 words

In [18]:
dfs['train'][dfs['train']['discharge_instructions_word_count'] < 10]

Unnamed: 0,note_id,subject_id,hadm_id,note_type,note_seq,charttime,storetime,text,brief_hospital_course,brief_hospital_course_word_count,...,o2sat,pain,resprate,sbp,temperature,seq_num,icd_code,icd_version,icd_title,processed_text


In [19]:
dfs['train'][dfs['train']['brief_hospital_course_word_count'] < 10]

Unnamed: 0,note_id,subject_id,hadm_id,note_type,note_seq,charttime,storetime,text,brief_hospital_course,brief_hospital_course_word_count,...,o2sat,pain,resprate,sbp,temperature,seq_num,icd_code,icd_version,icd_title,processed_text


In [20]:
dfs['train'][(dfs['train']['discharge_instructions_word_count'] < 10) | 
             (dfs['train']['brief_hospital_course_word_count'] < 10)]

Unnamed: 0,note_id,subject_id,hadm_id,note_type,note_seq,charttime,storetime,text,brief_hospital_course,brief_hospital_course_word_count,...,o2sat,pain,resprate,sbp,temperature,seq_num,icd_code,icd_version,icd_title,processed_text


In [21]:
for domain, df in dfs.items():
    mask = (dfs[domain]['discharge_instructions_word_count'] >= 10)
    mask &= (dfs[domain]['brief_hospital_course_word_count'] >= 10)
    dfs[domain] = dfs[domain][mask]

## Export some sample discharge summaries

In [None]:
from docx import Document
from docx.shared import Pt

top = 10
for i, row in dfs['train'].reset_index(drop=True).head(top).iterrows():
    document = Document()
    
    style = document.styles['Normal']
    font = style.font
    font.name = 'Courier New'
    font.size = Pt(10.5)
    style.paragraph_format.line_spacing = 1
    style.paragraph_format.space_after = Pt(0)

    
#     document.add_paragraph(row['processed_text'], style=style)
    for line in row['processed_text'].split("\n"):
#         if line.strip() != '':
#         if len(line) > 0:
        document.add_paragraph(line, style=style)

    document.save(f'discharge_summary_samples/discharge_summary_{i}.docx')

# Section Extraction (Parsing)

In [251]:
from collections import OrderedDict

input_sections = OrderedDict([
    ('Allergies', 'Allergies'),
    ('Chief Complaint', '(?:Chief|_+) Complaint'),
    ('Major Surgical or Invasive Procedure', '(?:Major |_+ *)(?:Surgical |_+ *)(?:or |_+ *)(?:Invasive|_+ *) Procedure'),
    ('History of Present Illness', '(?:History|_+) of Present Illness'),
    ('Past Medical History', '(?:Past|_+) Medical History'),
    ('Social History', '(?:Social|_+) History'),
    ('Family History', '(?:Family|_+) History'),
    ('Physical Exam', 'Physical [A-Za-z_]+'),
    ('Pertinent Results', '(?:Pertinent|_+) Results'),
    ('Brief Hospital Course', 'Brief Hospital Course'),
    ('Medications on Admission', '[A-Za-z_]+ on Admission'),
    ('Discharge Medications', '[A-Za-z_]+ Medications'),
    ('Discharge Disposition', '[A-Za-z_]+ Disposition'),
    ('Discharge Diagnosis', '[A-Za-z_]+ Diagnosis'),
    ('Discharge Condition', '[A-Za-z_]+ Condition')
])

In [None]:
# SECTIONS REPORT
s_list = []
for section in input_sections.keys():
    print("======", section.upper(), "======")
    s = pd.Series()
    s.name = section
    for subset in ['train', 'valid', 'test']:
        size = dfs[subset].shape[0]
        filtered_size = dfs[subset][dfs[subset]['text'].str.contains(section)].shape[0]
        print(subset.upper(), size, filtered_size, filtered_size / size)
        s[subset] = filtered_size / size
    s_list += [s]
    print()

In [None]:
pd.DataFrame(s_list)

## Parse

In [32]:
def parse_sections(row):
    discharge_summary = row['text']
    
    for i, (section_name, section) in enumerate(input_sections.items()):
        if section_name in ['Pertinent Results', 'Physical Exam', 'Brief Hospital Course', 'Past Medical History', 'Social History', 'Family History']:
            for next_section in list(input_sections.values())[i+1:]:
#                 regexp = re.compile(next_section)
#                 if regexp.search(discharge_summary):
#                     break
#             print("Next Section: ", next_section)
                search = re.findall(section + ".+\n" + next_section, discharge_summary, re.DOTALL)
                if len(search) > 0:
                    break

#             rex = r'(%s?):\s*\n{0,2}(.+?)\s*(\n\s*){2,10}(%s):(\n[^=\n])' % (section, next_section)
            rex = r'(%s?):\s*\n{0,2}(.+?)\s*(\n\s*){1,10}(%s):\n' % (section, next_section)
        else:
#             rex = r'(%s?):\s*\n{0,2}(.+?)\s*(\n\s*){2,10}((?:[A-Z][a-z ]+)+):(\n[^=\n])' % (section)
#             rex = r'(%s?):\s*\n{0,2}(.+?)\s*(\n\s*){1,10}((?:[A-Z][a-z ]+)+):\n' % (section)
            rex = r'(%s?):\s*\n{0,2}(.+?)\s*(\n\s*){1,10}((?:[A-Z][a-z ]+)+):' % (section)

#             rex = r'(%s?):\s*\n{1,2}(?!_+)(.+?)\s*(\n\s*){2,10}((?:[A-Z][a-z ]+)+):(\n[^=\n])' % (section)
                    
        section_ext = re.findall(rex, discharge_summary, re.DOTALL)
        section_col_name = section_name.replace(" ", "_")
        if len(section_ext) > 0:
            row[section_col_name] = section_ext[-1]
        else:
            row[section_col_name] = np.nan
            
    return row

In [35]:
# for subset in ['train', 'valid', 'test']:
for subset in ['train']:
    dfs[subset] = dfs[subset].parallel_apply(parse_sections, axis=1)

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=6858), Label(value='0 / 6858'))), …

In [36]:
for subset in ['valid']:
    dfs[subset] = dfs[subset].parallel_apply(parse_sections, axis=1)

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=1468), Label(value='0 / 1468'))), …

In [37]:
for subset in ['test']:
    dfs[subset] = dfs[subset].parallel_apply(parse_sections, axis=1)

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=1474), Label(value='0 / 1474'))), …

In [38]:
for subset in ['test_2']:
    dfs[subset] = dfs[subset].parallel_apply(parse_sections, axis=1)

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=1099), Label(value='0 / 1099'))), …

## Save Data

In [49]:
for subset in ['train', 'valid', 'test']:
    dfs[subset] = dfs[subset].drop(columns=['__index_level_0__'], errors='ignore')

In [50]:
discharge_data = DatasetDict({
    'train': Dataset.from_pandas(dfs['train']),
    'valid': Dataset.from_pandas(dfs['valid']),
    'test': Dataset.from_pandas(dfs['test'])
})

  if _pandas_api.is_sparse(col):


In [51]:
print("SUCCESS")

SUCCESS


In [52]:
discharge_data.save_to_disk("discharge-data-save-sections-extracted")

Saving the dataset (0/6 shards):   0%|          | 0/68574 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/14675 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/14731 [00:00<?, ? examples/s]

## Read Processed Data

In [3]:
discharge_data = DatasetDict.load_from_disk("./discharge-data-save-sections-extracted")

In [4]:
dfs = {}
for subset, data in discharge_data.items():
    data.set_format("pandas")
    dfs[subset] = data[:]

## Post-processing

In [5]:
from collections import OrderedDict

input_sections = OrderedDict([
    ('Allergies', 'Allergies'),
    ('Chief Complaint', '(?:Chief|_+) Complaint'),
    ('Major Surgical or Invasive Procedure', '(?:Major |_+ *)(?:Surgical |_+ *)(?:or |_+ *)(?:Invasive|_+ *) Procedure'),
    ('History of Present Illness', '(?:History|_+) of Present Illness'),
    ('Past Medical History', '(?:Past|_+) Medical History'),
    ('Social History', '(?:Social|_+) History'),
    ('Family History', '(?:Family|_+) History'),
    ('Physical Exam', 'Physical [A-Za-z_]+'),
    ('Pertinent Results', '(?:Pertinent|_+) Results'),
    ('Brief Hospital Course', 'Brief Hospital Course'),
    ('Medications on Admission', '[A-Za-z_]+ on Admission'),
    ('Discharge Medications', '[A-Za-z_]+ Medications'),
    ('Discharge Disposition', '[A-Za-z_]+ Disposition'),
    ('Discharge Diagnosis', '[A-Za-z_]+ Diagnosis'),
    ('Discharge Condition', '[A-Za-z_]+ Condition')
])
[col.replace(" ", "_") for col in input_sections.keys()]

['Allergies',
 'Chief_Complaint',
 'Major_Surgical_or_Invasive_Procedure',
 'History_of_Present_Illness',
 'Past_Medical_History',
 'Social_History',
 'Family_History',
 'Physical_Exam',
 'Pertinent_Results',
 'Brief_Hospital_Course',
 'Medications_on_Admission',
 'Discharge_Medications',
 'Discharge_Disposition',
 'Discharge_Diagnosis',
 'Discharge_Condition']

In [7]:
sections_col = [col.replace(" ", "_") for col in input_sections.keys()]
sections_col

['Allergies',
 'Chief_Complaint',
 'Major_Surgical_or_Invasive_Procedure',
 'History_of_Present_Illness',
 'Past_Medical_History',
 'Social_History',
 'Family_History',
 'Physical_Exam',
 'Pertinent_Results',
 'Brief_Hospital_Course',
 'Medications_on_Admission',
 'Discharge_Medications',
 'Discharge_Disposition',
 'Discharge_Diagnosis',
 'Discharge_Condition']

In [8]:
for subset in ['train', 'valid', 'test', 'test_2']:
    for col in sections_col:
        mask = pd.notnull(dfs[subset][col])
        dfs[subset].loc[mask, col] = dfs[subset].loc[mask, col].apply(lambda x: x[1])

# Diagnose Summarization

In [24]:
extractive_field = ['Physical_Exam', 'Pertinent_Results']

In [26]:
def calculate_word_count(row):
    discharge_sections = []
    for col in extractive_field:
        word_count = 0
        if pd.notnull(row[col]):
            word_count = len(row[col].split(" "))
        else:
            word_count = 0
        row[col+"_Word_Count"] = word_count
    
    return row

In [27]:
for subset in ['train', 'valid', 'test']:
    dfs[subset] = dfs[subset].apply(calculate_word_count, axis=1)

In [28]:
dfs['train']['Physical_Exam_Word_Count'].max()

1168

In [29]:
dfs['train']['Pertinent_Results_Word_Count'].describe()

count    29096.000000
mean       440.001306
std        446.843602
min          1.000000
25%        161.000000
50%        311.000000
75%        567.000000
max       6670.000000
Name: Pertinent_Results_Word_Count, dtype: float64

In [28]:
print("SUCCESS")

SUCCESS


In [29]:
dfs = {}
dfs['test'] = pd.read_pickle("./discharge_data_test_set.pkl")
dfs['test_2'] = pd.read_pickle("./discharge_data_test_set_2.pkl")

In [5]:
def get_completion(model, prompt, max_tokens, temperature=0, max_new_tokens=500):
    response = openai.Completion.create(
        model=model,
        prompt=prompt,
        max_tokens=max_tokens,
        max_new_tokens=max_new_tokens,
        temperature=0,  # this is the degree of randomness of the model's output
        request_timeout=1200
    )
    return response.choices[0].text

In [30]:
import openai
import os
import pandas as pd
import time

openai.api_key = "EMPTY" # Not support yet
openai.api_base = "http://localhost:8000/v1"
model = "Mistral-7B-Instruct-v0.2-GPTQ-Brief-Hospital-Course"
prompt = "Once upon a time"

In [31]:
# create a completion
completion = openai.Completion.create(model=model, prompt=prompt, max_tokens=64, temperature=0)
# print the completion
print(prompt + completion.choices[0].text)

Once upon a time[there was] [a] [p] [p] [p] [p] [p] [p] [p] [p] [p] [p] [p] [p] [p] [p] [p] [p] [p] [p] [p]


# Target Section Summarization

## Brief Hospital Course Summarization

### Inference (API)

In [32]:
# GOOD 1 (CoT)
brief_hospital_course_prompt = """<s>[INST] In this task, you are provided with a Discharge Summary delimited by triple quotes.
Discharge Summaries are documents that outline the care a patient received during their hospital stay, including diagnoses, treatments, and follow-up care instructions, prepared at the time of a patient's discharge. 
Discharge Summaries are split into various sections and written under a variety of headings, relating to admission, diagnosis and relevant discharge information. But the provided Discharge summary will be missing the \"Brief Hospital Course\". \"Brief Hospital Course\" is a section of the discharge summaries that outlines the key events of a patient's hospital stay, including the progression from admission to discharge. It is written for the subsequent care providers about the critical aspects of the patient.
You are tasked to generate the missing \"Brief Hospital Course\" section in the discharge summary, based on the information of other sections in the discharge summary.
Brief Hospital Course outlines the key events of a patient's hospital stay, including the progression from admission to discharge. It is written for the subsequent care providers about the critical aspects of the patient

The summary should be written in the following structure, by answering some important questions:
1. Initial presentation: Describe the patient's initial presentation, including the main complaint and relevant history.
    * What were the main treatment strategies employed for the patient's conditions during their stay? Include medications adjusted, procedures performed, and any therapeutic interventions.
    * What are the key diagnoses identified during the hospital stay?
2. Treatment course:
    - For each section header named by "#Condition Name", provide a detailed description of each condition, disease, or symptom of the patient by answering the following questions:
        * What is the patient's background relating to the condition, disease, or symptom
        * Describe the treatment strategy, including any medications given, procedures performed, and dietary adjustments.
        * How was the diagnosis reached, including any significant tests or evaluations conducted?
        * What were the significant medical or surgical interventions during the hospital stay, including any procedures, diagnostic tests (e.g., CT Scan, Imaging, Blood Test, MRI), and changes in medication?
        * Were there any complications or additional diagnoses during the hospital stay? How were these addressed and managed?
        * How did the patient's condition progress throughout the hospital stay, including any monitoring of symptoms, response to treatments, and adjustments made to the treatment plan?
        * What were the conditions and considerations for the patient’s discharge? Include the discharge medications, any changes from previous medication regimens, and follow-up care or lifestyle recommendations.
3. Transitional issues: Highlight any transitional care issues addressed during the hospital stay, including changes in medication, dietary adjustments, and specific care instructions.
4. Acute/active issues: Detail the management of acute or active issues encountered during the stay, using the provided structure for each condition.
5. Chronic/stable issues: Summarize how chronic conditions were managed during the stay and any adjustments made to long-term management plans.[/INST]</s>

[INST] Discharge summary: 
\"\"\"%s\"\"\" [/INST]"""

In [33]:
def get_completion(model, prompt, max_tokens, temperature=0):
    response = openai.Completion.create(
        model=model,
        prompt=prompt,
        max_tokens=max_tokens,
        temperature=0,  # this is the degree of randomness of the model's output
        request_timeout=2000
    )
    return response.choices[0].text

In [34]:
def generate_target_section(my_prompt, data, model="checkpoint-87288"):
    final_prompt = my_prompt % data

    retries = 5
    while retries > 0:
        try:
            response = get_completion(model, final_prompt, 1000, temperature=0)
            return response
        except Exception as e:
            if e:
                if "exceeded your current quota" in str(e).lower():
                    raise e
                print(e)
                print('Timeout error, retrying...')
                retries -= 1
                if "limit reached for" in str(e).lower():
                    time.sleep(30)
                else:
                    time.sleep(5)
            else:
                raise e

    print('API is not responding, moving on...')
    return None

In [35]:
def target_section_summarization(root_path, target_section_prompt, output_col_name, domain, domain_df, model, save_step=10):

    src_path = f"{root_path}/{domain}"
    Path(src_path).mkdir(parents=True, exist_ok=True)
    extractions = []

    file_names = listdir(src_path)
    postfix = [re.split("[_.]", name)[1]
               for name in listdir(src_path)
               ]
    start = 0
    if 'done' in postfix:
        print(domain, ": ", "Loaded saved file. Done")
        new_domain_df = pd.read_pickle(f"{src_path}/{domain}_done.pkl")
        return new_domain_df
    elif len(postfix) > 0:
        last_index = max([int(idx) for idx in postfix if idx != 'done'])
        last_domain_df = pd.read_pickle(f"{src_path}/{domain}_{last_index}.pkl")
        extractions = last_domain_df[output_col_name].tolist()
        start = last_index
        print(domain, "Loaded saved file. Continuing")
    else:
        print(domain, "Start new process.")

    for i, (_, row) in tqdm(enumerate(domain_df.iterrows()), total=domain_df.shape[0]):
        if i < start:
            continue

        discharge_summary_data = row['processed_text']
        extraction = generate_target_section(target_section_prompt, discharge_summary_data, model)
        time.sleep(0.3)
        extractions += [extraction]

        if (i + 1) % save_step == 0:
            save_df = domain_df.iloc[:i + 1]
            save_df.insert(0, output_col_name, extractions)
            save_df[['hadm_id', output_col_name]].to_pickle(f"{src_path}/{domain}_{i + 1}.pkl")

    new_domain_df = domain_df.iloc[:i + 1]
    new_domain_df.insert(0, output_col_name, extractions)
    new_domain_df[['hadm_id', output_col_name]].to_pickle(f"{src_path}/{domain}_done.pkl")
    return new_domain_df

### Phase 2

In [36]:
df = dfs['test_2']
df['processed_text_word_count'] = df['processed_text'].apply(lambda x: len(x.split(" ")))
df = df.sort_values(by=['processed_text_word_count'], ascending=False)

In [37]:
thres = 1000
# thres = 1200
# thres = 1500
df['category'] = df['processed_text_word_count'].apply(lambda x: 1 if x < thres else 0)

In [38]:
mask = (df['processed_text_word_count'] >= 1000) & (df['processed_text_word_count'] <= 1300)
df.loc[mask, 'category'] = 2

In [39]:
# mask = (df['processed_text_word_count'] >= 1000) & (df['processed_text_word_count'] <= 1150)
# df.loc[mask, 'category'] = 3

In [51]:
df['category'].value_counts()

category
0    5049
1    3436
2    2500
Name: count, dtype: int64

#### R128_64

In [88]:
def replace_pertinent_results_with_radiology(row):
    if pd.notnull(row['Pertinent_Results']):
        new_reports = []
        for report in row['radiology_text']:
            rex = r'((?i)impression:[\s ]*\n{0,2}(.+?)\s*$)'
            section_ext = re.findall(rex, report, re.DOTALL)
            if len(section_ext) > 0 and section_ext[0][1][:15] in row['Pertinent_Results']:
                new_reports += [report]
        
        new_pertinent_results = "=============\n\n".join([report for report in new_reports])
        row['processed_text'] = row['processed_text'].replace(row['Pertinent_Results'], new_pertinent_results)
    return row

In [52]:
def replace_pertinent_results_with_radiology(row):
    if pd.notnull(row['Pertinent_Results']):
        new_reports = []
        for report in row['radiology_text']:
            rex = r'((?i)impression:[\s ]*\n{0,2}(.+?)\s*$)'
            section_ext = re.findall(rex, report, re.DOTALL)
            if len(section_ext) > 0 and section_ext[0][1][:15] in row['Pertinent_Results']:
                new_reports += [report]
        
        new_pertinent_results = ""
#         new_reports.sort(key=len, reverse=True)
        for report in new_reports:
            if len(new_pertinent_results.split(" ")) == 1:
                new_pertinent_results += report
            elif len(new_pertinent_results.split(" ")) < 1000:
                new_pertinent_results += ("=============\n\n" + report)
#         new_pertinent_results = "=============\n\n".join([report for report in new_reports])
        new_pertinent_results += "=============\n\n"
    
        row['processed_text'] = row['processed_text'].replace(row['Pertinent_Results'], new_pertinent_results)
            
    return row

In [53]:
df = df.apply(replace_pertinent_results_with_radiology, axis=1)

  section_ext = re.findall(rex, report, re.DOTALL)


In [54]:
print(len(df.iloc[0]['processed_text'].split(" ")))

3491


In [56]:
# root_path = './discharge_me_output/brief_hospital_course_mistral_finetuned_phase_2_r128_64_new_pertinent_2'
root_path = './discharge_me_output/brief_hospital_course_mistral_finetuned_phase_2_r128_64_new_pertinent_2_plain_best'
# num_workers = 2
num_workers = 1

In [62]:
model

'checkpoint-11838'

In [97]:
inputs = [(root_path,
           brief_hospital_course_prompt,
           'Brief_Hospital_Course_Generated',
           domain,
           df[df['category'] == domain].reset_index(drop=True),
           model,
           10,
           )
          for domain in [1]]
#           for domain in [0, 1]]
start_time = time.time()
with Pool(num_workers) as processor:
    data = processor.starmap(target_section_summarization, inputs)

1 Loaded saved file. Continuing


100%|████████████████████████████████████████████████████████████████████████████| 3436/3436 [26:31:22<00:00, 27.79s/it]
