In [2]:
import os
import pandas as pd
import numpy as np
import torch
import random
from tqdm import tqdm
import argparse
import json
from PIL import Image
from openai import OpenAI
import re
from transformers import ViTFeatureExtractor 
import ast
import torchvision.transforms as transforms
from utils import *
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def text2embedding(client, model, text):
    responses = client.embeddings.create(
            input=[text],
            model=model,
        )
    return responses.data[0].embedding

openai_api_key = "abc123"
openai_api_base = "http://localhost:8002/v1"
client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id

# Rex

In [None]:
def expand_medical_data_to_dataframe(data_dict):
    """
    의료 데이터 딕셔너리를 DataFrame으로 변환하면서 
    이미지 관련 정보를 개별 행으로 확장하는 함수
    Args:
        data_dict: 의료 데이터가 담긴 딕셔너리
        
    Returns:
        pd.DataFrame: 확장된 데이터프레임
    """
    # 이미지 관련 필드들 (리스트 형태로 되어있는 필드들)
    image_fields = ['ImagePath', 'ImageModality', 'ImageShape', 'ImageBodyPart', 'ImageViewPosition']
    # 이미지 개수 확인
    n_images = len(data_dict['ImagePath'])
    # 결과를 담을 리스트
    rows = []
    # 각 이미지에 대해 행을 생성
    for i in range(n_images):
        row = {}
        
        # 기본 정보들 (모든 행에 동일하게 복사)
        for key, value in data_dict.items():
            if key not in image_fields:
                row[key] = value
            else:
                # 이미지 관련 정보는 해당 인덱스의 값을 사용
                if isinstance(value, list) and i < len(value):
                    row[key] = value[i]
                else:
                    row[key] = None
        rows.append(row)
    # DataFrame 생성
    df = pd.DataFrame(rows)
    return df
def process_multiple_medical_records(data_list):
    """
    여러 의료 기록 딕셔너리를 처리하는 함수
    
    Args:
        data_list: 의료 데이터 딕셔너리들의 리스트
        
    Returns:
        pd.DataFrame: 모든 기록을 확장한 통합 데이터프레임
    """
    all_rows = []
    for data_dict in data_list:
        df = expand_medical_data_to_dataframe(data_dict)
        all_rows.append(df)
    # 모든 DataFrame을 하나로 합치기
    combined_df = pd.concat(all_rows, ignore_index=True)
    return combined_df

def standardize_view_position_direct(df, column_name='ImageViewPosition'):
    """
    딕셔너리를 사용한 직접 매핑 방식
    """
    mapping = {
        'PA': 'PA',
        'POSTERO_ANTERIOR': 'PA',
        'AP': 'AP', 
        'ANTERO_POSTERIOR': 'AP',
        'AP AXIAL': 'AP'
    }
    
    df_standardized = df.copy()
    df_standardized[column_name] = df_standardized[column_name].map(mapping).fillna(df_standardized[column_name])
    
    return df_standardized

In [None]:
df = pd.read_csv('/data/ReXGradient-160K/metadata/train_metadata.csv')
with open('/data/ReXGradient-160K/metadata/train_metadata_view_position.json', 'r', encoding='utf-8') as file:
    json_data = json.load(file)
df2 = process_multiple_medical_records(list(json_data.values()))
df2 = standardize_view_position_direct(df2)
df2 = df2[(df2['ImageViewPosition']=='AP') | (df2['ImageViewPosition']=='PA')]

In [None]:
embedding_rows = []
for idx, row in tqdm(df2.iterrows()):
    note = "Findings: {} \nImpression: {}".format(row['Findings'], row['Impression'])
    embedding = text2embedding(client, model, note)
    embedding_rows.append(embedding)

# Mimic

In [None]:
def standardize_view_position_direct(df, column_name='ViewPosition'):
    """
    딕셔너리를 사용한 직접 매핑 방식
    """
    mapping = {
        'PA': 'PA',
        'PA LLD': 'PA',
        'PA RLD': 'PA',
        'AP': 'AP', 
        'AP AXIAL': 'AP',
        'AP LLD': 'AP',
        'AP RLD': 'AP'
    }
    
    df_standardized = df.copy()
    df_standardized[column_name] = df_standardized[column_name].map(mapping).fillna(df_standardized[column_name])
    
    return df_standardized

