In [1]:
import numpy as np
import pickle
import copy
import torch
from torch.nn import functional as F
import os
import skimage.io
import torchvision.transforms as T
import cv2
import torchxrayvision as xrv
from tqdm import tqdm
from PIL import Image
from kornia.augmentation import RandomCrop

In [2]:
def rename_original_file(file_path):
    """ Rename the original file to include '_original' in its filename. """
    new_file_path = file_path.replace('.pkl', '_original.pkl')
    os.rename(file_path, new_file_path)
    return new_file_path

In [3]:
import openai
from openai import OpenAI
# Set your OpenAI API key here
api_key = '.'

In [4]:
client = OpenAI(api_key = api_key)

In [5]:
def augment_time_series(data):
    """ Augment time series data by adding Gaussian noise directly to a numpy array. """
    mean = 0
    std = 0.1
    if data is not None:
        noise = np.random.normal(mean, std, data.shape)
        return data + noise
    return data

In [6]:
def augment_text(original_text):
    """ Use GPT-3.5-turbo to rewrite a chest X-ray report. """
    prompt=f'''
    Rewrite the following chest X-ray report using different wording while maintaining all medical facts and implications. 
    Please ensure the generated text remains in the same format as the input. 
    Only return the rewritten report without any additional commentary: {original_text}'''

    my_message = [
        {'role': 'system', 'content': 'You are a radiologist rewriting a chest X-ray report.'},
        {'role': 'user', 'content': prompt}
    ]

    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=my_message,
        max_tokens=1024,
        temperature=0,
    )
    return response.choices[0].message.content

In [7]:
from transformers import AutoTokenizer, AutoModel, logging
logging.set_verbosity_error()
# biobert_path = '../pretrained_models/bio_clinical_bert/biobert_pretrain_output_all_notes_150000/'
# biobert_path = '/cis/home/charr165/vscode_projects/HAIM/pretrained_bert_tf/bert_pretrain_output_all_notes_150000'
# biobert_tokenizer = AutoTokenizer.from_pretrained(biobert_path)
# biobert_model = AutoModel.from_pretrained(biobert_path)
biobert_tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
biobert_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
longformer_tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-Longformer")
longformer_model = AutoModel.from_pretrained("yikuan8/Clinical-Longformer")

In [8]:
def get_biobert_embeddings(text, device, modelname='longformer'):

    # Move the BioBERT model to the specified device
    biobert_model.to(device)
    longformer_model.to(device)

    if modelname == 'biobert':
        # Tokenize text and move tokens to the same device
        tokens_pt = biobert_tokenizer(text, return_tensors="pt", padding="max_length", max_length=512,truncation=True).to(device)

        # Get outputs from BioBERT model
        outputs = biobert_model(**tokens_pt)
    elif modelname == 'longformer':
        # Tokenize text and move tokens to the same device
        tokens_pt = longformer_tokenizer(text, return_tensors="pt", max_length=1024, padding="max_length", truncation=True).to(device)

        # Get outputs from BioBERT model
        outputs = longformer_model(**tokens_pt)

    # Extract last hidden state and pooler output (both on the GPU)
    last_hidden_state = outputs.last_hidden_state
    pooler_output = outputs.pooler_output

    # No conversion to numpy, stay as tensors
    hidden_embeddings = last_hidden_state
    embeddings = pooler_output

    return embeddings, hidden_embeddings

In [9]:
def process_batch(chunk_batch, device):
    """ Process a batch of text chunks to generate embeddings. """
    embeddings_batch = []
    for chunk in chunk_batch:
        curr_embeddings, _ = get_biobert_embeddings(chunk, device)  # Assuming this function is defined elsewhere
        embeddings_batch.append(curr_embeddings.detach().cpu().numpy())
    return embeddings_batch

In [10]:
def augment_image(image_path, image_size=(2539, 3050), shift=100):
    """ Augment image by applying a random crop. """
    image = Image.open(image_path)
    image = np.array(image) / 255
    image = torch.tensor(image).unsqueeze(0)
    transform = RandomCrop(size=image_size)
    image = transform(image) * 255
    image = image.squeeze().numpy().astype(np.uint8)
    augmented_image_path = image_path.replace('.jpg', '_augmented.jpg')
    Image.fromarray(image).save(augmented_image_path)
    return augmented_image_path

In [11]:
def generate_cxr_feats(image_paths, model):
    """ Generate CXR features for a list of image paths using a pre-trained Densenet model. """
    img_list = []
    for image_path in image_paths:
        img = skimage.io.imread(image_path)
        img = xrv.datasets.normalize(img, 255)
        img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_AREA)
        img = np.expand_dims(img, axis=0)  # Add channel dimension
        img_list.append(img)

    # Convert the list of image arrays into a batch tensor
    img_batch = np.stack(img_list, axis=0)  # This should be (n_images, 1, 224, 224)
    img_batch = torch.from_numpy(img_batch).float().cuda()

    with torch.no_grad():
        feats = model.features(img_batch)
        feats = F.relu(feats, inplace=True)
        feats = F.adaptive_avg_pool2d(feats, (1, 1))
        densefeatures_batch = feats.cpu().detach().numpy().reshape(len(img_list), -1)

    return densefeatures_batch


In [12]:
def load_data(file_path):
    """ Load the data from a pickle file. """
    with open(file_path, 'rb') as f:
        return pickle.load(f)

def save_data(data, file_path):
    """ Save the data to a pickle file. """
    with open(file_path, 'wb') as f:
        pickle.dump(data, f)

