In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import json
import re
import os
import sys
from sklearn.metrics import f1_score, precision_score, recall_score
from nltk.translate.bleu_score import sentence_bleu

from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline

In [2]:
sys.path.append('/data/dangnguyen/report_generation/report-generation/')
from CXRMetric.CheXbert.src.label import label

In [3]:
CHEXBERT_PATH = '/data/dangnguyen/report_generation/models/chexbert.pth'

cxr_labels = [
        'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema',
        'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 'Lung Opacity',
        'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia',
        'Pneumothorax', 'Support Devices']

cxr_labels_2 = ['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',\
'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis',\
'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices', 'No Finding']

# converts a label vector to English
def labels_to_eng(labels):
    diag = ''
    for i in range(len(labels)):
        label = labels[i]
        cond = cxr_labels[i]
        if label == 1:
            diag += cond
            diag += ', '
    return diag

# Computes the exact match accuracy of the generated reports
def exact_match(gt, pred):
    matches = []
    for (gt_rp, pred_rp) in zip(gt, pred):
        gt_words = re.sub(r"[^\w\s]", "", gt_rp).lower().split() # removes special characters
        pred_words = re.sub(r"[^\w\s]", "", pred_rp).lower().split()

        match = True
        if len(gt_words) == len(pred_words):
            for (gt_word, pred_word) in zip(gt_words, pred_words):
                if gt_word != pred_word:
                    match = False
        else:
            match = False
        matches.append(match)
        
    matches = np.array(matches)
    em_acc = np.sum(matches) / len(matches)
    return em_acc, matches

# Computes the positive and negative F1
def compute_f1(df_gt, df_pred):
    # need to make sure df_gt and df_pred has a column called "report"
    gt_pre_chexb = './gt_pre-chexbert.csv'
    df_gt.to_csv(gt_pre_chexb, index=False)

    y_gt = label(CHEXBERT_PATH, gt_pre_chexb)
    y_gt = np.array(y_gt).T
    y_gt = y_gt[:, :-1] # excluding No Finding

    # Note on labels:
    # 0: unmentioned ; 1: positive ; 2: negative ; 3: uncertain
    
    y_gt_neg = y_gt.copy()
    y_gt_neg[(y_gt_neg == 1) | (y_gt_neg == 3)] = 0
    y_gt_neg[y_gt_neg == 2] = 1
    
    y_gt[(y_gt == 2) | (y_gt == 3)] = 0

    pred_pre_chexb = './pred_pre-chexbert.csv'
    df_pred.to_csv(pred_pre_chexb, index=False)

    # the labels are according to the 2nd ordering (see run_eval.py)
    y_pred = label(CHEXBERT_PATH, pred_pre_chexb)
    y_pred = np.array(y_pred).T
    y_pred = y_pred[:, :-1]

    y_pred_neg = y_pred.copy()
    y_pred_neg[(y_pred_neg == 1) | (y_pred_neg == 3)] = 0
    y_pred_neg[y_pred_neg == 2] = 1
    
    y_pred[(y_pred == 2) | (y_pred == 3)] = 0
    
    assert y_gt.shape == y_pred.shape

    os.system('rm {}'.format(gt_pre_chexb))
    os.system('rm {}'.format(pred_pre_chexb))

    pos_f1 = f1_score(y_gt, y_pred, average='macro', zero_division=1)
    neg_f1 = f1_score(y_gt_neg, y_pred_neg, average='macro', zero_division=1)
    prag_f1 = np.mean([pos_f1, neg_f1])
    
    # also returning the labels matrices for debugging
    return pos_f1, neg_f1, prag_f1, y_gt_neg, y_gt, y_pred_neg, y_pred

# Borrowing this function from /CXRMetric/run_eval.py
def prep_reports(reports):
    """Preprocesses reports"""
    return [list(filter(lambda val: val !=  "", str(elem).lower().replace(".", " .").split(" "))) for elem in reports]

# Computes BLEU-2
def bleu_2(df_gt, df_pred):
    scores = []
    for i, row in df_gt.iterrows():
        gt_report = prep_reports([row['report']])[0]
        predicted_report = prep_reports([df_pred.loc[i]['report']])[0]
        
        score = sentence_bleu([gt_report], predicted_report, weights=(1/2, 1/2)) # to use BLEU-2
        scores.append(score) 
    return np.mean(scores)

In [5]:
# Formatting finetuning data into instructions
# clean data
df_ind = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/val_indications_clean.csv')[['study_id','report']].drop_duplicates()
df_ind = df_ind.rename(columns={'report': 'indication'}).fillna('')

