In [1]:
import argparse
import os
import ruamel_yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path
import warnings
warnings.filterwarnings("ignore")


import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn

from tensorboardX import SummaryWriter

import utils
from scheduler import create_scheduler
from optim import create_optimizer
from dataset.dataset import MedKLIP_Dataset
from models.model_MedKLIP import MedKLIP
from models.tokenization_bert import BertTokenizer

2024-02-20 15:21:27.524570: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-02-20 15:21:27.548242: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-20 15:21:27.548267: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-20 15:21:27.549223: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-20 15:21:27.554053: I tensorflow/core/platform/cpu_feature_guar

In [2]:
def get_tokenizer(tokenizer, target_text):

    target_tokenizer = tokenizer(
        list(target_text),
        padding="max_length",
        truncation=True,
        max_length=128,
        return_tensors="pt",
    )

    return target_tokenizer

In [3]:
config = yaml.load(open('configs/Pretrain_MedKLIP.yaml', "r"), Loader=yaml.Loader)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Total CUDA devices: ", torch.cuda.device_count())
torch.set_default_tensor_type("torch.FloatTensor")
cudnn.benchmark = True

start_epoch = 0
max_epoch = config["schedular"]["epochs"]
warmup_steps = config["schedular"]["warmup_epochs"]

print("Creating book")
json_book = json.load(open('data_file/observation explanation.json', "r"))
disease_book = [json_book[i] for i in json_book]
ana_book = [
    "It is located at " + i
    for i in [
        "trachea",
        "left_hilar",
        "right_hilar",
        "hilar_unspec",
        "left_pleural",
        "right_pleural",
        "pleural_unspec",
        "heart_size",
        "heart_border",
        "left_diaphragm",
        "right_diaphragm",
        "diaphragm_unspec",
        "retrocardiac",
        "lower_left_lobe",
        "upper_left_lobe",
        "lower_right_lobe",
        "middle_right_lobe",
        "upper_right_lobe",
        "left_lower_lung",
        "left_mid_lung",
        "left_upper_lung",
        "left_apical_lung",
        "left_lung_unspec",
        "right_lower_lung",
        "right_mid_lung",
        "right_upper_lung",
        "right_apical_lung",
        "right_lung_unspec",
        "lung_apices",
        "lung_bases",
        "left_costophrenic",
        "right_costophrenic",
        "costophrenic_unspec",
        "cardiophrenic_sulcus",
        "mediastinal",
        "spine",
        "clavicle",
        "rib",
        "stomach",
        "right_atrium",
        "right_ventricle",
        "aorta",
        "svc",
        "interstitium",
        "parenchymal",
        "cavoatrial_junction",
        "cardiopulmonary",
        "pulmonary",
        "lung_volumes",
        "unspecified",
        "other",
    ]
]
tokenizer = BertTokenizer.from_pretrained(config["text_encoder"])
ana_book_tokenizer = get_tokenizer(tokenizer, ana_book).to(device)
disease_book_tokenizer = get_tokenizer(tokenizer, disease_book).to(device)
print("Creating model")


Total CUDA devices:  1
Creating book
Creating model


In [5]:
model = MedKLIP(config, ana_book_tokenizer, disease_book_tokenizer, mode="train").to('cuda:0')
model_after = MedKLIP(config, ana_book_tokenizer, disease_book_tokenizer, mode="train").to('cuda:0')
model.load_state_dict(torch.load('/home/wenrui/Projects/MIMIC/MedKLIP/runs/baseline/checkpoint_state.pth')['model'], strict=False)
model_after.load_state_dict(torch.load('/home/wenrui/Projects/MIMIC/MedKLIP/runs/resnet_pcl/2024-02-04_22-13-41/checkpoint_98.pth')['model'], strict=False)
ana_book = model.ana_book
disease_book = model.disease_book
disease_embedding = model.disease_embedding_layer
disease_embedding_after = model_after.disease_embedding_layer
ana_book_before = disease_embedding(ana_book)
disease_book_before = disease_embedding(disease_book)
ana_book_after = disease_embedding_after(ana_book)
disease_book_after = disease_embedding_after(disease_book)

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


