In [27]:
import argparse
import os
import ruamel_yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path
import warnings
warnings.filterwarnings("ignore")


import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn

from tensorboardX import SummaryWriter

import utils
from scheduler import create_scheduler
from optim import create_optimizer
from dataset.dataset2 import MedKLIP_Dataset
from models.model_MedKLIP import MedKLIP
from models.tokenization_bert import BertTokenizer

import nltk

In [28]:
def get_tokenizer(tokenizer, target_text):

    target_tokenizer = tokenizer(
        list(target_text),
        padding="max_length",
        truncation=True,
        max_length=128,
        return_tensors="pt",
    )

    return target_tokenizer

In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = yaml.load(open('configs/Pretrain_MedKLIP.yaml', "r"), Loader=yaml.Loader)
json_book = json.load(open('data_file/observation explanation.json', "r"))
disease_book = [json_book[i] for i in json_book]
ana_book = [
    "It is located at " + i
    for i in [
        "trachea",
        "left_hilar",
        "right_hilar",
        "hilar_unspec",
        "left_pleural",
        "right_pleural",
        "pleural_unspec",
        "heart_size",
        "heart_border",
        "left_diaphragm",
        "right_diaphragm",
        "diaphragm_unspec",
        "retrocardiac",
        "lower_left_lobe",
        "upper_left_lobe",
        "lower_right_lobe",
        "middle_right_lobe",
        "upper_right_lobe",
        "left_lower_lung",
        "left_mid_lung",
        "left_upper_lung",
        "left_apical_lung",
        "left_lung_unspec",
        "right_lower_lung",
        "right_mid_lung",
        "right_upper_lung",
        "right_apical_lung",
        "right_lung_unspec",
        "lung_apices",
        "lung_bases",
        "left_costophrenic",
        "right_costophrenic",
        "costophrenic_unspec",
        "cardiophrenic_sulcus",
        "mediastinal",
        "spine",
        "clavicle",
        "rib",
        "stomach",
        "right_atrium",
        "right_ventricle",
        "aorta",
        "svc",
        "interstitium",
        "parenchymal",
        "cavoatrial_junction",
        "cardiopulmonary",
        "pulmonary",
        "lung_volumes",
        "unspecified",
        "other",
    ]
]
tokenizer = BertTokenizer.from_pretrained(config["text_encoder"])
ana_book_tokenizer = get_tokenizer(tokenizer, ana_book).to(device)
disease_book_tokenizer = get_tokenizer(tokenizer, disease_book).to(device)
print("Creating model")
model = MedKLIP(config, ana_book_tokenizer, disease_book_tokenizer, mode="train")

ckpt = torch.load("/home/wenrui/Projects/MIMIC/MedKLIP/runs/dual_stream/2024-02-14_22-44-14/checkpoint_64.pth")
model.load_state_dict(ckpt["model"], strict=False)

model = model.to(device)

Creating model


Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


text feature extractor: emilyalsentzer/Bio_ClinicalBERT
Image feature extractor: resnet50


In [30]:
val_datasets = MedKLIP_Dataset(
    '../setting/rad_graph_metric_validate_local.json', '../setting/landmark_observation_adj_mtx.npy', mode="train"
)

val_dataloader = DataLoader(
    val_datasets,
    batch_size=1,
    num_workers=30,
    pin_memory=True,
    # sampler=val_sampler,
    shuffle=True,
    collate_fn=None,
    drop_last=True,
)

In [31]:
from PIL import Image
from torchvision import transforms

transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
image = Image.open("/home/wenrui/Projects/MIMIC/Data/physionet.org/files/mimic-cxr-jpg/2.0.0/files/p10/p10000032/s50414267/02aa804e-bde0afdd-112c0b34-7bc16630-4e384014.jpg").convert("RGB")
with open("/home/wenrui/Projects/MIMIC/CXR_NOTE/p10000032/s50414267.txt") as f:
    text = f.read()
    f.close()
image = transform(image)
image = image.unsqueeze(0).to(device)


In [None]:
for i, sample in enumerate(val_dataloader):
    print(i)
    images = sample["image"].to(device)
    labels_e = sample["label_e"].to(device)
    labels_p = sample["label_p"].to(device)
    index_e = sample["index_e"].to(device)
    index_p = sample["index_p"].to(device)
    matrix = sample["matrix"].to(device)
    text = sample["txt"]
    print(text)
    loss, x_e, ws_e, x_p, ws_p = model(
        images,
        labels_e=labels_e,
        labels_p=labels_p,
        matrix=matrix,
        sample_index_e=index_e,
        sample_index_p=index_p,
        is_train=True,
        text_gen=True,
        no_cl=config["no_cl"],
        exclude_class=config["exclude_class"],
    )
    # print(output)
    break

ERROR:tornado.general:SEND Error: Host unreachable


In [33]:
ann = json.load(open('/home/wenrui/Projects/MIMIC/MedKLIP/setting/rad_graph_metric_test_local.json', "r"))

In [34]:
image_list = list(ann.keys())

In [36]:
ann[image_list[0]]#["text"]

{'img_path': '/home/wenrui/Projects/MIMIC/physionet.org/files/mimic-cxr-jpg/2.0.0/files/p10/p10032725/s50331901/687754ce-7420bfd3-0a19911f-a27a3916-9019cd53.jpg',
 'txt_path': '/home/wenrui/Projects/MIMIC/CXR_NOTE/p10032725/s50331901.txt',
 'labels_id': 7892}

In [4]:
import json
import os

dic = json.load(open('/home/wenrui/Projects/MIMIC/MedKLIP/setting/rad_graph_metric_test_local.json', "r"))
img_list = list(dic.keys()) 
no = 0
for i in img_list:
    txt_path = dic[i]["txt_path"]
    if not os.path.exists(txt_path):
        print(i, txt_path)
        no += 1

In [2]:
no

4881

In [3]:
len(img_list)

4881