df_imp_chexb = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/val_ind_imp.csv')[['study_id','report']].drop_duplicates()
df_imp_chexb = df_imp_chexb.rename(columns={'report': 'impression'})

df_ind_imp = df_imp_chexb.merge(df_ind, on='study_id')
df_ind_imp

Unnamed: 0,study_id,impression,indication
0,52139270,"Bilateral pleural effusions, cardiomegaly and ...",_..
1,52139270,"Bilateral pleural effusions, cardiomegaly and ...",_..
2,52309364,"Mild to large bilateral, right greater than le...",F with oxygen requirement // evaluate for pulm...
3,53282957,Moderate pulmonary edema with moderate to larg...,"History: ___F status post fall, bradycardic //..."
4,53836463,"CHF, slightly worse than on the prior study.",_..
...,...,...,...
1832,52815959,Mild interstitial edema. Left basilar opacity...,___-year-old male with confusion. Evaluate for...
1833,53502057,Emphysema with mild congestion and edema. Bib...,___-year-old man with chest pain and shortness...
1834,53502057,Emphysema with mild congestion and edema. Bib...,___-year-old man with chest pain and shortness...
1835,58368837,No acute intrathoracic process.,"Shoulder pain, evaluate for infiltrate.."


In [15]:
# NUM_SAMPLES = 12800
sample_pc = 1
df_ind_imp_sample = df_ind_imp.sample(frac=sample_pc, random_state=42)
df_ind_imp_sample

Unnamed: 0,study_id,impression,Enlarged Cardiomediastinum,Cardiomegaly,Lung Opacity,Lung Lesion,Edema,Consolidation,Pneumonia,Atelectasis,Pneumothorax,Pleural Effusion,Pleural Other,Fracture,Support Devices,No Finding,indication
7592,54817529,Mild pulmonary vascular congestion without ove...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,"___ year old woman with crackles bilaterally, ..."
12051,51209748,There is a left subclavian pacer with single l...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,"___ year old man with AF, tachy-brady syndrome..."
17353,55605922,No acute cardiopulmonary abnormality.,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,History: ___M with neutrapenic fever // ?PNA
17873,58709083,No acute disease.,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,Pregnancy with chest pain and dyspnea.
18496,58027429,Interstitial lung disease with small pleural e...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,___-year-old female with lethargy. Evaluate f...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
11284,55829081,Endotracheal tube has its tip 2.6 cm above the...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,_
11964,57836239,Slight increase in size of small left pleural ...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,_
5390,58757139,No acute findings. Mild cardiomegaly.,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,_
860,52981560,Mild pulmonary edema has resolved. No evidenc...,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,"___ year old man with obesity, OSA, ESRD on HD..."


In [None]:
df_ind_imp_sample.to_csv('/data/dangnguyen/report_generation/mimic_data/finetune_llm/finetune_imp_clean_80pc.csv', index=False)

In [None]:
# validation set
df_ind_imp_sample = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/val_ind_imp.csv').fillna('')
val_viz = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/finetune_llm/val_pred_viz.csv')

df_ind_imp_sample = pd.concat([df_ind_imp_sample, val_viz[cxr_labels]], axis=1)
df_ind_imp_sample = df_ind_imp_sample.rename(columns={'report':'impression'})
df_ind_imp_sample

In [None]:
# original data
imp_chexb = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/train_gt_imp_chexbert.csv').drop_duplicates()
imp_chexb = imp_chexb.rename(columns={'Report Impression':'impression'})
imp_chexb

In [None]:
ft_imp = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/finetune_llm/finetune_gt_imp_clean.csv')
ft_imp

In [None]:
df_ind = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/train_indications.csv')[['study_id','report']].drop_duplicates()
df_ind = df_ind.rename(columns={'report':'indication'})
df_ind

In [None]:
df_imp = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/train_gt_imp.csv')[['study_id','report']].drop_duplicates()
df_imp = df_imp.rename(columns={'report':'impression'})
df_imp

In [None]:
assert list(df_ind['study_id']) == list(df_imp['study_id'])

In [None]:
df_rp = pd.concat([df_ind, df_imp['impression']], axis=1)
df_rp = df_rp.loc[~df_rp['impression'].isna()].fillna('')
df_rp

In [None]:
df_rp_n_percent = df_rp.loc[df_rp['study_id'].isin(ft_imp['study_id'])] # 10 percent
df_rp_n_percent

