In [1]:
import nibabel as nib
from monai.data.utils import correct_nifti_header_if_necessary
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM
import os
import numpy as np
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
lamed_model_path = "/import/c4dm-04/siyoul/Med3DLLM/checkpoint/amosmm_chatgpt_phi2_0210@bs2_acc1_ep16_lr2e5_ws2_fused/checkpoint-132000"

def load_model(lamed_model_path, enable_lora=False):
    tokenizer = AutoTokenizer.from_pretrained(
        lamed_model_path,
        model_max_length=2048,
        padding_side="right",
        use_fast=False,
        pad_token="<unk>",
        trust_remote_code=True
    )
    print(tokenizer.additional_special_tokens)
    print(tokenizer.additional_special_tokens_ids)
    if tokenizer.unk_token is not None and tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.unk_token
    
    if enable_lora:
        
        lora_config = LoraConfig(
            r=16,
            lora_alpha=32,
            target_modules=find_all_linear_names(lamed_model),
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
        )
        print("Adding LoRA adapters only on LLM.")
        lamed_model = get_peft_model(lamed_model, lora_config)
        # lamed_model.print_trainable_parameters()
        print("Load weights with LoRA")
        state_dict = torch.load(lamed_model_path, map_location="cpu")
        lamed_model.load_state_dict(state_dict, strict=True)
        print("Merge weights with LoRA")
        lamed_model = lamed_model.merge_and_unload()
    else:
        lamed_model = AutoModelForCausalLM.from_pretrained(
        lamed_model_path,
        trust_remote_code=True,
        )
    lamed_model = lamed_model.to("cpu")
    lamed_model.eval()
    return tokenizer, lamed_model

tokenizer, lamed_model = load_model(lamed_model_path)

['<im_patch>', '<bx_start>', '<bx_end>']
[50297, 50296, 50295]


