In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from monai.transforms import Compose, ScaleIntensity, EnsureChannelFirst, ToTensor, Rotate90
from dataloader import CTScanData
from vit_model import VisionTransformer3D, UpsampleAttentionMap
import random
import nibabel as nib

In [2]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# Function to remove 'module.' prefix from state_dict keys
def remove_module_prefix(state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[7:]] = v
        else:
            new_state_dict[k] = v
    return new_state_dict

In [4]:
# Load model
in_channels = 1
d_model = 256
feedforward_dim = 512
num_heads = 8
patch_size = 16
num_layers = 8
num_classes = 2

model = VisionTransformer3D(in_channels, d_model, feedforward_dim, patch_size, num_classes=num_classes, num_layers=num_layers)
state_dict = torch.load('./Data/results_2/epoch_430_model.pt')
state_dict = remove_module_prefix(state_dict)  # Remove 'module.' prefix
model.load_state_dict(state_dict)
model.to(device)
model.eval()

VisionTransformer3D(
  (patch_embedding): PatchEmbedding3D(
    (proj): Conv3d(1, 256, kernel_size=(16, 16, 16), stride=(16, 16, 16))
  )
  (transformerlayers): ModuleList(
    (0-7): 8 x TransformerBlock(
      (attention_layer): Attention(
        (qkv_layer): Linear(in_features=256, out_features=768, bias=False)
        (dropout_layer): Dropout(p=0.0, inplace=False)
        (softmax_layer): Softmax(dim=-1)
        (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (linear_layer): Linear(in_features=256, out_features=256, bias=True)
      )
      (feedforward_layer): FeedForward(
        (net): Sequential(
          (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=256, out_features=512, bias=True)
          (2): GELU(approximate='none')
          (3): Dropout(p=0.0, inplace=False)
          (4): Linear(in_features=512, out_features=256, bias=True)
          (5): Dropout(p=0.0, inplace=False)
        )
      )
    )


In [5]:

# Load attention map
attn_map_path = './Data/results_2/279_attn_map.pt'
attention_map = torch.load(attn_map_path).cpu().numpy()
print(attention_map.shape)
# Convert attention map to a 2D map (sum over the depth axis)
attention_map_2d = attention_map.sum(axis=0)

# Normalize the attention map for better visualization
attention_map_2d_normalized = (attention_map_2d - np.min(attention_map_2d)) / (np.max(attention_map_2d) - np.min(attention_map_2d))
attention_map_2d_normalized.shape

(1, 224, 224, 224)


(224, 224, 224)

In [6]:
# Load data
df = pd.read_excel('./Data/image_data.xlsx')
train_transforms = Compose([
    ScaleIntensity(),
#     Rotate90(),
    EnsureChannelFirst(channel_dim="no_channel"),
    ToTensor()
])
train_dataset = CTScanData(df, transform=train_transforms)


In [7]:
df.rename(columns={'Unnamed: 0': 'idx'}, inplace=True)
df.columns

Index(['idx', 'ID', 'image', 'seg', 'Age', 'target'], dtype='object')

In [8]:
print('TARGET: 1')
print(list(df['idx'][df['target']==1]))
print('*'*127)
print('TARGET: 0')
print(list(df['idx'][df['target']==0]))

TARGET: 1
[1, 3, 4, 6, 9, 11, 12, 14, 16, 20, 21, 22, 23, 30, 32, 34, 45, 46, 47, 49, 52, 54, 61, 63, 64, 67, 71, 75, 76, 78, 79, 80, 83, 85, 87, 89]
*******************************************************************************************************************************
TARGET: 0
[0, 2, 5, 7, 8, 10, 13, 15, 17, 18, 19, 24, 25, 26, 27, 28, 29, 31, 33, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 48, 50, 51, 53, 55, 56, 57, 58, 59, 60, 62, 65, 66, 68, 69, 70, 72, 73, 74, 77, 81, 82, 84, 86, 88]


In [9]:
idx_class1 = random.choice(list(df['idx'][df['target']==1]))
print(f"idx: {idx_class1}, ID: {df['ID'][idx_class1]}, Age: {df['Age'][idx_class1]}")
class1_age = df['Age'][idx_class1]
class1_id = df['ID'][idx_class1]
image_class1, _ = train_dataset[idx_class1]
image_class1 = image_class1.unsqueeze(0).to(device)  # Add batch dimension and move to device
image_class1.size()

idx: 52, ID: 53, Age: 75


torch.Size([1, 1, 224, 224, 224])

In [10]:
idx_class0 = random.choice(list(df['idx'][df['target']==0]))
print(f"idx: {idx_class0}, ID: {df['ID'][idx_class0]}, Age: {df['Age'][idx_class0]}")
class0_age = df['Age'][idx_class0]
class0_id = df['ID'][idx_class0]
image_class0, _ = train_dataset[idx_class0]
image_class0 = image_class0.unsqueeze(0).to(device)  # Add batch dimension and move to device
image_class0.size()

idx: 27, ID: 28, Age: 58


torch.Size([1, 1, 224, 224, 224])

In [11]:
final_attn_map = attention_map_2d_normalized.copy()
final_attn_map.shape

(224, 224, 224)

In [12]:
def map_values(value):
    if value < 0.05:
        return 0
    elif value < 0.1:
        return 1
    elif value < 0.2:
        return 2
    elif value < 0.3:
        return 3
    elif value < 0.4:
        return 4
    elif value < 0.5:
        return 5
    elif value < 0.6:
        return 6
    elif value < 0.7:
        return 7
    elif value < 0.8:
        return 8
    elif value < 0.9:
        return 9
    elif value <= 1.0:
        return 10
    else:
        return -1  # For values outside the expected range


In [13]:
vectorized_map_values = np.vectorize(map_values)
famap = vectorized_map_values(final_attn_map)

In [14]:
np.unique(famap)

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])

In [15]:
# save files
img_load = nib.load("./Data/Resampled/1_resampled_img.nii.gz")
# class0
image_class0_nii = nib.Nifti1Image(image_class0.cpu().numpy().squeeze(), img_load.affine)
nib.save(image_class0_nii, f'./Data/view_attn_maps/image_class0_id{class0_id}_age{class0_age}.nii.gz')

#class1
image_class1_nii = nib.Nifti1Image(image_class1.cpu().numpy().squeeze(), img_load.affine)
nib.save(image_class1_nii, f'./Data/view_attn_maps/image_class1_id{class1_id}_age{class1_age}.nii.gz')

#attn map
# attn_map_nii = nib.Nifti1Image(famap, img_load.affine)
# nib.save(attn_map_nii, './Data/view_attn_maps/final_attention_map.nii.gz')