In [None]:
# other percents
sample_percent = 0.8
df_rp_n_percent = df_rp.sample(frac=sample_percent)
df_rp_n_percent

In [None]:
df_rp_chexb = df_rp_n_percent.merge(imp_chexb).replace([np.nan, -1], 0)
df_rp_chexb

In [None]:
df_ind_imp_sample = df_rp_chexb

In [39]:
finetune_data = []
# instruction = 'Write a radiology report responding to the indication. Include all given positive labels.'
# instruction = 'Write a radiology report that includes all given positive labels.'
instruction = 'Write a radiology report responding to the indication and positive labels. \
Make medical recommendations as necessary.'

for _, row in df_ind_imp_sample.iterrows():
    ind = row['indication']
    imp = row['impression']
    labels = labels_to_eng(row[cxr_labels])[:-2]
    
    inp = 'Indication: {}.\nPositive labels: {}'.format(ind, labels)
#     inp = 'Positive labels: {}'.format(labels) # to test the effect of the indication
    
    sample = {
        'instruction': instruction,
        'input': inp,
        'output': imp
    }
    finetune_data.append(sample)

In [40]:
len(finetune_data)

19000

In [41]:
np.random.choice(finetune_data)

{'instruction': 'Write a radiology report responding to the indication and positive labels. Make medical recommendations as necessary.',
 'input': 'Indication: ___M with hx of epilepsy presenting s/p seizure. R/o infection  // Pneumonia?.\nPositive labels: No Finding',
 'output': 'No acute cardiopulmonary process. '}

In [42]:
outpath = '/data/dangnguyen/report_generation/mimic_data/finetune_llm/finetune_imp_rec_10pc.json'

with open(outpath, 'w') as json_file:
    json.dump(finetune_data, json_file)

In [None]:
# Getting all unique sentences for cleaning
imp_sen = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/mimic_train_impressions_sentence.csv')
imp_sen

In [None]:
imp_sen_uniq = imp_sen[['report']].drop_duplicates().reset_index(drop=True)
imp_sen_uniq

In [None]:
num_batches = 6
batch_size = len(imp_sen_uniq) // num_batches
for i in range(num_batches+1):
    start = i * batch_size
    end = (i + 1) * batch_size
    partition = imp_sen_uniq[start:end]
    partition.to_csv('/data/dangnguyen/report_generation/mimic_data/report_cleaning/clean_all/partition_{}.csv'.format(i+1), index=False)

In [None]:
# Making a test set comprising of indication, GT report, and labels
test_gt_imp = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/test_gt_imp.csv')
chexbert_labels = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/test_gt_imp_chexbert.csv')[cxr_labels_2].fillna(0).replace(-1, 0)
test_gt_imp_chexb = pd.concat([test_gt_imp, chexbert_labels], axis=1)
test_gt_imp_chexb

In [None]:
test_gt_ind = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/test_indications.csv')[['study_id','report']].drop_duplicates()
test_gt_ind = test_gt_ind.rename(columns={'report':'indication'})
test_gt_ind

In [None]:
test_ind_imp_chexb = test_gt_imp_chexb.merge(test_gt_ind, on='study_id').fillna('').reset_index(drop=True)
test_ind_imp_chexb

In [None]:
cols = ['study_id','indication','report'] + cxr_labels
test_ind_imp_chexb[cols].to_csv('/data/dangnguyen/report_generation/mimic_data/test_ind_imp_chexb.csv', index=False)

In [None]:
# For dumb parallelism: aggregating the outputs and filtering broken sentences
NUM_PARTITIONS = 7
og_rps = []
cleaned_rps = []

for i in range(NUM_PARTITIONS):
    df_og = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/report_cleaning/clean_all/partition_{}.csv'.format(i+1))
    df = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/report_cleaning/clean_all/partition_{}/rewrite7_intermediate.csv'.format(i+1))
    og_rps.append(df_og)
    cleaned_rps.append(df)
    
og_rp = pd.concat(og_rps, axis=0).fillna('').reset_index(drop=True)
clean_rp = pd.concat(cleaned_rps, axis=0).fillna('').reset_index(drop=True)

clean_rp = clean_rp.rename(columns={'report':'llm_rewritten'})
clean_sens = pd.concat([og_rp, clean_rp], axis=1)
clean_sens

In [None]:
df_imp_sen = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/mimic_train_impressions_sentence.csv')[['study_id','sentence_id','report']].drop_duplicates()
df_imp_sen

In [None]:
df_imp_sen_clean = df_imp_sen.merge(clean_sens, how='left', on='report')
df_imp_sen_clean