Some weights of LamedPhiForCausalLM were not initialized from the model checkpoint at /import/c4dm-04/siyoul/Med3DLLM/checkpoint/amosmm_chatgpt_phi2_0210@bs2_acc1_ep16_lr2e5_ws2_fused/checkpoint-132000 and are newly initialized: ['model.seg_projector.0.bias', 'model.seg_projector.0.weight', 'model.seg_projector.2.bias', 'model.seg_projector.2.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
image_file_path = "/import/c4dm-04/siyoul/Med3DLLM/datasets/AMOS-MM/imagesVa/amos_0008.nii.gz"
from src.utils.data_transforms import val_transforms
image = [val_transforms(image_file_path), val_transforms(image_file_path)]
image = torch.stack(image)
print("Input Image shape",image.shape)



Input Image shape torch.Size([2, 1, 32, 256, 256])


In [4]:
from src.utils.linear_3d_transform import Linear3DTransform
l3d_t = Linear3DTransform()



In [5]:
image = [l3d_t(image_file_path), l3d_t(image_file_path)]
image = torch.stack(image)
print("Input Image shape",image.shape)

Input Image shape torch.Size([2, 8, 32, 256, 256])


In [6]:
# B for batch size, C for channel, D for depth, H for height, W for width
B, C, D, H, W = image.shape
print("Input Image shape",image.shape)
image = image.view(B * C, 1, image.shape[-3], image.shape[-2], image.shape[-1])
print("Input Image shape",image.shape)
img_emb = lamed_model.encode_images(image)
print("Image embedding shape",img_emb.shape)
img_emb = img_emb.view(B, C, img_emb.shape[-2], img_emb.shape[-1])
print("Image embedding shape",img_emb.shape)

Input Image shape torch.Size([2, 8, 32, 256, 256])
Input Image shape torch.Size([16, 1, 32, 256, 256])
Image embedding shape torch.Size([16, 256, 2560])
Image embedding shape torch.Size([2, 8, 256, 2560])


In [7]:
from src.model.linear_3d_tokenizer.lin3dt import Linear3DTokenizer
embed_size = 2560
num_heads = 8
num_layers = 4
top_k = 1024
l3d_tokenizer = Linear3DTokenizer(
    embed_size=embed_size,
    num_heads=num_heads,
    num_layers=num_layers,
    top_k=top_k,
    use_multi_scale=True,
    num_3d_query_token=256,
    hidden_size=2560
)

In [8]:
t_token = tokenizer(["This is a test sentence","This is a test sentence"], add_special_tokens=False, max_length=-1, truncation=True, padding="max_length", return_tensors="pt", padding_side="right")["input_ids"]
print("Text token shape",t_token.shape)
t_token = lamed_model.model.embed_tokens(t_token)
print("Text token shape",t_token.shape)

We need to remove 6 to truncate the input but the first sequence has a length 5. 
We need to remove 6 to truncate the input but the first sequence has a length 5. 


Text token shape torch.Size([2, 5])
Text token shape torch.Size([2, 5, 2560])


In [9]:
num_video_query_token = 8
hidden_size = 2560
query_tokens = torch.randn(2, 256, 2560)

output = l3d_tokenizer(v_token=img_emb, t_token=t_token)
l3d_tokenizer
print("Output shape",output.shape)

Output shape torch.Size([1, 256, 2560])


In [17]:
img_emb[0] - img_emb[1]

metatensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 

In [10]:
from src.model.linear_3d_tokenizer.svr import SpatioTemporalVisualTokenRefinerModel
from src.model.linear_3d_tokenizer.tta import TextConditionTokenAggregatorModel
embed_size = 2560
num_heads = 8
num_layers = 4
top_k = 1024
svr_model = SpatioTemporalVisualTokenRefinerModel(embed_size=embed_size, num_heads=num_heads, num_layers=num_layers, top_k=top_k, use_multi_scale=True)


In [11]:
# Example input: (batch_size, num_frames, num_tokens, embed_size)
video_data = torch.randn(1, 15, 256, embed_size)  
svr_output = svr_model(video_data)
print(svr_output.shape)

torch.Size([1, 1792, 2560])


In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SpatioTemporalLayer(nn.Module):
    def __init__(self, embed_size, num_heads):
        super().__init__()
        # Spatial attention: attends to tokens within each frame.
        self.spatial_attn = nn.MultiheadAttention(embed_size, num_heads, batch_first=True)
        # Temporal attention: attends across frames for each token position.
        self.temporal_attn = nn.MultiheadAttention(embed_size, num_heads, batch_first=True)
        self.norm_spatial = nn.LayerNorm(embed_size)
        self.norm_temporal = nn.LayerNorm(embed_size)

    def forward(self, x):
        # x shape: (B, T, N, E)
        B, T, N, E = x.shape

        # --- Spatial Attention ---
        # Reshape to combine batch and frame dimensions: (B*T, N, E)
        x_spatial = x.view(B * T, N, E)
        attn_out, _ = self.spatial_attn(x_spatial, x_spatial, x_spatial)
        attn_out = attn_out.view(B, T, N, E)
        x = self.norm_spatial(x + attn_out)

        # --- Temporal Attention ---
        # Permute so that tokens (N) become batch “instances” over T frames:
        # (B, N, T, E) then reshape to (B*N, T, E)
        x_temporal = x.permute(0, 2, 1, 3).contiguous().view(B * N, T, E)
        attn_out, _ = self.temporal_attn(x_temporal, x_temporal, x_temporal)
        # Restore shape: (B, N, T, E) then permute back to (B, T, N, E)
        attn_out = attn_out.view(B, N, T, E).permute(0, 2, 1, 3)
        x = self.norm_temporal(x + attn_out)
        return x

class SpatioTemporalTokenRefiner(nn.Module):
    def __init__(self, embed_size, num_heads, num_layers, top_k, use_multi_scale):
        super().__init__()
        self.layers = nn.ModuleList(
            [SpatioTemporalLayer(embed_size, num_heads) for _ in range(num_layers)]
        )
        # Linear layer to score the significance of each token.
        self.score_fc = nn.Linear(embed_size, 1)
        self.top_k = top_k
        self.use_multi_scale = use_multi_scale

    def forward(self, x):
        # x shape: (B, T, N, E)
        B, T, N, E = x.shape

        # Apply a stack of spatio–temporal attention layers.
        for layer in self.layers:
            x = layer(x)

        # Compute significance scores for each token.
        # Resulting shape: (B, T, N)
        scores = self.score_fc(x).squeeze(-1)
        # Flatten frame and token dimensions: (B, T*N)
        scores = scores.view(B, -1)
        # Select top_k tokens (indices are flattened over T*N).
        _, topk_indices = torch.topk(scores, self.top_k, dim=1)
        # Convert flat indices back into frame and token indices.
        frame_indices = topk_indices // N  # integer division
        token_indices = topk_indices % N

        # Gather the top_k tokens for each batch.
        batch_idx = torch.arange(B, device=x.device).unsqueeze(1)
        selected_tokens = x[batch_idx, frame_indices, token_indices]  # (B, top_k, E)
        print(selected_tokens.shape)
        # Optionally apply multi–scale pooling over the token sequence.
        if self.use_multi_scale:
            pooled_tokens = []
            # Example scales (kernel sizes) for 1D pooling.
            for scale in [1, 2, 4]:
                if selected_tokens.size(1) >= scale:
                    # Pool along the token (sequence) dimension.
                    pooled = F.avg_pool1d(selected_tokens.transpose(1, 2),
                                          kernel_size=scale, stride=scale)
                    pooled = pooled.transpose(1, 2)  # shape: (B, new_tokens, E)
                    pooled_tokens.append(pooled)
            # Concatenate pooled outputs along the token dimension.
            selected_tokens = torch.cat(pooled_tokens, dim=1)

        return selected_tokens  # (B, S, E), with S depending on top_k and pooling


In [15]:
svr_model = SpatioTemporalTokenRefiner(
    embed_size=embed_size, 
    num_heads=num_heads, 
    num_layers=num_layers, 
    top_k=top_k, 
    use_multi_scale=True
    )

# Example input: (batch_size, num_frames, num_tokens, embed_size)
video_data = torch.randn(2, 15, 256, embed_size)  
svr_output = svr_model(video_data)
print(svr_output.shape)

torch.Size([2, 1024, 2560])
torch.Size([2, 1792, 2560])


In [63]:
svr_model = SpatioTemporalTokenRefiner(
    embed_size=embed_size, 
    num_heads=num_heads, 
    num_layers=num_layers, 
    top_k=top_k, 
    use_multi_scale=True
    )

# Example input: (batch_size, num_frames, num_tokens, embed_size)
video_data = torch.randn(1, 15, 256, embed_size)  
svr_output = svr_model(video_data)
print(svr_output.shape)

torch.Size([1, 1792, 2560])


In [21]:
num_heads = 8
num_layers = 4
tta_model = TextConditionedAggregator(embed_size, num_layers, num_heads)

In [27]:
query = torch.randn(2, 256, embed_size)
# Example input: (batch_size, num_tokens, embed_size)
visual_data = torch.randn(2, 1792, embed_size)
text_data = torch.randn(2, 600, embed_size)
output = tta_model(query, visual_data, text_data)
print(output.shape)

torch.Size([2, 256, 2560])


In [20]:
##########################################
# 2. Text–Conditioned Token Aggregation
##########################################
class AggregationLayer(nn.Module):
    def __init__(self, embed_size, num_heads):
        super().__init__()
        # Even if the query is a single token, we use self–attention for consistency.
        self.self_attn = nn.MultiheadAttention(embed_size, num_heads, batch_first=True)
        # Cross–attention layers with visual tokens and text tokens.
        self.cross_attn_visual = nn.MultiheadAttention(embed_size, num_heads, batch_first=True)
        self.cross_attn_text = nn.MultiheadAttention(embed_size, num_heads, batch_first=True)
        self.norm = nn.LayerNorm(embed_size)

    def forward(self, query, visual_tokens, text_tokens):
        # query shape: (B, 1, E)
        # Self–attention on the query.
        self_attn_out, _ = self.self_attn(query, query, query)
        query = self.norm(query + self_attn_out)

        # Cross–attention with visual tokens.
        cross_out, _ = self.cross_attn_visual(query, visual_tokens, visual_tokens)
        query = self.norm(query + cross_out)

        # Cross–attention with text tokens.
        cross_out, _ = self.cross_attn_text(query, text_tokens, text_tokens)
        query = self.norm(query + cross_out)
        return query  # (B, 1, E)

class TextConditionedAggregator(nn.Module):
    def __init__(self, embed_size, num_heads, num_layers):
        super().__init__()
        self.layers = nn.ModuleList(
            [AggregationLayer(embed_size, num_heads) for _ in range(num_layers)]
        )
        # Final cross–attention layer to “compress” the visual tokens.
        self.final_attn = nn.MultiheadAttention(embed_size, num_heads, batch_first=True)
        self.norm_final = nn.LayerNorm(embed_size)

    def forward(self, query, visual_tokens, text_tokens):
        # query: (B, 1, E); visual_tokens: (B, S, E); text_tokens: (B, L, E)
        out = query
        for layer in self.layers:
            out = layer(out, visual_tokens, text_tokens)
        # Final cross–attention with visual tokens.
        attn_out, _ = self.final_attn(out, visual_tokens, visual_tokens)
        out = self.norm_final(out + attn_out)
        # Squeeze the sequence dimension (which is 1) so output shape becomes (B, E)
        return out.squeeze(1)

##########################################
# 3. Combined Linear3DTokenizer Module
##########################################
class Linear3DTokenizer(nn.Module):
    """
    This module takes the CT image tokens (from a ViT3D model) and compresses
    them into a compact visual token. It then aggregates that with text tokens
    (and a visual query token) to produce a final token for the LLM.
    
    Input shapes:
      - v_token: (B, T, N, E)  (CT image embedding tokens over frames and patches)
      - v_query: (B, 1, E)      (a visual query token)
      - t_token: (B, L, E)      (text tokens from a prompt or caption)
      
    Output shape:
      - align_token: (B, E)
    """
    def __init__(self, embed_size, num_heads, num_layers, top_k, use_multi_scale):
        super().__init__()
        self.visual_token_refiner = SpatioTemporalTokenRefiner(
            embed_size, num_heads, num_layers, top_k, use_multi_scale
        )
        self.text_conditioned_aggregator = TextConditionedAggregator(
            embed_size, num_heads, num_layers
        )

    def forward(self, v_query, v_token, t_token):
        # First, refine (compress) the visual tokens from the CT image.
        # v_token: (B, T, N, E) --> visual_tokens: (B, S, E)
        visual_tokens = self.visual_token_refiner(v_token)
        # Then aggregate the visual query, refined visual tokens, and text tokens.
        # v_query: (B, 1, E), t_token: (B, L, E) --> align_token: (B, E)
        align_token = self.text_conditioned_aggregator(v_query, visual_tokens, t_token)
        return align_token