In [None]:
%matplotlib inline

## Imports
import glob
import os
import numpy as np
import re
import pickle
import pandas as pd
import matplotlib.pyplot as plt
from pandarallel import pandarallel
pandarallel.initialize(progress_bar=3)
# import matplotlib.pyplot as plt
# import stanza

from stanza.server import CoreNLPClient
from tqdm import tqdm
#import cv2
from PIL import Image

## Experiments

### Generating df experiments

In [None]:
MIMIC_CXR_JPG = '/vol/aimspace/projects/physionet/mimic_cxr/mimic-cxr-jpg_2-0-0'
MIMIC_CXR_REPORTS = '/vol/aimspace/projects/physionet/mimic_cxr/mimic-cxr_2-0-0' 
PROJECT_DIR = '/vol/aimspace/projects/practical_WS2425/vision_language'
jpg_to_report_path = lambda x: x.replace('mimic-cxr-jpg_2-0-0', 'mimic-cxr_2-0-0')


In [None]:
ls $MIMIC_CXR_REPORTS/files

In [None]:
patient_folders = []
# TODO change to 10-20
# for i in range(10,20):
#     patient_folders.extend(glob.glob(f'{MIMIC_CXR_DIR}/files/p{i}/p{i}*/'))
patient_folders.extend(glob.glob(f'{MIMIC_CXR_REPORTS}/files/p10/p10*/'))
print(len(patient_folders))

In [None]:
image_files = []
for pf in tqdm(patient_folders):
    image_files.extend(glob.glob(f'{pf}/s*/*.jpg'))
print(len(image_files))

In [None]:
image_files = []
for pf in tqdm(patient_folders):
    image_files.extend(glob.glob(f'{pf}/s*/*.jpg'))
print(len(image_files))

In [6]:
text_reports = []

In [None]:
ls {patient_folders[0]}

In [None]:
with open(f'{patient_folders[0]}/s56834987.txt', 'r') as f:
    print(f.read())

In [None]:
# study_ids_folders = glob.glob(f'{MIMIC_CXR_JPG}/files/p*/p*/s*/')
study_ids_folders = glob.glob(f'{MIMIC_CXR_JPG}/files/p10/p10*/s*/')
study_ids_folders[:4]

In [None]:
# create dataframe with patient_folder, patient_id, study_id from paths
text_reports_df = pd.DataFrame(np.vstack(pd.Series(study_ids_folders).parallel_apply(lambda x: x.split('/')[-4:-1]).values), columns=['patient_folder', 'patient_id', 'study_id'])
text_reports_df.head()

In [None]:
text_reports_df.shape

In [None]:
def get_findings_impressions(df, path):
    findings, impression = '', ''
    with open(f'{path}/files/{df["patient_folder"]}/{df["patient_id"]}/{df["study_id"]}.txt') as f:
        data = f.read()
        # Check for both FINDINGS and impression
        matches = re.search(r"^([\w\W]+?)\bFINDINGS\b([\w\W]+?)\bIMPRESSION\b([\w\W]+?)$", data)
        if matches and len(matches.groups())==3:
            findings = matches.group(2)
            impression = matches.group(3)
        if len(findings)==0:
            findings_match = re.search(r"^([\w\W]+?)\bFINDINGS\b([\w\W]+?)$", data)
            if findings_match and len(findings_match.groups())==2:
                findings = findings_match.group(2)
        if len(impression)==0:
            impression_match = re.search(r"^([\w\W]+?)\bIMPRESSION\b([\w\W]+?)$", data)
            #print(len(impression_match.groups()))
            if impression_match and len(impression_match.groups())==2:
                impression = impression_match.group(2)
        
    return findings, impression
f, i = get_findings_impressions(text_reports_df.iloc[0], MIMIC_CXR_REPORTS)
print(f)
print(i)

In [None]:
pandarallel.initialize(progress_bar=1)
text_reports_df['finding'], text_reports_df['impression'] = zip(*text_reports_df.parallel_apply(lambda x: get_findings_impressions(x, MIMIC_CXR_REPORTS), axis=1))
text_reports_df.head()

In [None]:
pandarallel.initialize(progress_bar=False)
print('Removing reports with no findings nor impressions')
mask_to_drop = text_reports_df.parallel_apply(lambda x: len(x['findings'])==0 and len(x['impression'])==0, axis=1)
text_reports_df = text_reports_df[~mask_to_drop]
print(f'Removed {mask_to_drop.sum()} reports with no findings nor impressions\n')