In [None]:
clean_rp = df_imp_sen_clean

In [None]:
clean_rp.to_csv('/data/dangnguyen/report_generation/mimic_data/report_cleaning/clean_all/imp_sen_cleaned_all.csv', index=False)

In [None]:
clean_rp['llm_rewritten'] = clean_rp['llm_rewritten'].apply(lambda x: x.replace('_', ''))
clean_rp['clean_length'] = clean_rp['llm_rewritten'].apply(lambda x: len(x.split()))
clean_rp = clean_rp.replace(['REMOVED', ''], '_')
clean_rp

In [None]:
clean_rp[['llm_rewritten']].rename(columns={'llm_rewritten':'report'}).to_csv('./tmp/clean_all_pre_chexbert.csv', index=False)

In [None]:
y_gt = label(CHEXBERT_PATH, './tmp/clean_all_pre_chexbert.csv')
y_gt = np.array(y_gt).T
y_gt[(y_gt == 2) | (y_gt == 3)] = 1

In [None]:
# filtering sentences with fewer than 2 words and no finding, i.e., meaningless sentences
is_no_finding = ~y_gt[:, :-1].any(axis=1)
clean_rp['no_finding'] = is_no_finding
clean_rp_no_finding = clean_rp.loc[clean_rp['no_finding'] & (clean_rp['clean_length'] < 3)]
clean_rp_clean = clean_rp.loc[(~clean_rp['no_finding']) | (clean_rp['clean_length'] >= 3)]
clean_rp_clean

In [None]:
uniq_ids = clean_rp_clean['study_id'].unique()

In [None]:
# combining sentences into reports and labeling them
clean_reports = []
for study_id in uniq_ids:
    df_rp = clean_rp_clean.loc[clean_rp_clean['study_id'] == study_id]
    report = ''
    for sentence in df_rp['llm_rewritten'].to_list():
        report += sentence
        report += ' '
    clean_reports.append(report)

In [None]:
df_clean_rps = pd.DataFrame(uniq_ids, columns=['study_id'])
df_clean_rps['report'] = clean_reports
df_clean_rps

In [None]:
df_clean_rps.to_csv('/data/dangnguyen/report_generation/mimic_data/finetune_llm/finetune_gt_imp_all.csv', index=False)

In [None]:
df_clean_rps = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/finetune_llm/finetune_gt_imp_all.csv')
df_clean_rps

In [None]:
# removing reports without any finding
y_gt = label(CHEXBERT_PATH, '/data/dangnguyen/report_generation/mimic_data/finetune_llm/finetune_gt_imp_all.csv')
y_gt = np.array(y_gt).T
y_gt[(y_gt == 2) | (y_gt == 3)] = 0
all_zeros = ~y_gt.any(axis=1)

In [None]:
df_clean_rps = df_clean_rps.loc[~all_zeros].reset_index(drop=True)
df_clean_rps

In [None]:
y_gt_clean = y_gt[~all_zeros]
len(y_gt_clean)

In [None]:
df_clean_rps.to_csv('/data/dangnguyen/report_generation/mimic_data/finetune_llm/finetune_gt_imp_all.csv', index=False)

In [None]:
df_clean_rps = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/finetune_llm/finetune_gt_imp_all.csv')

In [None]:
labels = pd.DataFrame(y_gt_clean, columns=[cxr_labels_2])
df_clean_rps = pd.concat([df_clean_rps, labels], axis=1)
df_clean_rps.columns = ['study_id', 'report'] + cxr_labels_2
df_clean_rps

In [None]:
imp_reports = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/train_gt_imp.csv')
imp_reports_clean = imp_reports.merge(df_clean_rps.rename(columns={'report':'llm_rewritten'}), on='study_id')
imp_reports_clean

In [None]:
cols = ['study_id', 'subject_id', 'report', 'llm_rewritten'] + cxr_labels_2
imp_reports_clean[cols].to_csv('/data/dangnguyen/report_generation/mimic_data/finetune_llm/finetune_gt_imp_all_dup.csv', index=False)

In [None]:
# clean_rp = imp_reports_clean[['study_id','report','llm_rewritten']].drop_duplicates().reset_index(drop=True)
clean_rp = clean_rp_clean.reset_index(drop=True)

clean_rp

In [None]:
df_gt = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/report_cleaning/test_cleaning_gt_200.csv').fillna('')
flan_t5 = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/report_cleaning/deepspeed/all_rules/all_rules_intermediate.csv').fillna('')

