In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm, trange
import os

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import Adam
from torchinfo import summary

from utils import make_time_bins
from utils import encode_survival, mtlr_neg_log_likelihood, make_optimizer
from utils import mtlr_survival, mtlr_risk
from prognosis_model import embd_model, lora_model

from ct_clip import CTCLIP
from transformer_maskgit import CTViT
from transformers import BertTokenizer, BertModel
from lifelines.utils import concordance_index
from data_inference_hector import Hector_Dataset_emb, Hector_Dataset

from peft import get_peft_config, get_peft_model, LoraConfig, TaskType


seed = 42
torch.manual_seed(seed) 
generator = torch.Generator().manual_seed(seed)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

tokenizer = BertTokenizer.from_pretrained('microsoft/BiomedVLP-CXR-BERT-specialized',do_lower_case=True)
text_encoder = BertModel.from_pretrained("microsoft/BiomedVLP-CXR-BERT-specialized")

text_encoder.resize_token_embeddings(len(tokenizer))
text_encoder.to(device)

image_encoder = CTViT(
    dim = 512,
    codebook_size = 8192,
    image_size = 480,
    patch_size = 20,
    temporal_patch_size = 10,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 32,
    heads = 8
)

image_encoder.to(device)

clip = CTCLIP(
    image_encoder = image_encoder,
    text_encoder = text_encoder,
    dim_image = 294912,
    dim_text = 768,
    dim_latent = 512,
    extra_latent_projection = False,         # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
    use_mlm=False,
    downsample_image_embeds = False,
    use_all_token_embeds = False,
)

clip.load("/home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/CT-CLIP/CT-CLIP_v2.pt")
clip.to(device)

hect_dataset = Hector_Dataset(data_folder = "/home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/valid_preprocessed_hector/",  
                csv_file ="/home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/final_hector_with_text.csv")


train_size = int(0.8 * len(hect_dataset))  # 80% for training
test_size = len(hect_dataset) - train_size  # 20% for testing
train_dataset, test_dataset = random_split(hect_dataset, [train_size, test_size], generator=generator)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

df = pd.read_csv("/home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/final_hector_with_text.csv")
time_bins = make_time_bins(df['RFS'].values, event = df['Relapse'].values)
num_time_bins = len(time_bins)

peft_config = LoraConfig(
    inference_mode=False, r=8, lora_alpha=64, lora_dropout=0.2, target_modules=["to_q", "to_kv"]
)

model = lora_model(clip, device, peft_config, num_time_bins)
model.to(device)

  from .autonotebook import tqdm as notebook_tqdm
  @autocast(enabled = False)
  @autocast(enabled = False)
  return torch.load(checkpoint_file, map_location="cpu")
  pt = torch.load(str(path))