In [None]:
def process_findings_impressions(df):
    # Vectorized operation to get both findings and impressions
    results = df.parallel_apply(
        lambda x: get_findings_impressions(x, MIMIC_CXR_REPORTS), 
        axis=1, 
        result_type='expand'
    )
    # Directly assign columns
    df['findings'], df['impressions'] = results[0], results[1]
    return df

# Enable progress bar
tqdm.pandas()
# Process in single pass
text_reports_df = process_findings_impressions(text_reports_df)
text_reports_df.head()

In [None]:
pandarallel.initialize(progress_bar=False)
problem_df = text_reports_df[text_reports_df.parallel_apply(lambda x: not (x['findings']!='' or x['impressions']!=''), axis=1)]
problem_df.shape

In [None]:
r = problem_df.iloc[1]

In [None]:
for _, r in problem_df.sample(10).iterrows():
    with open(f'{MIMIC_CXR_REPORTS}/files/{r["patient_folder"]}/{r["patient_id"]}/{r["study_id"]}.txt') as f:
        data = f.read()

    print(data)
    
    print("-"*100)

In [None]:
# Remove rows with empty findings and impressions
print('Before removing samples:', text_reports_df.shape)
has_empty_findings_or_impressions = text_reports_df.parallel_apply(lambda x: len(x['findings'])!=0 or len(x['impression'])!=0, axis=1)
text_reports_df = text_reports_df[has_empty_findings_or_impressions]
print('After removing samples:', text_reports_df.shape)

In [None]:
(~has_empty_findings_or_impressions).sum()

In [None]:
text_reports_df.apply(lambda x: len(x['findings'])==0 or len(x['impression'])==0, axis=1).value_counts()

In [None]:
# import stanza
# CORE_NLP_DIR = '/vol/aimspace/projects/practical_WS2425/vision_language/CORE_NLP'
# stanza.install_corenlp(dir=CORE_NLP_DIR)
# os.environ['CORENLP_HOME'] = CORE_NLP_DIR



# No more CoreNLP server since pipeline is faster and easier to use


In [None]:
import stanza 
STANZA_DIR = f'{PROJECT_DIR}/stanza_resources'
#stanza.download('en', model_dir=STANZA_DIR + '/stanza_en')
nlp = stanza.Pipeline(lang='en', processors='tokenize',
                      model_dir=STANZA_DIR + '/stanza_en',
                    )

In [None]:
# Test the pipeline
doc = nlp(text_reports_df.iloc[0]['findings'])
for i, sentence in enumerate(doc.sentences):
    # merge sentences to get the original text
    print(i)
    print(' '.join([word.text for word in sentence.words]))

In [None]:
len(doc.sentences[1].words)

In [None]:
# def get_sentences(tokenized_sentences):
#     '''
#     Old version, needs to be replaced since it does not work with the new stanza library, e.g. stanza.Pipeline
#     '''
#     final_sentences = []
#     for sent in tokenized_sentences:
#         if len(sent.token)>3:
#             ## Convert tokens to a string and replace the intermediate new lines
#             final_sent = ''.join(list(map(lambda x: x.before.replace('\n', '')+x.word, list(sent.token)))).strip()
#             if final_sent.startswith(":"):
#                 final_sent = final_sent.replace(":", "").strip()
#             final_sentences.append(final_sent)
    
#     return final_sentences

def clean_text(sentence: stanza.models.common.doc.Sentence):
    '''
    Remove new lines and leading colons from the sentence
    '''
    final_sentence = sentence.text.replace('\n', '').strip()
    if final_sentence.startswith(":"):
        final_sentence = final_sentence[1:].strip()

    return final_sentence

def get_sentences(document: stanza.Document):
    '''
    '''
    final_sentences = []
    for sentence in document.sentences:
        if len(sentence.words)<3:
            continue
        final_sentences.append(clean_text(sentence))
    return final_sentences

get_sentences(doc)

In [None]:
text_reports_df['findings_tokenized_sentences'] = None
text_reports_df['impressions_tokenized_sentences'] = None

In [None]:
r

