In [1]:
from prognosis_model import embd_model, lora_model
import pandas as pd
import numpy as np
from tqdm import tqdm, trange
import os

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import Adam, AdamW
from torchinfo import summary

from utils import make_time_bins
from utils import encode_survival, mtlr_neg_log_likelihood, make_optimizer
from utils import mtlr_survival, mtlr_risk, roc_auc_at_times, brier_score_at_times
from prognosis_model import embd_model, lora_model

from ct_clip import CTCLIP
from transformer_maskgit import CTViT
from transformers import BertTokenizer, BertModel
from lifelines.utils import concordance_index
from data_inference_hector import Hector_Dataset_emb, Hector_Dataset

from peft import get_peft_config, get_peft_model, LoraConfig, TaskType

seed = 42
torch.manual_seed(seed) 
generator = torch.Generator().manual_seed(seed)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

tokenizer = BertTokenizer.from_pretrained('microsoft/BiomedVLP-CXR-BERT-specialized',do_lower_case=True)
text_encoder = BertModel.from_pretrained("microsoft/BiomedVLP-CXR-BERT-specialized")

text_encoder.resize_token_embeddings(len(tokenizer))
text_encoder.to(device)

image_encoder = CTViT(
    dim = 512,
    codebook_size = 8192,
    image_size = 480,
    patch_size = 20,
    temporal_patch_size = 10,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 32,
    heads = 8
)

image_encoder.to(device)

clip = CTCLIP(
    image_encoder = image_encoder,
    text_encoder = text_encoder,
    dim_image = 294912,
    dim_text = 768,
    dim_latent = 512,
    extra_latent_projection = False,         # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
    use_mlm=False,
    downsample_image_embeds = False,
    use_all_token_embeds = False,
)

clip.load("/home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/CT-CLIP/CT-CLIP_v2.pt")
clip.to(device)


  from .autonotebook import tqdm as notebook_tqdm
  @autocast(enabled = False)
  @autocast(enabled = False)
  return torch.load(checkpoint_file, map_location="cpu")
  pt = torch.load(str(path))