og = df_gt[['report']]
gt = df_gt[['cleaned']].rename(columns={'cleaned':'report'})

flan_t5 = flan_t5.replace('REMOVED', '')
# flan_t5 = flan_t5[['llm_rewritten']].rename(columns={'llm_rewritten':'report'})
flan_t5

In [None]:
flan_t5_all = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/report_cleaning/val_flan_all_rules.csv')
flan_t5_all = flan_t5_all[['llm_rewritten']].rename(columns={'llm_rewritten':'report'})
flan_t5_all = flan_t5_all.replace('REMOVED', '')
flan_t5_all

In [None]:
# cxr_pro = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/report_cleaning/val_cxr_pro_clean.csv').fillna('')
cxr_pro = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/report_cleaning/test_cxr_pro_clean.csv').fillna('')
cxr_pro

In [None]:
# Cleaning reports uing GILBERT
def get_pipe():
    model_name = "rajpurkarlab/gilbert"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForTokenClassification.from_pretrained(model_name)
    pipe = pipeline(task="token-classification", model=model.to("cuda"), tokenizer=tokenizer, aggregation_strategy="simple")
    return pipe

def remove_priors(pipe, report):
    ret = ""
    for sentence in report.split("."):
        if sentence and not sentence.isspace():
            p = pipe(sentence)
            string = ""
            for item in p:
                if item['entity_group'] == 'KEEP':
                    string += item['word'] + " "
            ret += string.strip().replace("redemonstrate", "demonstrate").capitalize() + ". "
    return ret.strip()

In [None]:
pipe = get_pipe()
reports = list(cxr_pro['report'])

clean_reports = []
for rp in reports:
    print(rp)
    rp_cleaned = remove_priors(pipe, rp)
    clean_reports.append(rp_cleaned)

In [None]:
clean_rp['cxr_pro'] = clean_reports
clean_rp

In [None]:
# clean_rp.to_csv('/data/dangnguyen/report_generation/mimic_data/report_cleaning/cleaning_baselines_200.csv', index=False)

In [None]:
gt_clean = list(gt['report'])
llm_clean = list(flan_t5['report'])
# llm_all_clean = list(flan_t5_all['report'])
# pro_clean = list(cxr_pro['report'])

In [None]:
llm_acc, llm_matches = exact_match(gt_clean, llm_clean)
llm_acc

In [None]:
llm_all_acc, all_matches = exact_match(gt_clean, llm_all_clean)
llm_all_acc

In [None]:
pro_acc, pro_matches = exact_match(gt_clean, pro_clean)
pro_acc

In [None]:
# computing F1 and BLEU
og = og.replace('', '_')
gt = gt.replace('', '_')
flan_t5 = flan_t5.replace('', '_')
# flan_t5_all = flan_t5_all.replace('', '_')
# cxr_pro = cxr_pro.replace('', '_')

In [None]:
pos_f1, neg_f1, _, gt_neg, gt_np, pred_neg, pred_np = compute_f1(og['report'], flan_t5['report'])
# pos_f1, neg_f1, _, gt_neg, gt_np, pred_neg, pred_np = compute_f1(og['report'], flan_t5_all['report'])
# pos_f1, neg_f1, _, gt_neg, gt_np, pred_neg, pred_np = compute_f1(og['report'], cxr_pro['report'])
print('{}\n{}'.format(pos_f1, neg_f1))

In [None]:
bleu = bleu_2(gt, flan_t5)
# bleu = bleu_2(gt, flan_t5_all)
# bleu = bleu_2(gt, cxr_pro)
bleu

In [None]:
# looking into failure cases
pos_diff = np.logical_xor(gt_np, pred_np)
pos_diff_agg = np.any(pos_diff, axis=1)

llm_pos_diff = gt.loc[pos_diff_agg]
llm_pos_diff

In [None]:
idx = 121
print('GT: {}\n\nManual: {}\n\nLLM: {}'.format(llm_pos_diff.loc[idx].report, 
                                               llm_pos_diff.loc[idx].cleaned, 
                                               llm_pos_diff.loc[idx].llm_rewritten))

In [None]:
pd.DataFrame(gt[pos_diff_agg], columns=cxr_labels_2[:-1])

In [None]:
pd.DataFrame(pred[pos_diff_agg], columns=cxr_labels_2[:-1])

In [None]:
neg_diff = np.logical_xor(gt_neg, pred_neg)
neg_diff_agg = np.any(neg_diff, axis=1)