text feature extractor: emilyalsentzer/Bio_ClinicalBERT
Image feature extractor: resnet50


Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


text feature extractor: emilyalsentzer/Bio_ClinicalBERT
Image feature extractor: resnet50


In [6]:
ana_name = [
    "trachea",
    "left_hilar",
    "right_hilar",
    "hilar_unspec",
    "left_pleural",
    "right_pleural",
    "pleural_unspec",
    "heart_size",
    "heart_border",
    "left_diaphragm",
    "right_diaphragm",
    "diaphragm_unspec",
    "retrocardiac",
    "lower_left_lobe",
    "upper_left_lobe",
    "lower_right_lobe",
    "middle_right_lobe",
    "upper_right_lobe",
    "left_lower_lung",
    "left_mid_lung",
    "left_upper_lung",
    "left_apical_lung",
    "left_lung_unspec",
    "right_lower_lung",
    "right_mid_lung",
    "right_upper_lung",
    "right_apical_lung",
    "right_lung_unspec",
    "lung_apices",
    "lung_bases",
    "left_costophrenic",
    "right_costophrenic",
    "costophrenic_unspec",
    "cardiophrenic_sulcus",
    "mediastinal",
    "spine",
    "clavicle",
    "rib",
    "stomach",
    "right_atrium",
    "right_ventricle",
    "aorta",
    "svc",
    "interstitium",
    "parenchymal",
    "cavoatrial_junction",
    "cardiopulmonary",
    "pulmonary",
    "lung_volumes",
    "unspecified",
    "other",
]
disease_name = [
    "normal",
    "clear",
    "sharp",
    "sharply",
    "unremarkable",
    "intact",
    "stable",
    "free",
    "effusion",
    "opacity",
    "pneumothorax",
    "edema",
    "atelectasis",
    "tube",
    "consolidation",
    "process",
    "abnormality",
    "enlarge",
    "tip",
    "low",
    "pneumonia",
    "line",
    "congestion",
    "catheter",
    "cardiomegaly",
    "fracture",
    "air",
    "tortuous",
    "lead",
    "disease",
    "calcification",
    "prominence",
    "device",
    "engorgement",
    "picc",
    "clip",
    "elevation",
    "expand",
    "nodule",
    "wire",
    "fluid",
    "degenerative",
    "pacemaker",
    "thicken",
    "marking",
    "scar",
    "hyperinflate",
    "blunt",
    "loss",
    "widen",
    "coll_eapse",
    "density",
    "emphysema",
    "aerate",
    "mass",
    "crowd",
    "infiltrate",
    "obscure",
    "deformity",
    "hernia",
    "drainage",
    "distention",
    "shift",
    "stent",
    "pressure",
    "lesion",
    "finding",
    "borderline",
    "hardware",
    "dilation",
    "chf",
    "redistribution",
    "aspiration",
    "tail_abnorm_obs",
    "excluded_obs",
]

In [7]:
ana_book_after = ana_book_after / ana_book_after.norm(dim=1, keepdim=True)
disease_book_after = disease_book_after / disease_book_after.norm(dim=1, keepdim=True)
ana_book_before = ana_book_before / ana_book_before.norm(dim=1, keepdim=True)
disease_book_before = disease_book_before / disease_book_before.norm(dim=1, keepdim=True)
logits_after = torch.mm(ana_book_after, disease_book_after.T).cpu().detach().numpy()
logits_before = torch.mm(ana_book_before, disease_book_before.T).cpu().detach().numpy()

In [8]:
# import seaborn as sns
# import numpy as np
# from sklearn.manifold import TSNE
# import matplotlib.pyplot as plt

