In [None]:
import json
import os
import torch
# Load the configuration from the JSON file
with open(r'C:\Users\Rishabh\Documents\3d-hcct\config.json', 'r') as f:
    config = json.load(f)

In [None]:
from model import ViTForClassfication

# Initialize the model with the loaded configuration
model = ViTForClassfication(config=config)

In [None]:
from collections import OrderedDict
checkpoint_path = r'C:\Users\Rishabh\training_output_metricsHCCT_best_model.pth.tar'
checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict = checkpoint['state_dict']

# Remove 'module.' prefix if it exists
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k.replace('module.', '')  # strip the prefix
    new_state_dict[name] = v

model.load_state_dict(new_state_dict, strict=True)  # strict ensures all match


In [None]:
import nibabel as nib
import pandas as pd
DataFolder = r'C:\Users\Rishabh\Documents\TrimeseData'
CSVPath = r'C:\Users\Rishabh\Documents\TransBTS\IXI.xlsx'
Files = os.listdir(DataFolder)
ixi_ids = [int(f[3:6]) for f in Files]
print(ixi_ids)  # [2, 19, 20]

In [None]:
df = pd.read_excel(CSVPath)

In [None]:

def white0(image, threshold=0):
    """
    Standardize voxels with value > threshold

    Args:
        image: Input image
        threshold: Threshold value

    Returns:
        Standardized image
    """
    image = image.astype(np.float32)
    mask = (image > threshold).astype(int)

    # Vectorized implementation to avoid unnecessary memory allocation
    image_h = image * mask

    # Calculate mean and std only for relevant voxels
    non_zero_voxels = np.sum(mask)
    if non_zero_voxels > 0:
        mean = np.sum(image_h) / non_zero_voxels

        # More memory efficient way to calculate std
        std_sum = np.sum((image_h - mean * mask) ** 2)
        std = np.sqrt(std_sum / non_zero_voxels)

        if std > 0:
            normalized = mask * (image - mean) / std
            # Use in-place operations to reduce memory usage
            image = normalized + image * (1 - mask)
            return image

    # Default case
    return np.zeros_like(image, dtype=np.float32)

In [None]:
import os, gc
import numpy as np
import nibabel as nib
import torch

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
model.to(device)
model.eval()  # deterministic dropout/BN

indx = 30
filename = Files[indx]

file_path = os.path.join(DataFolder, filename)
img = nib.load(file_path)
x_np = img.get_fdata(caching='unchanged').astype(np.float32)       # avoid float64 bloat

inputvolume = white0(x_np)
inputvolume = torch.from_numpy(inputvolume).unsqueeze(0).unsqueeze(0).to(device).float()
inputvolume = inputvolume.to(device).type(torch.FloatTensor)



with torch.no_grad():
    out,all_attention = model(inputvolume, output_attentions=True)
logits = out[0] if isinstance(out, (tuple, list)) else out
_id = int(filename[3:6])
AGE = df[df['IXI_ID']==_id]['AGE'].values[0]
Predicted_Age = logits.item()



In [None]:
Predicted_Age, AGE

In [None]:
# len(all_attention), all_attention[0].shape
Attn = torch.stack(all_attention)
Attn = torch.mean(Attn, dim=0)
Attn = torch.mean(Attn, dim=0)
Attn = torch.mean(Attn, dim=0)
Attn.shape

In [None]:
import cv2
att_mat = Attn
residual_att = torch.eye(att_mat.size(1)).to(device="cpu")
aug_att_mat = att_mat + residual_att
aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

# Recursively multiply the weight matrices
joint_attentions = torch.zeros(aug_att_mat.size()).to(device="cpu")
joint_attentions[0] = aug_att_mat[0]

for n in range(1, aug_att_mat.size(0)):
    joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n - 1])

# Attention from the output token to the input space.
v = joint_attentions[0,1:].to(device="cpu")
print(len(v))
mask = v.reshape(8, 8, 8).detach().numpy()

In [None]:
import numpy as np
from scipy.ndimage import zoom

# Suppose mask.shape = (depth, height, width)
zoom_factors = (91 / mask.shape[0],
                109 / mask.shape[1],
                91 / mask.shape[2])

mask = zoom(mask, zoom_factors, order=1)  # order=1 → bilinear

In [None]:
mask.shape

In [None]:
import matplotlib.pyplot as plt

plt.imshow(mask[34], cmap='hot')  # mask is your 2D NumPy array
plt.colorbar(label='Attention intensity')
plt.title('Attention Map')
plt.axis('off')
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Example: pick a middle slice along z-axis
slice_idx = x_np.shape[2] // 2  

img_slice = x_np[:, :, slice_idx]
mask_slice = mask[:, :, slice_idx]

plt.figure(figsize=(10,5))

# Show base image
plt.imshow(img_slice, cmap='gray')

# Overlay mask with transparency
plt.imshow(mask_slice, cmap='jet', alpha=0.5)  # alpha controls overlay strength
plt.colorbar(label="Mask intensity")
plt.axis("off")
plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Example: pick a middle slice along z-axis
slice_idx = x_np.shape[0] // 2  

img_slice = x_np[slice_idx, :, :]
mask_slice = mask[slice_idx, :, :]

plt.figure(figsize=(10,5))

# Show base image
plt.imshow(img_slice, cmap='gray')

# Overlay mask with transparency
plt.imshow(mask_slice, cmap='jet', alpha=0.5)  # alpha controls overlay strength
plt.colorbar(label="Mask intensity")
plt.axis("off")
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Example: pick a middle slice along z-axis
slice_idx = x_np.shape[1] // 2  

img_slice = x_np[:, slice_idx, :]
mask_slice = mask[:,slice_idx, :]

plt.figure(figsize=(10,5))

# Show base image
plt.imshow(img_slice, cmap='gray')

# Overlay mask with transparency
plt.imshow(mask_slice, cmap='jet', alpha=0.5)  # alpha controls overlay strength
plt.colorbar(label="Mask intensity")
plt.axis("off")
plt.show()