In [1]:
from data_inference_hector import Hector_Dataset_ct_pt

hect_dataset = Hector_Dataset_ct_pt(data_folder = "/home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/valid_preprocessed_hector/",  
                csv_file ="/home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/TNM_hector_prompts.csv")


In [5]:
hect_dataset[0][0].shape, hect_dataset[0][1].shape, hect_dataset[0][2], hect_dataset[0][3], hect_dataset[0][4], hect_dataset[0][5]

(torch.Size([1, 240, 480, 480]),
 torch.Size([1, 240, 480, 480]),
 "Patient Information and Clinical Summary:\n\nThe patient is an 82-year-old male with a weight of 80.0 kg. Information regarding the patient's alcohol consumption and performance status is not available. The patient's HPV status is also not specified. The patient has undergone chemotherapy. There is no available information about any surgical interventions.\n\nTNM Staging:\n\nAccording to the 7th edition of the TNM staging system, the patient is classified as T2, N2, M0, which corresponds to a TNM group IV. This indicates a locally advanced disease with regional lymph node involvement but no distant metastasis.\n\nConclusion:\n\nIn summary, this is an 82-year-old male patient with a history of chemotherapy treatment for a cancer classified as T2N2M0, TNM group IV, according to the 7th edition of the TNM staging system. Further information regarding the patient's alcohol consumption, performance status, HPV status, and s

In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm, trange
import os

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import Adam, AdamW
from torchinfo import summary

from utils import make_time_bins
from utils import encode_survival, mtlr_neg_log_likelihood, make_optimizer
from utils import mtlr_survival, mtlr_risk, roc_auc_at_times, brier_score_at_times
from prognosis_model import model_ctpt

from ct_clip import CTCLIP
from transformer_maskgit import CTViT
from transformers import BertTokenizer, BertModel
from lifelines.utils import concordance_index
from data_inference_hector import Hector_Dataset_emb, Hector_Dataset

from peft import get_peft_config, get_peft_model, LoraConfig, TaskType


seed = 42
torch.manual_seed(seed) 
generator = torch.Generator().manual_seed(seed)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

tokenizer = BertTokenizer.from_pretrained('microsoft/BiomedVLP-CXR-BERT-specialized',do_lower_case=True)
text_encoder = BertModel.from_pretrained("microsoft/BiomedVLP-CXR-BERT-specialized")

text_encoder.resize_token_embeddings(len(tokenizer))
text_encoder.to(device)

image_encoder = CTViT(
    dim = 512,
    codebook_size = 8192,
    image_size = 480,
    patch_size = 20,
    temporal_patch_size = 10,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 32,
    heads = 8
)

image_encoder.to(device)

clip = CTCLIP(
    image_encoder = image_encoder,
    text_encoder = text_encoder,
    dim_image = 294912,
    dim_text = 768,
    dim_latent = 512,
    extra_latent_projection = False,         # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
    use_mlm=False,
    downsample_image_embeds = False,
    use_all_token_embeds = False,
)

clip.load("/home/Mohammad.Qazi@mbzuai.ac.ae/project/ct_rate/CT-CLIP/CT-CLIP_v2.pt")
clip.to(device)

num_time_bins = 12

model = model_ctpt(clip, device, num_time_bins)

  from .autonotebook import tqdm as notebook_tqdm
  @autocast(enabled = False)
  @autocast(enabled = False)
  return torch.load(checkpoint_file, map_location="cpu")
  pt = torch.load(str(path))


In [2]:
# Step 3: Initialize new layers manually
def initialize_weights(layer):
    if isinstance(layer, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_uniform_(layer.weight)
        if layer.bias is not None:
            nn.init.zeros_(layer.bias)
    elif isinstance(layer, nn.LayerNorm):
        nn.init.ones_(layer.weight)
        nn.init.zeros_(layer.bias)

model.clip.visual_transformer.to_patch_emb_pt.apply(initialize_weights)
model.clip.visual_transformer.merge_modalities.apply(initialize_weights)

# Step 4: Freeze CT-specific layers
for name, param in model.named_parameters():
    if 'to_patch_emb_pt' in name or 'merge_modalities' in name:
        param.requires_grad = True  # Train these layers
    else:
        param.requires_grad = False  # Freeze all other layers

for name, param in model.named_parameters():
    if any(x in name for x in ["img_embd", "text_embd", "fuse", "mtlr"]):
        param.requires_grad = True

In [3]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"{name}: requires_grad={param.requires_grad}")

clip.visual_transformer.to_patch_emb_pt.1.weight: requires_grad=True
clip.visual_transformer.to_patch_emb_pt.1.bias: requires_grad=True
clip.visual_transformer.to_patch_emb_pt.2.weight: requires_grad=True
clip.visual_transformer.to_patch_emb_pt.2.bias: requires_grad=True
clip.visual_transformer.to_patch_emb_pt.3.weight: requires_grad=True
clip.visual_transformer.to_patch_emb_pt.3.bias: requires_grad=True
clip.visual_transformer.merge_modalities.conv.weight: requires_grad=True
clip.visual_transformer.merge_modalities.conv.bias: requires_grad=True
clip.visual_transformer.merge_modalities.norm.weight: requires_grad=True
clip.visual_transformer.merge_modalities.norm.bias: requires_grad=True
img_embd.0.weight: requires_grad=True
img_embd.0.bias: requires_grad=True
img_embd.2.weight: requires_grad=True
img_embd.2.bias: requires_grad=True
img_embd.3.weight: requires_grad=True
img_embd.3.bias: requires_grad=True
text_embd.0.weight: requires_grad=True
text_embd.0.bias: requires_grad=True
text_e