lora_model(
  (clip): PeftModel(
    (base_model): LoraModel(
      (model): CTCLIP(
        (text_transformer): BertModel(
          (embeddings): BertEmbeddings(
            (word_embeddings): Embedding(30522, 768, padding_idx=0)
            (position_embeddings): Embedding(512, 768)
            (token_type_embeddings): Embedding(2, 768)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.25, inplace=False)
          )
          (encoder): BertEncoder(
            (layer): ModuleList(
              (0-11): 12 x BertLayer(
                (attention): BertAttention(
                  (self): BertSelfAttention(
                    (query): Linear(in_features=768, out_features=768, bias=True)
                    (key): Linear(in_features=768, out_features=768, bias=True)
                    (value): Linear(in_features=768, out_features=768, bias=True)
                    (dropout): Dropout(p=0.25, inplace=False)
             

In [69]:
model.load_state_dict(torch.load("/opt/sagemaker/new_home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/save/model_weights/weight_095.pth"))
print("Validation of the model")
model.eval()
pred_risk_all = []
relapse_all = []
RFS_all = []
pred_survival_all = []
with torch.no_grad():
    for img_emb, text_emb, relapse, RFS, _ in test_loader:
        img_emb = img_emb.to(device)
        text_emb=tokenizer(text_emb, return_tensors="pt", padding="max_length", truncation=True, max_length=512).to(device)

        y_pred = model(text_emb, img_emb)
        pred_survival = mtlr_survival(y_pred).cpu().numpy()
        pred_risk = mtlr_risk(y_pred).cpu().numpy()

        pred_risk_all.append(pred_risk.item()) 
        relapse_all.append(relapse.item())
        RFS_all.append(RFS.item())
        pred_survival_all.append(list(pred_survival[0]))

ci = concordance_index(RFS_all, -np.array(pred_risk_all), event_observed=relapse_all)
print(f"Concordance Index: {ci:.4f}")

  model.load_state_dict(torch.load("/opt/sagemaker/new_home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/save/model_weights/weight_095.pth"))


Validation of the model
Concordance Index: 0.4298


In [71]:
np.array(pred_survival_all)

array([[1.        , 0.9954841 , 0.98311925, ..., 0.84875095, 0.7843362 ,
        0.4419284 ],
       [0.99999994, 0.9804237 , 0.9392827 , ..., 0.6399225 , 0.5471909 ,
        0.31601977],
       [1.        , 0.9896096 , 0.96635306, ..., 0.7328142 , 0.6472257 ,
        0.36006597],
       ...,
       [0.99999994, 0.9957256 , 0.9821786 , ..., 0.7130777 , 0.64634717,
        0.29794365],
       [1.        , 0.9987982 , 0.99350846, ..., 0.928804  , 0.8949395 ,
        0.50034034],
       [1.        , 0.999854  , 0.99934024, ..., 0.98197997, 0.9433391 ,
        0.6599177 ]], dtype=float32)

In [81]:
# get only those RFS_all values for which relapse is 1

{f"auc_{t}": auc[i] for i, t in enumerate(eval_times)}

{'auc_266': 0.43548387096774194,
 'auc_405': 0.30461922596754054,
 'auc_782': 0.4206426484907498}

In [83]:
np.array([RFS_all[i] for i in range(len(RFS_all)) if relapse_all[i] == 1])

array([1362,  400,  323,  435,  196,  393, 2315,   88,  846,  526,  920,
       4425,  410,  202,  592,  248,  330,   96])

In [106]:
# check if any elemt is greater than 1

for i in range(len(pred_survival_all)):
    if np.array(pred_survival_all)[i].max() > 1:
        print(i, np.array(pred_survival_all)[i].max())

In [105]:
pred_survival_all[15][0]

# replace all the values greater than 1 with 1

for i in range(len(pred_survival_all)):
    for j in range(len(pred_survival_all[i])):
        if pred_survival_all[i][j] > 1:
            pred_survival_all[i][j] = 1

In [99]:
np.array(pred_survival_all)[0].max()

1.0

In [107]:
eval_times = np.quantile(np.array([RFS_all[i] for i in range(len(RFS_all)) if relapse_all[i] == 1]), [.25, .5, .75]).astype(int)

bs = brier_score_at_times(np.array(RFS_all), np.array(pred_survival_all), np.array(relapse_all), eval_times)
auc = roc_auc_at_times(np.array(RFS_all), np.array(pred_survival_all), np.array(relapse_all), eval_times)

In [108]:
metrics = []

metrics.append({
    "model": "mtlr",
    **{f"bs_{t}": bs[i] for i, t in enumerate(eval_times)},
    **{f"auc_{t}": auc[i] for i, t in enumerate(eval_times)}
})

pd.DataFrame(metrics).round(3)


Unnamed: 0,model,bs_266,bs_405,bs_782,auc_266,auc_405,auc_782
0,mtlr,0.051,0.092,0.139,0.405,0.305,0.421


In [4]:
from sklearn.metrics import brier_score_loss, roc_auc_score

def compute_metric_at_times(metric, time_true, prob_pred, event_observed, score_times):
    """Helper function to evaluate a metric at given timepoints."""
    scores = []
    for time, pred in zip(score_times, prob_pred.T):
        target = time_true > time
        uncensored = target | event_observed.astype(bool)
        scores.append(metric(target[uncensored], pred[uncensored]))
        
    return scores


def brier_score_at_times(time_true, prob_pred, event_observed, score_times):
    scores = compute_metric_at_times(brier_score_loss, 
                                     time_true,
                                     prob_pred,
                                     event_observed,
                                     score_times)
    return scores


def roc_auc_at_times(time_true, prob_pred, event_observed, score_times):
    scores = compute_metric_at_times(roc_auc_score, 
                                     time_true,
                                     prob_pred, 
                                     event_observed,
                                     score_times)
    return scores

## CHecking weights

In [1]:
import torch
ct_weight = torch.load('/home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/CT-CLIP/CT-CLIP_v2.pt')
my_w = torch.load('/opt/sagemaker/new_home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/save/model_weights_new/weight_020.pth')

  ct_weight = torch.load('/home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/CT-CLIP/CT-CLIP_v2.pt')
  my_w = torch.load('/opt/sagemaker/new_home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/save/model_weights_new/weight_020.pth')


In [2]:
# ct_weight.keys()

# rename the keys of the model weights. add "clip.base_model.model." to all the keys

ct_weight_new = {}
for k in ct_weight.keys():
    ct_weight_new["clip.base_model.model."+k] = ct_weight[k]

In [4]:
ct_weight_new.keys()

dict_keys(['clip.base_model.model.temperature', 'clip.base_model.model.text_transformer.embeddings.position_ids', 'clip.base_model.model.text_transformer.embeddings.word_embeddings.weight', 'clip.base_model.model.text_transformer.embeddings.position_embeddings.weight', 'clip.base_model.model.text_transformer.embeddings.token_type_embeddings.weight', 'clip.base_model.model.text_transformer.embeddings.LayerNorm.weight', 'clip.base_model.model.text_transformer.embeddings.LayerNorm.bias', 'clip.base_model.model.text_transformer.encoder.layer.0.attention.self.query.weight', 'clip.base_model.model.text_transformer.encoder.layer.0.attention.self.query.bias', 'clip.base_model.model.text_transformer.encoder.layer.0.attention.self.key.weight', 'clip.base_model.model.text_transformer.encoder.layer.0.attention.self.key.bias', 'clip.base_model.model.text_transformer.encoder.layer.0.attention.self.value.weight', 'clip.base_model.model.text_transformer.encoder.layer.0.attention.self.value.bias', 'cli

In [30]:
list(my_w.keys())[228]

'clip.base_model.model.visual_transformer.enc_spatial_transformer.layers.0.1.to_q.base_layer.weight'

In [31]:
list(ct_weight_new.keys())[228]

'clip.base_model.model.visual_transformer.enc_spatial_transformer.layers.0.1.to_q.weight'

In [33]:
for i, k in enumerate(ct_weight_new.keys()):
    try:
        if not torch.equal(ct_weight_new[k], my_w[k]):
            print(k, i)
    except:
        print(k)

clip.base_model.model.visual_transformer.enc_spatial_transformer.layers.0.1.to_q.weight
clip.base_model.model.visual_transformer.enc_spatial_transformer.layers.0.1.to_kv.weight
clip.base_model.model.visual_transformer.enc_spatial_transformer.layers.1.1.to_q.weight
clip.base_model.model.visual_transformer.enc_spatial_transformer.layers.1.1.to_kv.weight
clip.base_model.model.visual_transformer.enc_spatial_transformer.layers.2.1.to_q.weight
clip.base_model.model.visual_transformer.enc_spatial_transformer.layers.2.1.to_kv.weight
clip.base_model.model.visual_transformer.enc_spatial_transformer.layers.3.1.to_q.weight
clip.base_model.model.visual_transformer.enc_spatial_transformer.layers.3.1.to_kv.weight
clip.base_model.model.visual_transformer.enc_temporal_transformer.layers.0.1.to_q.weight
clip.base_model.model.visual_transformer.enc_temporal_transformer.layers.0.1.to_kv.weight
clip.base_model.model.visual_transformer.enc_temporal_transformer.layers.1.1.to_q.weight
clip.base_model.model.vi

In [17]:
i

228