In [None]:
# Old version with CoreNLPClient
# with CoreNLPClient(properties={
#       'annotators': 'tokenize'
#   }, be_quiet=True) as client:
#     for i, (idx, r) in tqdm(enumerate(text_reports_df.iterrows()), total=text_reports_df.shape[0]):
#         cnlp_out = client.annotate(r['findings'])
#         text_reports_df.loc[idx, 'findings_tokenized_sentences'] = get_sentences(cnlp_out.sentence)
#         cnlp_out = client.annotate(r['impressions'])
#         text_reports_df.loc[idx, 'impressions_tokenized_sentences'] = get_sentences(cnlp_out.sentence)
#         if (i%100==0):
#             print(i)

# TODO is it faster to not use gpu for nlp but pandarallel
# New version with stanza.Pipeline
tqdm.pandas()
# for findings:
#text_reports_df['findings_tokenized_sentences'] = text_reports_df['findings'].progress_apply(lambda x: get_sentences(nlp(x)))
# for impressions:
text_reports_df['impressions_tokenized_sentences'] = text_reports_df['impression'].progress_apply(lambda x: get_sentences(nlp(x)))


In [None]:
# # Cleaning findings and impressions
# def clean_text(text):
#     # Replace newlines with spaces
#     text = text.replace('\n', ' ')
#     # Remove leading colon if present
#     return text[1:] if text.startswith(':') else text

# # Apply the cleaning function to both columns
# text_reports_df['findings'] = text_reports_df['findings'].apply(clean_text)
# text_reports_df['impressions'] = text_reports_df['impressions'].apply(clean_text)

In [None]:
text_reports_df.head()

In [None]:
text_reports_df.iloc[2]['impressions_tokenized_sentences']

In [None]:
ls $PROJECT_DIR

In [None]:
# Define the directory and file path
# directory = 'data/processed/tokenized_reports'
directory = f'{PROJECT_DIR}/data/interims'

os.makedirs(directory, exist_ok=True)

file_path = os.path.join(directory, 'tokenized_text_reports.pkl')

print("Target Directory:", directory)

# Save the DataFrame
with open(file_path, 'wb') as f:
    pickle.dump(text_reports_df, f)

# # Verifying the file was saved correctly
# with open(file_path, 'rb') as f:
#     data = pickle.load(f)

# # If it's a DataFrame, display the first few rows
# if isinstance(data, pd.DataFrame):
#     print(data.head())
# else:
#     # If it's not a DataFrame, just print the data
#     print(data)

In [None]:
# Function to generate image file paths and check their existence
def generate_image_path(row):
    # Construct the directory path
    directory_path = f'{MIMIC_CXR_JPG}/files/{row["patient_folder"]}/{row["patient_id"]}/{row["study_id"]}'
    
    # Use glob to find all .jpg files in the directory
    image_files = glob.glob(os.path.join(directory_path, '*.jpg'))
    
    # Print the directory and the files found
    if not image_files:
        print("No files found in directory: " + directory_path)

    return image_files

# Add a new column with image file paths
text_reports_df['image_files'] = text_reports_df.parallel_apply(generate_image_path, axis=1)
text_reports_df.head(10)


In [None]:
# How many images are there per sample?
text_reports_df['image_files'].apply(len).value_counts()

In [None]:
def resize_image(image_path, target_size=(256, 256)):
    # Load the image
    original_image = Image.open(image_path)
    
    # Create a copy for resizing
    resized_image = original_image.copy()
    
    # Resize the image to have a size of 256 on the larger side
    resized_image.thumbnail(target_size)
    
    return original_image, resized_image

In [None]:
def display_images(original, resized):
    # Display the original and resized images
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))

    axes[0].imshow(original)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    axes[1].imshow(resized)
    axes[1].set_title('Resized Image')
    axes[1].axis('off')
    
    plt.show()

In [None]:
# Test the functions
text_reports_df = data
test_image_path = text_reports_df.iloc[0]['image_files'][0]
print(test_image_path)

original, resized = resize_image(test_image_path)
print(original.size, resized.size)  

# Display the images
display_images(original, resized)


In [None]:
from PIL import Image
import os
import pickle
import pandas as pd
from pandarallel import pandarallel
pandarallel.initialize(progress_bar=3)

base_directory = f'{PROJECT_DIR}/data/processed/test'
file_path = os.path.join(PROJECT_DIR, 'data/dataframes/report_images_resized_df.pkl')

with open(file_path, 'rb') as f:
     data = pickle.load(f)
     
data.drop('image_files_resized', axis=1, inplace=True)
data.drop('impressions', axis=1, inplace=True)
data.drop('findings', axis=1, inplace=True)
data.head()