# plt.figure(figsize=(150, 120))
# plt.subplot(221)
# ax = sns.heatmap(logits, linecolor="white", linewidth=0.5, annot=True, fmt=".2f")
# ax.set_xticklabels(disease_name, rotation=90)
# ax.set_yticklabels(ana_name, rotation=0)
# plt.subplot(222)
# ax = sns.heatmap(logits_reshape, linecolor="white", linewidth=0.5, annot=True, fmt=".2f")
# ax.set_xticklabels(disease_name, rotation=90)
# ax.set_yticklabels(ana_name, rotation=0)
# plt.subplot(223)
# ax = sns.heatmap(logits_reshape-logits, linecolor="white", linewidth=0.5, annot=True, fmt=".2f")
# ax.set_xticklabels(disease_name, rotation=90)
# ax.set_yticklabels(ana_name, rotation=0)
# plt.show()

In [9]:


# disease_book_before = torch.concatenate([ana_book_before, disease_book_before], dim=0)

# disease_book = torch.concatenate([ana_book, disease_book], dim=0)
# labels = ana_name + disease_name
# # data_np = disease_book.cpu().detach().numpy()
# # data_np_reshape = disease_book_before.cpu().detach().numpy()
# # tsne = TSNE(n_components=2, random_state=0)
# # data_embedded = tsne.fit_transform(data_np)
# # data_embedded_reshape = tsne.fit_transform(data_np_reshape)
# # lists = disease_name + ana_name
# # labels = [str(i) for i in range(len(lists))]


# # plt.figure(figsize=(20, 10))
# # plt.subplot(121)
# # for i, (x, y) in enumerate(data_embedded):
# #     plt.scatter(x, y)
# #     plt.text(x, y, labels[i], fontsize=12)
# # plt.subplot(122)
# # for i, (x, y) in enumerate(data_embedded_reshape):
# #     plt.scatter(x, y)
# #     plt.text(x, y, labels[i], fontsize=12)
# # plt.show()

In [10]:
# disease_book = disease_book / disease_book.norm(dim=1, keepdim=True)
# disease_book_before = disease_book_before / disease_book_before.norm(dim=1, keepdim=True)
# disease_map = torch.mm(disease_book, disease_book.T)
# disease_map = disease_map.cpu().detach().numpy()
# disease_map_reshape = torch.mm(disease_book_before, disease_book_before.T)
# disease_map_reshape = disease_map_reshape.cpu().detach().numpy()

In [11]:
# import seaborn as sns


# plt.figure(figsize=(150, 120))
# plt.subplot(221)
# ax = sns.heatmap(disease_map, linecolor="white", linewidth=0.5, annot=True, fmt=".2f")
# ax.set_xticklabels(labels, rotation=90)
# ax.set_yticklabels(labels, rotation=0)
# plt.subplot(222)
# ax_reshape = sns.heatmap(disease_map_reshape, linecolor="white", linewidth=0.5, annot=True, fmt=".2f")
# ax_reshape.set_xticklabels(labels, rotation=90)
# ax_reshape.set_yticklabels(labels, rotation=0)
# plt.subplot(223)
# ax = sns.heatmap(disease_map_reshape-disease_map, linecolor="white", linewidth=0.5, annot=True, fmt=".2f")
# ax.set_xticklabels(labels, rotation=90)
# ax.set_yticklabels(labels, rotation=0)
# plt.show()


In [12]:
import numpy as np
gts = np.load('/home/wenrui/Projects/MIMIC/MedKLIP/setting/landmark_observation_adj_mtx.npy')

In [13]:
gts[gts<0] = 0


In [14]:
gts = gts.sum(axis=0)

In [15]:
gts = gts / gts.sum()

In [16]:
# np.save('/mnt/a/MedKLIP/gts.npy', gts)

In [17]:
((logits_after - gts) ** 2).mean()

0.5216448513432825

In [18]:
((logits_before - gts) ** 2).mean()

0.5647437311018451

In [19]:
epsilon = 1e-8
np.sum(logits_before * np.log(logits_before / (gts + epsilon) + epsilon), axis=1).mean()

782.1597881518973

In [20]:
np.sum(logits_after * np.log(logits_after / (gts + epsilon) + epsilon), axis=1).mean()

749.032755664473