llm_neg_diff = clean_rp.loc[neg_diff_agg]
llm_neg_diff

In [None]:
idx = 73
print('GT: {}\n\npred: {}'.format(llm_neg_diff.loc[idx].report, llm_neg_diff.loc[idx].cleaned))

In [None]:
pd.DataFrame(gt_neg[neg_diff_agg], columns=cxr_labels_2[:-1])

In [None]:
pd.DataFrame(pred_neg[neg_diff_agg], columns=cxr_labels_2[:-1])

In [None]:
# Step-by-step evaluation
accuracies = []

NUM_RULES = 7
for i in range(NUM_RULES):
    rule = i + 1
    data_path = '/data/dangnguyen/report_generation/mimic_data/finetune_llm/rewrite{}_cleaned.csv'.format(rule)
    clean_rp = pd.read_csv(data_path).fillna('')
    clean_rp = clean_rp.replace('REMOVED', '')
    
    gt_clean = list(clean_rp['cleaned'])
    llm_clean = list(clean_rp['llm_rewritten'])
    acc, _ = exact_match(gt_clean, llm_clean)
    accuracies.append(acc)

In [None]:
accuracies

In [None]:
rules = ['rewrite' + str(i+1) for i in range(NUM_RULES)]
rules

In [None]:
res = pd.DataFrame({'rule': rules, 'em_acc': accuracies})
res

In [None]:
incre_acc = []

NUM_RULES = 7
for i in range(NUM_RULES):
# for i in range(6, 7):
    rule = i + 1
    data_path = '/data/dangnguyen/report_generation/mimic_data/report_cleaning/rewrite{}_intermediate.csv'.format(rule)

    clean_rp = pd.read_csv(data_path).fillna('')
    clean_rp = clean_rp.replace('REMOVED', '')
    
#     llm_rewritten = clean_rp['llm_rewritten']
#     clean_rp['llm_rewritten'] = ['' if 'REMOVED' in rp else rp for rp in llm_rewritten]
    
    gt_clean = list(clean_rp['cleaned'])
    llm_clean = list(clean_rp['llm_rewritten'])
    acc, _ = exact_match(gt_clean, llm_clean)
    incre_acc.append(acc)

In [None]:
rules = ['rewrite' + str(i+1) for i in range(NUM_RULES)]
res = pd.DataFrame({'rule': rules, 'incremental_acc': incre_acc})
res

In [None]:
# Checking length of longest sentence to estimate inference batch size
imp_sen = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/mimic_train_impressions_sentence.csv')
imp_sen_uniq = imp_sen[['study_id','subject_id','report']].drop_duplicates()
imp_sen_uniq

In [None]:
imp_sen_uniq['length'] = imp_sen_uniq['report'].apply(lambda x: len(x))
imp_sen_uniq

In [None]:
max_length = imp_sen_uniq['length'].max()
longest_rp = imp_sen_uniq.loc[imp_sen_uniq['length'] > 500].report.item()
max_length

In [None]:
longest_rp

In [None]:
longest_batch = [longest_rp for _ in range(64)]

In [None]:
long_batch = pd.DataFrame(longest_batch, columns=['report'])
long_batch.to_csv('./test_long_batch.csv', index=False)

In [None]:
imp_sen_test = imp_sen[:100]
imp_sen_test['length'] = imp_sen_test['report'].apply(lambda x: len(x))
max_length = imp_sen_test['length'].max()
max_length

In [None]:
imp_sen_uniq.loc[imp_sen_uniq['length'] > 400]

In [None]:
study_ids = imp_sen['study_id'].sample(20000)

In [None]:
imp_sen = imp_sen.set_index('study_id')
imp_sen_toclean = imp_sen.loc[study_ids]
imp_sen = imp_sen.reset_index()

imp_sen_toclean

In [None]:
imp_sen_toclean = imp_sen_toclean.reset_index()

In [None]:
imp_sen_toclean_uniq = imp_sen_toclean[['study_id','sentence_id','report']].drop_duplicates()
imp_sen_toclean_uniq

In [None]:
imp_sen_toclean_uniq.to_csv('/data/dangnguyen/report_generation/mimic_data/report_cleaning/train_gt_imp_sen_72k_uniq.csv', index=False)

In [None]:
# Evaluating report-cleaning models in an unsupervised way

# Types of unavailable info:
# - Comparison to previous studies
# - Previous medical procedures
# - Communication
# - Recommendations
# - Image view

# Cleaning models:
# - No clean
# - Flan-T5 (ours)
# - GILBERT
# - XrayGPT