In [None]:
def resize_image(img_path:str, max_size: int=256) -> Image:
    '''
    Resize the image to have a size of max_size on the larger side
    '''
    img_copy = Image.open(img_path).copy()
    img_copy.thumbnail((max_size, max_size))
    return img_copy

def resize_and_save_images(row, target_size=256):
    resized_paths = []
    save_dir = os.path.join(base_directory, row["patient_folder"], row["patient_id"], row["study_id"])
    # Ensure the directory exists
    os.makedirs(save_dir, exist_ok=True)
    for path in row['image_files']:
        # get resized image path
        image_name = os.path.basename(path)
        resized_image_path = os.path.join(save_dir, f'{os.path.splitext(image_name)[0]}_resized.jpg')
        resized_paths.append(resized_image_path)
        # resize and save image
        if os.path.exists(resized_image_path):
            continue
        resized_image = resize_image(path, target_size)
        resized_image.save(resized_image_path)

    return resized_paths

# Test resize function
#tmp = resize_image(data.iloc[0]['image_files'][0])
data['images_resized_names'] = data.parallel_apply(resize_and_save_images, axis=1)
data.head()


In [None]:
base_directory = f'{PROJECT_DIR}/data/processed/report_images'

text_reports_df['image_files_resized'] = [[] for _ in range(len(text_reports_df))]
# Iterate over each row in the DataFrame
for idx, row in tqdm(text_reports_df.iterrows(), total=text_reports_df.shape[0]):

    resized_paths = []  # List to store resized image paths for the current row
    # Iterate over each image path in the image_files column
    for image_path in row['image_files']:        
        # Extract the image name from the path
        image_name = os.path.basename(image_path)
        
        # Construct the directory path for saving the resized image
        save_directory = os.path.join(base_directory, row["patient_folder"], row["patient_id"], row["study_id"])
        
        # Ensure the directory exists
        os.makedirs(save_directory, exist_ok=True)
        
        # Save the resized image
        resized_image_path = os.path.join(save_directory, f'{os.path.splitext(image_name)[0]}_resized.jpg')
        if not os.path.exists(resized_image_path):
            # Resize the image
            original, resized = resize_image(image_path)
            
            # Save the resized image
            resized.save(resized_image_path)        
            
        # Append the resized image path to the list
        resized_paths.append(resized_image_path)
    
    # Assign the list of resized image paths to the new column
    text_reports_df.at[idx, 'image_files_resized'] = resized_paths

In [None]:
text_reports_df.head()

In [None]:
# Save the dataframe
directory = f'{PROJECT_DIR}/data/processed/report_images_resized'

os.makedirs(directory, exist_ok=True)

file_path = os.path.join(directory, 'report_images_resized_df.pkl')

print("Target Directory:", directory)

# Save the DataFrame
with open(file_path, 'wb') as f:
    pickle.dump(text_reports_df, f)


In [None]:
# Reload the dataframe
def load_dataframe_from_pickle(file_path):
    try:
        with open(file_path, 'rb') as f:
            data = pickle.load(f)

        if isinstance(data, pd.DataFrame):
            return data
        else:
            print("The loaded data is not a DataFrame.")
            return None
    except FileNotFoundError:
        print(f"File not found: {file_path}")
        return None
    except Exception as e:
        print(f"An error occurred while loading the file: {e}")
        return None


In [None]:
# Example usage
file_path = f'{PROJECT_DIR}/data/dataframes/report_images_resized_df.pkl'
# text_reports_df = load_dataframe_from_pickle(file_path) 

# text_reports_df.head()


In [None]:
class MIMICCXRDataset:
    def __init__(self, df_file, client_id=None):
        assert os.path.exists(df_file) and df_file.endswith('.pkl'), "File must exist and be a .pkl file"
        
        with open(df_file, 'rb') as f:
            self.df = pickle.load(f)
        
        self.client_id = client_id

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if isinstance(idx, int):
            return self.df.iloc[idx]
        elif isinstance(idx, slice):
            return self.df.iloc[idx]
        else:
            raise TypeError("Index must be an integer or a slice")

# Example usage
dataset = MIMICCXRDataset(file_path)

# Access the first 10 records
print(dataset[:10])