CTCLIP(
  (text_transformer): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.25, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.25, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementw

In [None]:
class emb_gen(nn.Module):
    def __init__(self, clip, device, num_time_bins):
        super(emb_gen, self).__init__()
        self.clip = clip
        self.device = device
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.img_embd = nn.Sequential(
            nn.Linear(512, 512),
            nn.GELU(),
            nn.Linear(512, 512),
            nn.LayerNorm(512),
        )
        self.text_embd = nn.Sequential(
            nn.Linear(768, 512),
            nn.GELU(),
            nn.Linear(512, 512),
            nn.LayerNorm(512),
        )

        self.fuse = nn.Sequential(
            nn.Conv1d(
                in_channels=512 * 2,
                out_channels=self.hidden_dim,
                kernel_size=3,
                padding=1,
            ),
            nn.GELU(),
            nn.Conv1d(
                in_channels=self.hidden_dim,
                out_channels=self.hidden_dim,
                kernel_size=3,
                padding=1,
            ),
        )
        self.mtlr = MTLR(in_features=512, num_time_bins=num_time_bins)


    def forward(self, text, image):
        self.clip.eval()
        with torch.no_grad():
            img, text = self.clip(text, image, self.device, prognosis = True)
    
        img = self.avgpool(img.permute(0, 3, 1, 2)).squeeze(-1).squeeze(-1)
        text = text.mean(dim=1)

        img = self.img_embd(img)
        text = self.text_embd(text)

        fuse = torch.cat([img, text], dim=1) 
        fuse = self.fuse(fuse.unsqueeze(2))
        pred = self.mtlr(fuse.squeeze(2))
        
        return pred

In [10]:
n=88
text_emb=tokenizer(hect_dataset[n][1], return_tensors="pt", padding="max_length", truncation=True, max_length=512).to(device)
clip.eval()
with torch.no_grad():
    emb = clip(text_emb, hect_dataset[n][0].unsqueeze(0).cuda(), device, prognosis = True)

In [20]:
img_embd = nn.Sequential(
    nn.Linear(512, 512),
    nn.GELU(),
    nn.Linear(512, 512),
    nn.LayerNorm(512),
)
text_embd = nn.Sequential(
    nn.Linear(768, 512),
    nn.GELU(),
    nn.Linear(512, 512),
    nn.LayerNorm(512),
)

fuse = nn.Sequential(
    nn.Conv1d(
        in_channels=512 * 2,
        out_channels=512,
        kernel_size=3,
        padding=1,
    ),
    nn.GELU(),
    nn.Conv1d(
        in_channels=512,
        out_channels=512,
        kernel_size=3,
        padding=1,
    ),
)

avgpool = nn.AdaptiveAvgPool2d((1, 1))

img_embd = img_embd.to(device)
text_embd = text_embd.to(device)
fuse = fuse.to(device)

In [21]:
emb[0].shape, emb[1].shape

img = avgpool(emb[0].permute(0, 3, 1, 2)).squeeze(-1).squeeze(-1)
text = emb[1].mean(dim=1)

img = img_embd(img)
text = text_embd(text)

feat = torch.cat([img, text], dim=1) 

fuse(feat.unsqueeze(2)).shape

torch.Size([1, 512, 1])

In [22]:
img.shape, text.shape, feat.unsqueeze(2).shape

(torch.Size([1, 512]), torch.Size([1, 512]), torch.Size([1, 1024, 1]))

In [14]:
emb[1].shape

torch.Size([1, 512, 768])

In [None]:
image_encoder()

In [2]:
df = pd.read_csv("/home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/TNM_hector_prompts.csv")

num_time_bins = 12
time_bins = make_time_bins(df['RFS'].values, event=df['Relapse'].values, num_bins=num_time_bins)

peft_config = LoraConfig(
    inference_mode=False, r=8, lora_alpha=64, lora_dropout=0.2, target_modules=["to_q", "to_kv"]
)

model = lora_model(clip, device, peft_config, num_time_bins)


## Dataloader

In [3]:
from data_inference_hector import Hector_Dataset_emb, Hector_Dataset

hect_dataset = Hector_Dataset(data_folder = "/home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/valid_preprocessed_hector/",  
                csv_file ="/home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/TNM_hector_prompts.csv")

hect_dataset_emb = Hector_Dataset_emb(emd_path = '/opt/sagemaker/new_home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/save/embeddings_new_TNM.npy',  
                csv_file ="/opt/sagemaker/new_home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/TNM_hector_prompts.csv")


In [3]:
len(hect_dataset), len(hect_dataset_emb)

(409, 409)

In [3]:
hect_dataset[45][4], hect_dataset_emb[45][4]

('CHUM-055_ct_roi.npz', 'CHUM-055_ct_roi.npz')

In [21]:
hect_dataset[0][0].max(), hect_dataset[0][0].min()

(tensor(1.), tensor(-1.))

In [22]:
hect_dataset_emb[0][0].max(), hect_dataset_emb[0][0].min()

(tensor(0.1154), tensor(-0.1333))

In [25]:
hect_dataset[0][1]

"Patient Information and Clinical Summary:\n\nThe patient is an 82-year-old male with a weight of 80.0 kg. Information regarding the patient's alcohol consumption and performance status is not available. The patient's HPV status is also not specified. The patient has undergone chemotherapy. There is no available information about any surgical interventions.\n\nTNM Staging:\n\nAccording to the 7th edition of the TNM staging system, the patient is classified as T2, N2, M0, which corresponds to a TNM group IV. This indicates a locally advanced disease with regional lymph node involvement but no distant metastasis.\n\nConclusion:\n\nIn summary, this is an 82-year-old male patient with a history of chemotherapy treatment for a cancer classified as T2N2M0, TNM group IV, according to the 7th edition of the TNM staging system. Further information regarding the patient's alcohol consumption, performance status, HPV status, and surgical history is required for a more comprehensive assessment."

In [29]:
hect_dataset[0][0].unsqueeze(0).shape

torch.Size([1, 1, 240, 480, 480])

In [3]:
n=407

In [4]:
image_encoder.eval()
with torch.no_grad():
    encd = image_encoder(hect_dataset[n][0].unsqueeze(0).to(device), return_encoded_tokens=True)
encd.shape, encd.max(), encd.min()

torch.Size([1, 13824, 512]) tensor(14.8312, device='cuda:0') tensor(-10.7521, device='cuda:0')
after torch.Size([1, 13824, 512]) tensor(0.8041, device='cuda:0') tensor(-0.7761, device='cuda:0')


(torch.Size([1, 24, 24, 24, 512]),
 tensor(0.8041, device='cuda:0'),
 tensor(-0.7761, device='cuda:0'))

In [5]:
with torch.no_grad():
    encd = image_encoder(hect_dataset[n][0].unsqueeze(0).cuda(), return_encoded_tokens=True)
encd.shape, encd.max(), encd.min()

torch.Size([1, 13824, 512]) tensor(14.8312, device='cuda:0') tensor(-10.7521, device='cuda:0')
after torch.Size([1, 13824, 512]) tensor(0.8041, device='cuda:0') tensor(-0.7761, device='cuda:0')


(torch.Size([1, 24, 24, 24, 512]),
 tensor(0.8041, device='cuda:0'),
 tensor(-0.7761, device='cuda:0'))

In [5]:
encd = image_encoder(hect_dataset[n][0].unsqueeze(0).cuda(), return_encoded_tokens=True)
encd.shape, encd.max(), encd.min()

(torch.Size([1, 24, 24, 24, 512]),
 tensor(0.7667, device='cuda:0', grad_fn=<MaxBackward1>),
 tensor(-0.7625, device='cuda:0', grad_fn=<MinBackward1>))

In [4]:
text_emb=tokenizer(hect_dataset[n][1], return_tensors="pt", padding="max_length", truncation=True, max_length=512).to(device)
clip.eval()
with torch.no_grad():
    emb = clip(text_emb, hect_dataset[n][0].unsqueeze(0).cuda(), device, embed = True)

emb[0].shape, emb[1].shape, emb[0].max(), emb[0].min(), emb[1].max(), emb[1].min(), hect_dataset[n][4]

torch.Size([1, 13824, 512]) tensor(14.8312, device='cuda:0') tensor(-10.7521, device='cuda:0')
after torch.Size([1, 13824, 512]) tensor(0.8041, device='cuda:0') tensor(-0.7761, device='cuda:0')
text embd torch.Size([1, 768]) tensor(1.2392, device='cuda:0') tensor(-1.0307, device='cuda:0')
image embd torch.Size([1, 294912]) tensor(0.4192, device='cuda:0') tensor(-0.3291, device='cuda:0')
text latents torch.Size([1, 512]) tensor(0.1407, device='cuda:0') tensor(-0.1061, device='cuda:0')
image latents torch.Size([1, 512]) tensor(0.1090, device='cuda:0') tensor(-0.1142, device='cuda:0')


(torch.Size([1, 512]),
 torch.Size([1, 512]),
 tensor(0.1090, device='cuda:0'),
 tensor(-0.1142, device='cuda:0'),
 tensor(0.1407, device='cuda:0'),
 tensor(-0.1061, device='cuda:0'),
 'HMR-034_ct_roi.npz')

In [5]:
text_emb=tokenizer(hect_dataset[n][1], return_tensors="pt", padding="max_length", truncation=True, max_length=512).to(device)
clip.eval()
with torch.no_grad():
    emb = clip(text_emb, hect_dataset[n][0].unsqueeze(0).cuda(), device, embed = True)

emb[0].shape, emb[1].shape, emb[0].max(), emb[0].min(), emb[1].max(), emb[1].min(), hect_dataset[n][4]

torch.Size([1, 13824, 512]) tensor(14.8312, device='cuda:0') tensor(-10.7521, device='cuda:0')
after torch.Size([1, 13824, 512]) tensor(0.8041, device='cuda:0') tensor(-0.7761, device='cuda:0')
text embd torch.Size([1, 768]) tensor(1.2392, device='cuda:0') tensor(-1.0307, device='cuda:0')
image embd torch.Size([1, 294912]) tensor(0.4192, device='cuda:0') tensor(-0.3291, device='cuda:0')
text latents torch.Size([1, 512]) tensor(0.1407, device='cuda:0') tensor(-0.1061, device='cuda:0')
image latents torch.Size([1, 512]) tensor(0.1090, device='cuda:0') tensor(-0.1142, device='cuda:0')


(torch.Size([1, 512]),
 torch.Size([1, 512]),
 tensor(0.1090, device='cuda:0'),
 tensor(-0.1142, device='cuda:0'),
 tensor(0.1407, device='cuda:0'),
 tensor(-0.1061, device='cuda:0'),
 'HMR-034_ct_roi.npz')

In [6]:
hect_dataset_emb[n][0].shape, hect_dataset_emb[n][1].shape, hect_dataset_emb[n][0].max(), hect_dataset_emb[n][0].min(), hect_dataset_emb[n][1].max(), hect_dataset_emb[n][1].min()

(torch.Size([512]),
 torch.Size([512]),
 tensor(0.1090),
 tensor(-0.1142),
 tensor(0.1407),
 tensor(-0.1061))

# Fold

In [None]:
from data_inference_hector import Hector_Dataset_emb

hect_dataset = Hector_Dataset_emb(emd_path = '/opt/sagemaker/new_home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/save/embeddings_new_exp_.npy',  
                csv_file ="/opt/sagemaker/new_home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/TNM_hector_prompts.csv")

In [27]:
import pandas as pd

df = pd.read_csv("/opt/sagemaker/new_home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/TNM_hector_prompts.csv")

In [14]:
num_folds = 5

# Shuffle indices
num_rows = len(df)
folds = np.tile(np.arange(num_folds), num_rows // num_folds + 1)[:num_rows]  
np.random.shuffle(folds)

In [19]:
#  get the number of samples in each fold

np.unique(folds, return_counts=True)

(array([0, 1, 2, 3, 4]), array([82, 82, 82, 82, 81]))

In [20]:
# Assign to DataFrame
df['fold'] = folds

    PatientID Gender  Age  Weight  Tobacco  Alcohol  Performance status  \
0    CHUM-001      M   82    80.0      NaN      NaN                 NaN   
1    CHUM-002      M   73    55.0      NaN      NaN                 NaN   
2    CHUM-006      M   65   101.0      NaN      NaN                 NaN   
3    CHUM-007      F   70    80.0      NaN      NaN                 NaN   
4    CHUM-008      F   67    91.0      NaN      NaN                 NaN   
..        ...    ...  ...     ...      ...      ...                 ...   
404   HMR-028      M   73    87.0      NaN      NaN                 NaN   
405   HMR-029      M   57     NaN      NaN      NaN                 NaN   
406   HMR-030      M   70     NaN      NaN      NaN                 NaN   
407   HMR-034      F   85     NaN      NaN      NaN                 NaN   
408   HMR-040      F   61    53.0      NaN      NaN                 NaN   

     HPV status (0=-, 1=+)  Surgery  Chemotherapy  Relapse   RFS  TNM edition  \
0                 

In [22]:
df['fold'].value_counts()

fold
0    82
2    82
1    82
3    82
4    81
Name: count, dtype: int64

In [23]:
df.to_csv("/opt/sagemaker/new_home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/TNM_hector_prompts.csv")

In [1]:
from data_inference_hector import Hector_Dataset_emb


hect_dataset = Hector_Dataset_emb(emd_path = '/opt/sagemaker/new_home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/save/embeddings_new_exp_.npy',  
                csv_file ="/opt/sagemaker/new_home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/TNM_hector_prompts.csv")


In [2]:
train_dataset, val_dataset = hect_dataset.train_val_split(fold=0)