# Generation models:
# - X-LLaMA
# - CXR-ReDonE
# - CXR-RePaiR
# - MedCLIP
# - XrayGPT zeroshot
# - Retrieval

def compute_unavailable_prop(reports):
    type1_keywords = ["compar","interval","new","increas","worse","chang",
                      "persist","improv","resol","disappear",
                      "prior","stable","previous","again","remain","remov",
                      "similar","earlier","decreas","recurr","redemonstrate"]
    type2_keywords = ["status"]
    type3_keywords = ["findings","commun","report","convey","relay","enter","submit"]
    type4_keywords = ["recommend","suggest","should"]
    type5_keywords = [" ap "," pa "," lateral ","view"]
    type_keywords = [type1_keywords, type2_keywords, type3_keywords, type4_keywords, type5_keywords]
    
    has_info_types = [] # a (5 x num_reports) boolean matrix indicating whether a report has an info type

    for type_kws in type_keywords:
        has_type = []
        for report in reports:
            report = report.lower()
            rp_has_type = 0
            for keyword in type_kws:
                if keyword in report:
                    rp_has_type = 1
                    break
            if rp_has_type:
                has_type.append(1)
            else:
                has_type.append(0)
        has_info_types.append(has_type)
    has_info_np = np.array(has_info_types)
    
    # proportion of reports having each type
    type_freqs = has_info_np.sum(axis=1)
    type_props = type_freqs / has_info_np.shape[1]
    return type_props

In [None]:
column_names = ['Prior study','Prior procedures','Communication','Recommendations','View']
train_names = ['Train','Flan-T5','GILBERT','XrayGPT']
test_names = ['Test','CXR-RePaiR','CXR-ReDonE','XrayGPT']

train_clean = ['/data/dangnguyen/report_generation/mimic_data/train_gt_imp.csv',
               '/data/dangnguyen/report_generation/mimic_data/finetune_llm/finetune_gt_imp_all_dup.csv',
               '/data/mimic_data/cxr-pro/mimic_train_impressions.csv',
               '/data/mimic_data/train_gt_xraygpt_imp.csv']

test_gen = ['/data/dangnguyen/report_generation/mimic_data/mimic_test_impressions.csv',
            '/data/dangnguyen/report_generation/mimic_data/finetune_llm/baselines/test_cxr-repair_imp.csv',
            '/data/dangnguyen/report_generation/mimic_data/finetune_llm/baselines/test_cxr-redone_imp.csv',
            '/data/dangnguyen/report_generation/mimic_data/finetune_llm/baselines/test_xraygpt-zeroshot_imp.csv']

In [None]:
flan_t5 = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/finetune_llm/finetune_gt_imp_all_dup.csv')
flan_t5

In [None]:
# color_pal = ['brown','orangered','orange','gold']
color_pal = ['chocolate','firebrick','darkorange','gold']

In [None]:
train_results = []
for filename in train_clean:
    df = pd.read_csv(filename)[['study_id','report']].drop_duplicates().dropna()
    df_sample = df.loc[df['study_id'].isin(flan_t5['study_id'].tolist())]
    print(len(df_sample))
    props = compute_unavailable_prop(df_sample['report'].to_list())
    train_results.append(props)

In [None]:
df_res_train = pd.DataFrame(train_names, columns=['model'])
df_props_train = pd.DataFrame(train_results, columns=column_names)
df_res_train = pd.concat([df_res_train, df_props_train], axis=1)
df_res_train

In [None]:
df_res_train = df_res_train.set_index('model')
df_res_train

In [None]:
ax = df_res_train.T.plot(kind='bar', width=0.8, color=color_pal)
ax.set_ylabel('Proportion of all reports', fontsize=10)
ax.set_xticklabels(df_res_train.columns, rotation=20, fontsize=10)
ax.figure.savefig('/data/dangnguyen/report_generation/plots/train_unavailable_prop.jpg', bbox_inches='tight')

In [None]:
# test_results = []
# for filename in test_gen:
#     df = pd.read_csv(filename).fillna('')
#     props = compute_unavailable_prop(df['report'].to_list())
#     test_results.append(props)

In [None]:
# df_res_test = pd.DataFrame(test_names, columns=['model'])
# df_props_test = pd.DataFrame(test_results, columns=column_names)
# df_res_test = pd.concat([df_res_test, df_props_test], axis=1)
# df_res_test

In [None]:
# df_res_test = df_res_test.set_index('model')
# df_res_test