In [None]:
text_reports_df.apply(lambda x: glob.glob(f'/MIMIC_CXR_JPG/files/{x["patient_folder"]}/{x["patient_id"]}/{x["study_id"]}/*.jpg'), axis=1).apply(len).value_counts()

In [None]:
text_reports_df['image_files'] = text_reports_df.apply(lambda x: glob.glob(f'/MIMIC_CXR_JPG/files/{x["patient_folder"]}/{x["patient_id"]}/{x["study_id"]}/*.jpg'), axis=1)

In [None]:
final_image_text_df = text_reports_df.explode("image_files").reset_index(drop=True)

In [None]:
final_image_text_df.rename(columns={'image_files':'image_fname'}, inplace=True)

In [None]:
final_image_text_df.head()

In [None]:
with open('/scratch/tm3647/public/mimic_image_text_df.pkl', 'wb') as f:
    pickle.dump(final_image_text_df, f)

### Reload Dataframe

In [None]:
with open('/scratch/tm3647/public/mimic_image_text_df.pkl', 'rb') as f:
    final_image_text_df = pickle.load(f)

In [None]:
final_image_text_df.head()

#### Resize images and store on the disk (later created squash out of this)

In [None]:
final_image_text_df['resized_image_fname'] = final_image_text_df['image_fname'].apply(lambda x: os.path.exists((os.path.splitext(x)[0]+'_resized.jpg').replace('/MIMIC_CXR_JPG/', '/vast/tm3647/physionet.org/files/mimic-cxr-jpg/2.0.0/')))