def load_text(path):
    with open(path,'r') as file:
        lines=file.readlines()
        file_content=''.join(lines)
    return file_content.split("FINAL REPORT\n ")[1].replace('\n ','\n') #

def text_processing(full_text):
    findings_pattern = r"FINDINGS:(.*?)"
    findings_match = re.search(findings_pattern, full_text, re.DOTALL)
    impression_pattern = r"IMPRESSION:(.*?)"
    impression_match = re.search(impression_pattern, full_text, re.DOTALL)
    if findings_match and impression_match:
        findings_start = findings_match.span()[0]
        impression_start = impression_match.span()[0]
        if findings_start <= impression_start :
            text = full_text[findings_start:]
        else:
            text = full_text[impression_start:]
    elif findings_match and (not impression_match):
        findings_start = findings_match.span()[0]
        text = full_text[findings_start:]
    elif (not findings_match) and impression_match:
        impression_start = impression_match.span()[0]
        text = full_text[impression_start:]
    else:
        text = full_text
    return text

In [None]:
df = pd.read_csv('/data/mimic3_cxr_jpg/mimic-cxr-dataset.csv')
df = standardize_view_position_direct(df)
df = df[(df['ViewPosition'] == "PA") | (df['ViewPosition'] == "AP")].reset_index()

In [None]:
embedding_rows = []
for idx, row in tqdm(df.iterrows()):
    note = load_text('/data/mimic3_cxr_jpg/'+row['path'])
    note = text_processing(note)
    embedding = text2embedding(client, model, note)
    embedding_rows.append(embedding)
    if idx == 10:
        break

10it [00:00, 11.75it/s]


In [None]:
len(df['study_id'].unique())

218132

In [None]:
len(df)

243335

# DataLoad

In [2]:
def image_path_refine(row):
    return f'/data/mimic3_cxr_jpg/mimic-cxr-jpg-2.0.0.physionet.org/files/p{str(row['subject_id'])[:2]}/p{row['subject_id']}/s{row['study_id']}/{row['dicom_id']}.jpg'

In [3]:
# mimic_train_df = pd.read_csv('/data/mimic3_cxr_jpg/train_with_view_embeddings.csv')
# mimic_train_df['ImagePath'] = mimic_train_df.apply(image_path_refine, axis=1)
# rex_train_df = pd.read_csv('/data/ReXGradient-160K/metadata/train_with_view_embeddings.csv')
# rex_train_df['ImagePath'] = rex_train_df['ImagePath'].apply(lambda x : x.replace('../', '/data/ReXGradient-160K/'))
# train_df = pd.concat([mimic_train_df[['ImagePath', 'embeddings']],rex_train_df[['ImagePath', 'embeddings']]], axis=0).reset_index(drop=True)
train_df = pd.read_csv('/data/code/CXR_embedding_research/dev.csv')

In [4]:
augment_tool = KorniaGPUAugmentation()
dataloader = create_dataloader(train_df, label_type='embedding', batch_size=32, shuffle=True, num_workers=4)

# Train

In [5]:
from transformers import ViTModel, ViTFeatureExtractor 
model = ViTModel.from_pretrained('/data/models/vit-base-patch16-384')
model = custom_vit_embed(model)
model = model.to('cuda')
optimizer = torch.optim.AdamW(model.parameters(), lr = 3e-4, betas=(0.9,0.999), eps=1e-6, weight_decay=0.01, amsgrad=False)
criterion = torch.nn.CosineEmbeddingLoss()

Some weights of ViTModel were not initialized from the model checkpoint at /data/models/vit-base-patch16-384 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
imgs, labels = next(iter(dataloader))
pixel_values = augment_tool(imgs, True)
pixel_values = pixel_values.to('cuda')
# pixel_values = imgs.to('cuda')
labels = labels.to('cuda')
output = model(pixel_values)

In [7]:
ones = torch.ones(32).to('cuda')
loss = criterion(output, labels, ones)

# Text augmentation

In [3]:
mimic_train_df = pd.read_csv('/data/mimic3_cxr_jpg/train_with_view_embeddings.csv')
mimic_train_df['ImagePath'] = mimic_train_df.apply(image_path_refine, axis=1)