In [None]:
def single_augmentation(patient_data, text_device, cxr_model):
    """ Augment a single instance of patient data. """
    # Copy the original data to keep it unchanged
    aug_patient_data = copy.deepcopy(patient_data)
    
    # Augment the time series data
    if 'irg_ts' in aug_patient_data and aug_patient_data['irg_ts'] is not None:
        aug_patient_data['irg_ts'] = augment_time_series(aug_patient_data['irg_ts'])
    if 'reg_ts' in aug_patient_data and aug_patient_data['reg_ts'] is not None:
        aug_patient_data['reg_ts'] = augment_time_series(aug_patient_data['reg_ts'])
    
    # Augment and handle text data
    if not aug_patient_data['text_missing']:
        aug_texts = [augment_text(text) for text in aug_patient_data['text_data']]
        aug_patient_data['text_data'] = aug_texts
        # Generate text embeddings
        aug_patient_data['text_embeddings'] = [process_batch([text], text_device) for text in aug_texts]

    # Augment and handle image data
    if not aug_patient_data['cxr_missing']:
        aug_image_paths = [augment_image(path) for path in aug_patient_data['image_paths']]
        aug_patient_data['image_paths'] = aug_image_paths
        # Generate CXR features
        aug_patient_data['cxr_feats'] = generate_cxr_feats(aug_image_paths, cxr_model)

    # Ignore the ECG data by setting the missing flag
    aug_patient_data['ecg_missing'] = True
    
    return aug_patient_data

In [None]:
def count_classes(data):
    """ Count instances in each class. """
    counts = {}
    for item in data:
        label = item['label']
        if label in counts:
            counts[label] += 1
        else:
            counts[label] = 1
    return counts

def augment_data_conditionally(file_path, text_device, cxr_model):
    data = load_data(file_path)
    counts = count_classes(data)
    num_majority = max(counts.values())
    num_minority = counts.get(1, 0)  # Assuming label '1' is the minority class

    if num_minority == 0:
        raise ValueError("No instances of the minority class found.")

    num_augmentations = math.ceil((num_majority - num_minority) / num_minority)
    augmented_data = []

    for item in data:
        augmented_data.append(item)
        if item['label'] == 1:
            for _ in range(num_augmentations):
                augmented_instance = single_augmentation(copy.deepcopy(item), text_device, cxr_model)
                augmented_data.append(augmented_instance)

    # Save the augmented data
    save_data(augmented_data, file_path)

In [14]:
file_path = '/data/wang/junh/datasets/multimodal/augmentation/pid_3mod/train_los_S_stays.pkl'
augment_data(file_path, text_device='cuda:3', cxr_model=xrv.models.DenseNet(weights="all").cuda())

100%|██████████| 115/115 [14:40<00:00,  7.66s/it]


In [None]:
file_path = '/data/wang/junh/datasets/multimodal/augmentation/pid_3mod/train_ihm_S_stays.pkl'
augment_data(file_path, text_device='cuda:3', cxr_model=xrv.models.DenseNet(weights="all").cuda())

 67%|██████▋   | 120/178 [23:11<15:03, 15.58s/it]

In [9]:
# Example usage
original_cxr_report = train_stays_list[0].get('text_data')[1]
metadata = train_stays_list[0].get('cxr_metadata')[1]

In [10]:
original_cxr_report

'INDICATION:  ___ year old man with new ventricular tachycardia status post\nablation also volume overload\n\nTECHNIQUE:  Frontal chest radiographs were obtained with the patient in the\nsupine position.\n\nCOMPARISON:  Radiographs from ___.\n\nFINDINGS: \n\nThe heart continues to be enlarged, and a left cardiac device has its leads\nterminating in appropriate position. There is moderate edema with associated\nright layering profusion. An endotracheal tube terminates in appropriate\nposition, and the nasogastric tube terminates below the view of this\nradiograph.\n\nIMPRESSION: \n\nContinued moderate pulmonary edema with layering right pleural effusion.\n'

In [11]:
# Augment text
augmented_text = augment_text(original_cxr_report)
print("Augmented Text:", augmented_text)

Augmented Text: INDICATION:  Chest X-ray of a ___ year old man with recent ventricular tachycardia ablation and volume overload.

TECHNIQUE:  Frontal chest radiographs were taken with the patient lying supine.

COMPARISON:  Prior radiographs not available for comparison.

FINDINGS: 

The heart remains enlarged, and a left cardiac device is appropriately positioned with its leads. Moderate edema is present with right pleural effusion. The endotracheal tube and nasogastric tube are both appropriately positioned.

IMPRESSION: 

Persistent moderate pulmonary edema with right pleural effusion.


In [18]:
for stays in train_stays_list:
    print(stays['text_data'][0])

A chest fluoro was performed without a radiologist present.  2 minutes and 59
seconds of fluoro time was used.  Films were submitted to PACS.


INDICATION:  ___ male with lactic acidosis, severe mitral
regurgitation, and tachypnea.

COMPARISON:  ___.

TECHNIQUE:  Single frontal chest radiograph was obtained portably with the
patient in an upright position.

FINDINGS:  Compared to prior exam, there has been no significant interval
change.  Moderate cardiomegaly and pulmonary vascular congestion persist.  No
focal consolidation is detected on this single view.  There may be trace right
pleural effusion.  No pneumothorax is detected.

IMPRESSION:  Stable chest radiograph.

HISTORY:  CHF.

FINDINGS:  In comparison with study of ___, there is diffuse bilateral
pulmonary opacifications in a pattern consistent with the clinical diagnosis
of congestive failure.  Cardiac silhouette is more prominent, though some of
this may be due to the portable supine position.  The left hemidiaphragm is
poor