In [None]:
def preprocess(img, desired_size=320):
    old_size = img.size
    ratio = float(desired_size)/max(old_size)
    new_size = tuple([int(x*ratio) for x in old_size])
    img = img.resize(new_size, Image.Resampling.LANCZOS)
    # create a new image and paste the resized on it

    new_img = Image.new('L', (desired_size, desired_size))
    new_img.paste(img, ((desired_size-new_size[0])//2,
                        (desired_size-new_size[1])//2))
    return new_img

In [None]:
def resize_and_save(record):
    new_fname = os.path.splitext(record['image_fname'])[0]+'_resized.jpg'
    new_fname = new_fname.replace('/MIMIC_CXR_JPG/', '/vast/tm3647/physionet.org/files/mimic-cxr-jpg/2.0.0/')
    if os.path.exists(new_fname):
        return
    
    image = pyvips.Image.new_from_file(record['image_fname'], access="sequential")
    mem_img = image.write_to_memory()
    image = np.frombuffer(mem_img, dtype=np.uint8).reshape(image.height, image.width)
    
    img = preprocess(Image.fromarray(image))
    img.save(new_fname)

In [None]:
from joblib import Parallel, delayed

In [None]:
_ = Parallel(4, verbose=2)(delayed(resize_and_save)(r) for _,r in tqdm(final_image_text_df.iterrows(), total=final_image_text_df.shape[0]))

In [None]:
final_image_text_df.iloc[10000:11000].iloc[-1]['image_fname']

#### Add a resized images fname column

In [None]:
final_image_text_df.rename(columns={'image_fname':'orig_image_fname'}, inplace=True)

In [None]:
final_image_text_df['image_fname'] = final_image_text_df['orig_image_fname'].apply(lambda x: os.path.splitext(x)[0]+'_resized.jpg')

In [None]:
final_image_text_df['image_fname'].apply(lambda x: os.path.exists(x)).value_counts()

In [None]:
Image.open(final_image_text_df['image_fname'].sample().iloc[0])

#### Write the dataframe to the public folder (tejas)

In [None]:
with open('/scratch/tm3647/public/mimic_image_text_df.pkl', 'wb') as f:
    pickle.dump(final_image_text_df, f)

### Pytorch Dataset

In [None]:
class MIMICCXRDataset(torch.utils.data.Dataset):
    def __init__(self,  df_file:str , config=None, transforms=None, tokenizer=None):
        super(MIMICCXRDataset, self).__init__()
        
        self.config = config
        assert os.path.exists(df_file) and os.path.splitext(df_file)[1].lower()==".pkl", "Check file path exists and has the extension .pkl"
        
        with open(df_file, 'rb') as f:
            self.df = pickle.load(f)
        
        self.transforms = transforms
        self.tokenizer = tokenizer
    
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, idx):
        record = self.df.iloc[idx]
        
        image = pyvips.Image.new_from_file(record['image_fname'], access="sequential")
        mem_img = image.write_to_memory()
        image = np.frombuffer(mem_img, dtype=np.uint8).reshape(image.height, image.width)
        #return image, record['image_fname']
        
        image = transforms.ToTensor()(cv2.cvtColor(image, cv2.COLOR_GRAY2RGB))
        
        findings = record['findings_tokenized_sentences']
        impressions = record['impressions_tokenized_sentences']
        
        find_impres = findings + impressions
        
        assert len(find_impres)!=0, f"Issue findings/impression of {record['patient_folder']}/{record['patient_id']}/{record['study_id']}"
        
        text = np.random.choice(find_impres)
        
        if self.transforms:
            image = self.transforms(image)
        
        if self.tokenizer:
            tokenized_input_data = self.tokenizer(text, max_length=128, padding="max_length", truncation=True, return_tensors="pt")
        
        return {'image':image, 'text': text, 'tokenized_data':tokenized_input_data}

#### Exps

In [None]:
mds = MIMICDataset('/scratch/tm3647/public/mimic_image_text_df.pkl')

In [None]:
for i in mds:
    break

In [None]:
transforms.ToTensor()(i['image'].squeeze().numpy())

In [None]:
i['image'].shape

In [None]:
plt.imshow(transforms.ToPILImage()(i['image']))

#### Images write to disk

In [None]:
mds = MIMICCXRDataset('/scratch/tm3647/public/mimic_image_text_df.pkl')

In [None]:
for b in tqdm(mds):
    img = preprocess(Image.fromarray(b[0]))
    new_fname = os.path.splitext(b[1])[0]+'_resized.jpg'
    new_fname = new_fname.replace('/MIMIC_CXR_JPG/', '/vast/tm3647/physionet.org/files/mimic-cxr-jpg/2.0.0/')
    img.save(new_fname)

### Pytorch DataModule

In [None]:
from transformers import AutoModel

In [None]:
from typing import Optional, Dict, Any

In [None]:
from pytorch_lightning import LightningDataModule 
from transformers import (
    AutoConfig,
    AutoTokenizer,
)
from torchvision import transforms
from tqdm.auto import tqdm
import time

In [None]:
class MIMICCXRDataModule(LightningDataModule):
    """Example of LightningDataModule for MNIST dataset.

    A DataModule implements 5 key methods:

        def prepare_data(self):
            # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
            # download data, pre-process, split, save to disk, etc...
        def setup(self, stage):
            # things to do on every process in DDP
            # load data, set variables, etc...
        def train_dataloader(self):
            # return train dataloader
        def val_dataloader(self):
            # return validation dataloader
        def test_dataloader(self):
            # return test dataloader
        def teardown(self):
            # called on every process in DDP
            # clean up after fit or test

    This allows you to share a full dataset without explaining how to download,
    split, transform and process the data.

    Read the docs:
        https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html
    """

    def __init__(
        self,
        mimic_cxr_dataset_file:str,
        batch_size: int = 64,
        num_workers: int = 0,
        pin_memory: bool = False,
        text_model_name:str = '',
        #cfg: dict = {},
    ):
        super(MIMICCXRDataModule,).__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)

        ## Image transformations
        self.transforms = transforms.Compose([
            transforms.RandomResizedCrop(224, ratio=[0.6, 1.0]),
            transforms.RandomAffine(degrees=[-20,20], translate=(0.1,0.1), scale=(0.95, 1.05)),
            transforms.ColorJitter(brightness=(0.6, 1.4), contrast=(0.6, 1.4)),
            #transforms.GaussianBlur(G) ## Not implemented due to no info on kernel size in the paper
            #transforms.ToTensor(),
            transforms.Resize((224,224)),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        
        ## Transforms
        self.tokenizer = AutoTokenizer.from_pretrained(text_model_name, use_fast=True)

        self.dataset: Optional[Dataset] = None

    def prepare_data(self):
        """Download data if needed.
        Do not use it to assign state (self.x = y).
        """
        pass

    def setup(self):
        """Load data

        This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
        careful not to execute things like random split twice!
        """
        self.dataset = MIMICCXRDataset(self.hparams.mimic_cxr_dataset_file, transforms=self.transforms, tokenizer=self.tokenizer)

    def train_dataloader(self):
        return DataLoader(
            dataset=self.dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=True,
        )

    def teardown(self, stage: Optional[str] = None):
        """Clean up after fit or test."""
        pass

    def state_dict(self):
        """Extra things to save to checkpoint."""
        return {}

    def load_state_dict(self, state_dict: Dict[str, Any]):
        """Things to do when loading checkpoint."""
        pass
    
    def collate_and_tokenize(self, batch):
        input_data = {}
        
        images = torch.cat(list(map(lambda x: torch.unsqueeze(x['image'], 0), batch)))
        texts = list(map(lambda x: x['text'], batch))
        
        input_data = self.tokenizer.batch_encode_plus(texts, max_length=128, padding=True, truncation=True, return_tensors="pt")
        
        keys = list(input_data.keys())
        input_data['tokenized_text'] = {}
        
        for k in list(keys):
            input_data['tokenized_text'][k]=input_data.pop(k)
        
        input_data['images'] = images
        input_data['texts'] = texts
        
        return input_data

#### Exps

In [None]:
dm = MIMICCXRDataModule('/scratch/tm3647/public/mimic_image_text_df.pkl', text_model_name='emilyalsentzer/Bio_ClinicalBERT', batch_size=32)

In [None]:
dm.setup()

In [None]:
st_time = time.time()
for batch in tqdm(dm.train_dataloader()):
    images, texts = batch['image'], batch['text']
    print(time.time()-st_time)
    st_time = time.time()
    break

#### Flow janky stuff

In [None]:
st_time = time.time()
for batch in tqdm(dm.train_dataloader()):
    images, texts = batch['image'], batch['text']
    print(time.time()-st_time)
    st_time = time.time()
    break

In [None]:
batch['tokenized_text']['input_ids'].shape

In [None]:
batch.keys()

In [None]:
images.shape, len(texts)

In [None]:
text_model_input = dm.tokenizer(batch['texts'][:3], return_tensors="pt", padding=True)
text_model_input

In [None]:
text_model = AutoModel.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')

In [None]:
text_model_output = text_model(**text_model_input)

In [None]:
import torchvision
import torchsummary

In [None]:
image_model = torchvision.models.resnet50()
image_model = torch.nn.Sequential(*(list(image_model.children())[:-1]))
torchsummary.summary(image_model, (3,224,224))

In [None]:
image_fv = image_model(batch['image'][0:3])
image_fv.shape

In [None]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

In [None]:
sentence_embeddings = mean_pooling(text_model_output, text_model_input['attention_mask'])

In [None]:
proj_layer1 = torch.nn.Linear(768, 768)
proj_layer2 = torch.nn.Linear(768, 512)

In [None]:
text_proj_output = proj_layer2(proj_layer1(sentence_embeddings))

In [None]:
text_proj_output.shape

In [None]:
img_proj_layer1 = torch.nn.Linear(2048, 1024)
img_proj_layer2 = torch.nn.Linear(1024, 512)

In [None]:
img_proj_output = img_proj_layer2(img_proj_layer1(image_fv.squeeze()))

In [None]:
img_proj_output.shape

In [None]:
from torch.nn.functional import cosine_similarity, pairwise_distance

In [None]:
temperature = 0.1

In [None]:
cosine_similarity(img_proj_output, text_proj_output)

In [None]:
img_proj_output

In [None]:
import torchmetrics

In [None]:
img_text_sim = torchmetrics.functional.pairwise_cosine_similarity(img_proj_output, text_proj_output).detach()
img_text_sim

In [None]:
mean[(-0.0250 / sum(-0.0250, -0.0479, -0.0319)), -0.0247 / sum(-0.0114, -0.0247, -0.0254), ...]

In [None]:
text_img_sim = torchmetrics.functional.pairwise_cosine_similarity(text_proj_output, img_proj_output).detach()
text_img_sim

In [None]:
a = torch.exp(img_text_sim)
a

In [None]:
-torch.log(0.9753 / sum(a[0]))

In [None]:
torch.diag(-torch.nn.functional.log_softmax(img_text_sim, 1))

In [None]:
lam = 0.75

In [None]:
## Final loss fn
torch.mean(lam*torch.diag(-torch.nn.functional.log_softmax(img_text_sim, 1)) + (1-lam)*torch.diag(-torch.nn.functional.log_softmax(text_img_sim, 1)))

In [None]:
vu_sim = cosine_similarity(img_proj_output, text_proj_output)
vu_sim

In [None]:
torch.exp(vu_sim/temperature) / torch.sum(torch.exp())

### .....