In [None]:
# ax = df_res_test.T.plot(kind='bar', width=0.8, color=color_pal)
# ax.set_ylabel('Proportion of all reports', fontsize=10)
# ax.set_xticklabels(df_res_test.columns, rotation=20, fontsize=10)
# ax.figure.savefig('/data/dangnguyen/report_generation/plots/test_unavailable_prop.jpg', bbox_inches='tight')

In [None]:
# Looking at X-LLaMA failure cases
val_ind_imp = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/test3_ind_imp_chexbert.csv')
val_viz = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/finetune_llm/test3_pred_viz.csv')
val_gen = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/finetune_llm/llama2_7b_clean/test3_gen_imp_epoch_3.csv')
val_gen = pd.concat([val_gen, val_ind_imp[['indication']], val_viz[cxr_labels_2]], axis=1)
val_gen

In [None]:
finetune_imp = pd.read_csv('/data/dangnguyen/report_generation/mimic_data/finetune_llm/finetune_gt_imp_18k.csv')
finetune_imp

In [None]:
ft_labels = label(CHEXBERT_PATH, '/data/dangnguyen/report_generation/mimic_data/finetune_llm/finetune_gt_imp_18k.csv')

In [None]:
ft_labels = np.array(ft_labels).T[:, :-1]
ft_labels.shape

In [None]:
ft_labels_neg = ft_labels.copy()
ft_labels_neg[(ft_labels_neg == 1) | (ft_labels_neg == 3)] = 0
ft_labels_neg[ft_labels_neg == 2] = 1
ft_neg_mentions = ft_labels_neg.sum(axis=0)
ft_neg_mentions

In [None]:
df_labels = pd.DataFrame(cxr_labels_2[:-1], columns=['label'])
df_labels['ft_neg_total'] = ft_neg_mentions
df_labels

In [None]:
_, _, _, y_gt_neg, y_gt, y_pred_neg, y_pred = compute_f1(val_gen[['original']].rename(columns={'original':'report'}), val_gen['report'])

In [None]:
total_neg_mentions = y_gt_neg.sum(axis=0)
total_gen_neg = y_pred_neg.sum(axis=0)
label_f1 = [f1_score(y_gt_neg[:, i], y_pred_neg[:, i], zero_division=0) for i in range(13)]
label_precision = [precision_score(y_gt_neg[:, i], y_pred_neg[:, i], zero_division=0) for i in range(13)]
label_recall = [recall_score(y_gt_neg[:, i], y_pred_neg[:, i], zero_division=0) for i in range(13)]

In [None]:
df_labels['val_neg_total'] = total_neg_mentions
df_labels['gen_neg_total'] = total_gen_neg
df_labels['label_neg_f1'] = label_f1
df_labels['label_precision'] = label_precision
df_labels['label_recall'] = label_recall
df_labels

In [None]:
label_id = 8
tp_keys = (y_gt_neg[:, label_id] == 1) & (y_pred_neg[:, label_id] == 1)
fn_keys = (y_gt_neg[:, label_id] == 1) & (y_pred_neg[:, label_id] == 0)
fp_keys = (y_gt_neg[:, label_id] == 0) & (y_pred_neg[:, label_id] == 1)

In [None]:
gen_tp = val_gen[tp_keys].fillna('').drop_duplicates().reset_index(drop=True)
gen_fn = val_gen[fn_keys].fillna('').drop_duplicates().reset_index(drop=True)
gen_fp = val_gen[fp_keys].fillna('').reset_index(drop=True)

print('TP: {}. FN: {}. FP: {}'.format(len(gen_tp), len(gen_fn), len(gen_fp)))

In [None]:
fn_id = 39
gt_report = gen_fn.loc[fn_id]['original']
indication = gen_fn.loc[fn_id]['indication']
gen_report = gen_fn.loc[fn_id]['report']

print('Indication: {}\n\nGT: {}\n\nGen: {}'.format(indication, gt_report, gen_report))

In [None]:
fp_id = 51
gt_report = gen_fp.loc[fp_id]['original']
indication = gen_fp.loc[fp_id]['indication']
gen_report = gen_fp.loc[fp_id]['report']

print('Indication: {}\n\nGT: {}\n\nGen: {}'.format(indication, gt_report, gen_report))

In [None]:
tp_id = 2

gt_report = gen_tp.loc[tp_id]['original']
indication = gen_tp.loc[tp_id]['indication']
gen_report = gen_tp.loc[tp_id]['report']

print('Indication: {}\n\nGT: {}\n\nGen: {}'.format(indication, gt_report, gen_report))