In [1]:
import os
import pandas as pd

from tqdm import tqdm
from transformers import pipeline, set_seed
from transformers import BioGptTokenizer, BioGptForCausalLM
from aug.bert import *

MIMIC_EYE_PATH = "F:\\mimic-eye"

REFLACX_LESION_LABEL_COLS = [
    # "Fibrosis",
    # "Quality issue",
    # "Wide mediastinum",
    # "Fracture",
    # "Airway wall thickening",

    ######################
    # "Hiatal hernia",
    # "Acute fracture",
    # "Interstitial lung disease",
    # "Enlarged hilum",
    # "Abnormal mediastinal contour",
    # "High lung volume / emphysema",
    # "Pneumothorax",
    # "Lung nodule or mass",
    # "Groundglass opacity",
    ######################
    "Pulmonary edema",
    "Enlarged cardiac silhouette",
    "Consolidation",
    "Atelectasis",
    "Pleural abnormality",
    # "Support devices",
]


CHEXPERT_LABEL_COLS = [
    "Atelectasis_chexpert",
    "Cardiomegaly_chexpert",
    "Consolidation_chexpert",
    "Edema_chexpert",
    "Enlarged Cardiomediastinum_chexpert",
    "Fracture_chexpert",
    "Lung Lesion_chexpert",
    "Lung Opacity_chexpert",
    "No Finding_chexpert",
    "Pleural Effusion_chexpert",
    "Pleural Other_chexpert",
    "Pneumonia_chexpert",
    "Pneumothorax_chexpert",
    "Support Devices_chexpert", 
]


In [2]:
from transformers import AutoTokenizer, AutoModel, DistilBertForMaskedLM
tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalBERT")
model = DistilBertForMaskedLM.from_pretrained("medicalai/ClinicalBERT")
# model = AutoModel.from_pretrained("medicalai/ClinicalBERT")

In [3]:
mask_filler = pipeline('fill-mask', model=model, tokenizer=tokenizer)
set_seed(0)

In [4]:
mask_filler("The average blood pressure is [MASK] mmHg.", top_k=3)

[{'score': 0.038724757730960846,
  'token': 10197,
  'token_str': '20',
  'sequence': 'the average blood pressure is 20 mmhg.'},
 {'score': 0.03354965150356293,
  'token': 121,
  'token_str': '0',
  'sequence': 'the average blood pressure is 0 mmhg.'},
 {'score': 0.02289050817489624,
  'token': 10218,
  'token_str': '18',
  'sequence': 'the average blood pressure is 18 mmhg.'}]

In [5]:
df = pd.read_csv('./spreadsheets/reflacx_clinical.csv')
df['temperature_c'] = df['temperature'].apply(lambda f :(f-32) * 5/9 )

In [6]:
features_to_aug =  [
            "temperature_c",
            "heartrate",
            "resprate",
            "o2sat",
            "sbp",
            "dbp",
        ]

feature_to_name_map = {
    "temperature_c": "body temperature in degrees Celsius",
    "heartrate": "heart rate in beats per minute",
    "resprate": "respiratory rate in breaths per minute",
    "o2sat": "peripheral oxygen saturation (%)",
    "sbp": "systolic blood pressure (mmHg)",
    "dbp":"diastolic blood pressure (mmHg)",
}

In [7]:
report_format=False

In [8]:
# df = aug_df(MIMIC_EYE_PATH, REFLACX_LESION_LABEL_COLS, features_to_aug, feature_to_name_map, df, generator, progress=[1, 5, 25, 50], report_format=report_format)
aug_feature_range = {f: (df[f].min(), df[f].max()) for f in features_to_aug}

for f in features_to_aug:
    df[f"aug_{f}"] = None

for f in features_to_aug:
    print(f"Resolving {f}")
    # aug the instance one by one
    for idx, data in tqdm(df.iterrows(), total=df.shape[0]):
        prompt = get_prompt_for_mask(
            MIMIC_EYE_PATH,
            data,
            REFLACX_LESION_LABEL_COLS,
            feature_to_name_map,
            f,
            report_format=report_format,
        )


        v = get_generated_value(
            mask_filler, prompt, aug_feature_range[f], top_k=100,
        )
        if v is None:
            print(
                f"Couldn't find value for [{idx}] prompt: {prompt}"
            )

            
        df.at[idx, f"aug_{f}"] = v

Resolving temperature_c


100%|██████████| 799/799 [01:06<00:00, 11.96it/s]


Resolving heartrate


100%|██████████| 799/799 [01:07<00:00, 11.82it/s]


Resolving resprate


100%|██████████| 799/799 [01:07<00:00, 11.79it/s]


Resolving o2sat


100%|██████████| 799/799 [01:07<00:00, 11.92it/s]


Resolving sbp


100%|██████████| 799/799 [01:06<00:00, 12.04it/s]


Resolving dbp


100%|██████████| 799/799 [01:07<00:00, 11.83it/s]


In [9]:
df["aug_temperature"] = df["aug_temperature_c"].apply(lambda c: (c*1.8)+32)
if report_format:
    df.to_csv('./spreadsheets/cb_aug_report.csv')
else:
    df.to_csv('./spreadsheets/cb_aug_text.csv')