In [1]:
from scipy.ndimage import zoom
import torch
import numpy as np

In [2]:
def resample_to_target_shape(data, target_shape):
    """
    Resample data to match the target shape
    """
    # Calculate zoom factors
    factors = (target_shape[0] / data.shape[0],
               target_shape[1] / data.shape[1],
               target_shape[2] / data.shape[2])

    # Resample using order=1 (linear interpolation) for continuous data
    resampled_data = zoom(data, factors, order=1)

    return resampled_data


In [3]:

def perform_region_occlusion_analysis(test_loader, model, device, atlas_data, region_labels, region_mapping, img_shape):
    """
    Perform occlusion analysis based on brain regions defined in the atlas
    """
    model.eval()

    print(f"Target image shape for resampling: {img_shape}")

    # Check if the atlas needs resampling
    if atlas_data.shape != img_shape:
        print(f"Resampling atlas from {atlas_data.shape} to {img_shape}")
        resampled_atlas = resample_to_target_shape(atlas_data, img_shape)
    else:
        print("Atlas already matches target shape, no resampling needed")
        resampled_atlas = atlas_data

    print(f"Resampled atlas shape: {resampled_atlas.shape}")

    # Initialize results dictionary to store effect per region
    region_occlusion_effects = {region: 0 for region in region_labels}
    region_sample_counts = {region: 0 for region in region_labels}

    print('======= Starting Region-Based Occlusion Analysis =============')

    sample_count = 0

    with torch.no_grad():
        for _, (input_img, ids, target, male) in enumerate(test_loader):
            # Print input_img shape for debugging
            print(f"Input image shape: {input_img.shape}")

            # Debugging: Print detailed shape information
            print(f"Input data type: {input_img.dtype}")

            # Get original prediction
            input_img = input_img.to(device).type(torch.FloatTensor)

            # Handle gender information if needed by model
            # if opt.model == 'ScaleDense':
            #     male_onehot = torch.unsqueeze(male, 1)
            #     male_onehot = torch.zeros(male_onehot.shape[0], 2).scatter_(1, male_onehot, 1)
            #     male_onehot = male_onehot.type(torch.FloatTensor).to(device)
            #     original_output = model(input_img, male_onehot)
            # else:
            #     original_output = model(input_img)

            original_output = model(input_img)

            # original_output = original_output.cpu().numpy()
            original_output = original_output[0].numpy()

            # Process each region one by one
            for region in region_labels:
                # Free up memory
                torch.cuda.empty_cache()

                # Create mask for this region
                region_mask = (resampled_atlas == region)

                # Skip if region is not present in the resampled atlas
                if not np.any(region_mask):
                    continue

                # Clone the original input
                masked_input = input_img.clone()

                # Move input to CPU for masking
                # cpu_input = masked_input.cpu().numpy()

                cpu_input = masked_input.numpy()

                # Create a zero array with the same shape
                zeroed_array = np.zeros_like(cpu_input)

                # Create a mask array by broadcasting the region mask
                # This safely handles all dimension arrangements
                mask_array = np.ones_like(cpu_input)

                # Apply the region mask - this is the key change
                # We're assuming the last 3 dimensions of cpu_input correspond to the 3D volume
                for i in range(cpu_input.shape[0]):  # batch dimension
                    # Create a view that can be applied to the 3D volume regardless of channel arrangement
                    mask_view = np.broadcast_to(~region_mask, cpu_input[i].shape)
                    cpu_input[i] = cpu_input[i] * mask_view

                # Move back to GPU
                masked_input = torch.from_numpy(cpu_input).to(device)

                # Get prediction for masked input
                # if opt.model == 'ScaleDense':
                #     masked_output = model(masked_input, male_onehot)
                # else:
                #     masked_output = model(masked_input)

                masked_output = model(masked_input)

                # masked_output = masked_output.cpu().numpy()
                masked_output = masked_output[0].numpy()
                print('region:- ',region_mapping[region])
                print('original_output:- ',original_output)
                print('masked_output:- ',masked_output)

                # Calculate effect for this region (difference from original)
                effect = abs(masked_output - original_output)

                # Accumulate effect for this region
                region_occlusion_effects[region] += effect.item()
                region_sample_counts[region] += 1

                # Clean up
                del masked_input, cpu_input
                if 'masked_output' in locals():
                    del masked_output
                torch.cuda.empty_cache()

            sample_count += 1
            print(f"Processed sample {sample_count}/{len(test_loader)}: {ids[0]}")

    # Average effects across samples
    for region in region_labels:
        if region_sample_counts[region] > 0:
            region_occlusion_effects[region] /= region_sample_counts[region]

    # Convert results to a structured array
    result_array = np.array([region_occlusion_effects[region] for region in region_labels])

    return result_array



In [4]:
import os
from Models.MultiViewViT import MultiViewViT
from load_data import IMG_Folder
import torch.nn as nn
import nibabel as nib
from nilearn import datasets

In [5]:
def weights_init(w):
    classname = w.__class__.__name__
    if classname.find('Conv') != -1:
        if hasattr(w, 'weight'):
            # nn.init.kaiming_normal_(w.weight, mode='fan_out', nonlinearity='relu')
            nn.init.kaiming_normal_(w.weight, mode='fan_in', nonlinearity='leaky_relu')
        if hasattr(w, 'bias') and w.bias is not None:
                nn.init.constant_(w.bias, 0)
    if classname.find('Linear') != -1:
        if hasattr(w, 'weight'):
            torch.nn.init.xavier_normal_(w.weight)
        if hasattr(w, 'bias') and w.bias is not None:
            nn.init.constant_(w.bias, 0)
    if classname.find('BatchNorm') != -1:
        if hasattr(w, 'weight') and w.weight is not None:
            nn.init.constant_(w.weight, 1)
        if hasattr(w, 'bias') and w.bias is not None:
            nn.init.constant_(w.bias, 0)

In [6]:
# Load model
model = MultiViewViT(
    image_sizes=[(91, 109), (91, 91), (109, 91)],
    patch_sizes=[(7, 7), (7, 7), (7, 7)],
    num_channals=[91, 109, 91],
    vit_args={
        'emb_dim': 768, 'mlp_dim': 3072, 'num_heads': 12,
        'num_layers': 12, 'num_classes': 1,
        'dropout_rate': 0.1, 'attn_dropout_rate': 0.0
    },
    mlp_dims=[3, 128, 256, 512, 1024, 512, 256, 128, 1]
)
model.apply(weights_init)
model = model.to("cpu")

# Load checkpoint
CheckpointPath = r'C:\Users\Rishabh\training_output_metricsMulti_VIT_best_model.pth.tar'
checkpoint = torch.load(CheckpointPath, map_location="cpu")
state_dict = checkpoint["state_dict"]
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [7]:
CheckpointPath = r'C:\Users\Rishabh\trainingMulti_VIT_best_model.pth.tar'
CSVPath = r'C:\Users\Rishabh\Documents\TransBTS\IXI.xlsx'
# DataFolder = r'C:\Users\Rishabh\Documents\TrimeseData'
DataFolder = r"C:\Users\Rishabh\TrimeesePreprocessedData"
test_data = IMG_Folder(CSVPath, DataFolder)
device = "cpu"

In [8]:
valid_loader = torch.utils.data.DataLoader(test_data
                                         ,batch_size=1
                                         ,num_workers=0
                                         ,pin_memory=True
                                         ,drop_last=True
                                         )

In [9]:
# ======== Load AAL atlas ======== #
aal_atlas = datasets.fetch_atlas_aal()
atlas_filename = aal_atlas.maps
atlas_nii = nib.load(atlas_filename)
atlas_data = atlas_nii.get_fdata()
region_labels = np.unique(atlas_data)[1:]  # Exclude 0 (background)
region_mapping = {code: label for code, label in zip(region_labels, aal_atlas.labels)}

print(f"Number of regions in atlas: {len(region_labels)}")
print(f"Atlas shape: {atlas_data.shape}")

  aal_atlas = datasets.fetch_atlas_aal()


[fetch_atlas_aal] Dataset found in C:\Users\Rishabh\nilearn_data\aal_SPM12
Number of regions in atlas: 116
Atlas shape: (91, 109, 91)


In [10]:
atlas_data.shape

(91, 109, 91)

In [11]:
len(np.unique(atlas_data))

117

In [12]:
model.eval()

MultiViewViT(
  (vit_1): VisionTransformer(
    (embedding): Conv2d(91, 768, kernel_size=(7, 7), stride=(7, 7))
    (transformer): Encoder(
      (pos_embedding): PositionEmbs(
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder_layers): ModuleList(
        (0-11): 12 x EncoderBlock(
          (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): SelfAttention(
            (query): LinearGeneral()
            (key): LinearGeneral()
            (value): LinearGeneral()
            (out): LinearGeneral()
          )
          (dropout): Dropout(p=0.1, inplace=False)
          (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): MlpBlock(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (act): GELU(approximate='none')
            (dropout1): Dropout(p=0.1, inplace=False)
            (dropout2): Dropout(p=0.1,

In [13]:
# Get sample image to determine exact shape
for sample_data in valid_loader:
    sample_img = sample_data[0]
    img_shape = (sample_img.shape[1], sample_img.shape[2], sample_img.shape[3])
    print(f"Detected image shape: {img_shape}")
    break

Detected image shape: (91, 109, 91)


In [14]:
# ======== perform region-based occlusion analysis ======== #
region_occlusion_results = perform_region_occlusion_analysis(
    test_loader=valid_loader,
    model=model,
    device="cpu",
    atlas_data=atlas_data,
    region_labels=region_labels,
    region_mapping=region_mapping,
    img_shape=img_shape
)


Target image shape for resampling: (91, 109, 91)
Atlas already matches target shape, no resampling needed
Resampled atlas shape: (91, 109, 91)
Input image shape: torch.Size([1, 91, 109, 91])
Input data type: torch.float32
region:-  Background
original_output:-  [[30.379274]]
masked_output:-  [[31.153767]]
region:-  Precentral_L
original_output:-  [[30.379274]]
masked_output:-  [[32.800705]]
region:-  Precentral_R
original_output:-  [[30.379274]]
masked_output:-  [[29.600851]]
region:-  Frontal_Sup_L
original_output:-  [[30.379274]]
masked_output:-  [[30.611794]]
region:-  Frontal_Sup_R
original_output:-  [[30.379274]]
masked_output:-  [[30.386362]]
region:-  Frontal_Sup_Orb_L
original_output:-  [[30.379274]]
masked_output:-  [[30.410652]]
region:-  Frontal_Sup_Orb_R
original_output:-  [[30.379274]]
masked_output:-  [[30.791653]]
region:-  Frontal_Mid_L
original_output:-  [[30.379274]]
masked_output:-  [[30.444912]]
region:-  Frontal_Mid_R
original_output:-  [[30.379274]]
masked_output:

KeyboardInterrupt: 

In [None]:

def white0(image, threshold=0):
    """
    Standardize voxels with value > threshold

    Args:
        image: Input image
        threshold: Threshold value

    Returns:
        Standardized image
    """
    image = image.astype(np.float32)
    mask = (image > threshold).astype(int)

    # Vectorized implementation to avoid unnecessary memory allocation
    image_h = image * mask

    # Calculate mean and std only for relevant voxels
    non_zero_voxels = np.sum(mask)
    if non_zero_voxels > 0:
        mean = np.sum(image_h) / non_zero_voxels

        # More memory efficient way to calculate std
        std_sum = np.sum((image_h - mean * mask) ** 2)
        std = np.sqrt(std_sum / non_zero_voxels)

        if std > 0:
            normalized = mask * (image - mean) / std
            # Use in-place operations to reduce memory usage
            image = normalized + image * (1 - mask)
            return image

    # Default case
    return np.zeros_like(image, dtype=np.float32)

In [None]:
fileidx = 10
Files = os.listdir(DataFolder)
img_data = nib.load(os.path.join(DataFolder, Files[fileidx]))
img_data = img_data.get_fdata()
# img_data = white0(img_data)
np.unique(img_data[indx, :, :])

In [None]:
import matplotlib.pyplot as plt
import numpy as np

indx = 45
region_indx = 1
# maskedimage = (sample_img == region_labels[region_indx])
maskedimage = (atlas_data == 2401)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Axial (XY plane)
axes[0].imshow(img_data[indx, :, :], cmap="gray")
axes[0].set_title("Axial")
axes[0].axis("off")

# Coronal (XZ plane)
axes[1].imshow(maskedimage[indx,:, :], cmap="gray")
axes[1].set_title("Coronal")
axes[1].axis("off")

# print(np.unique(atlas_data[indx,:, :]))
# print(np.unique(maskedimage[0,indx,:, :],))
# Sagittal (YZ plane)
print(np.unique(img_data[indx, :, :]))
print(np.unique(atlas_data[indx,:, :]))
maskedimage = np.where(maskedimage, 1, img_data)
print(np.unique(maskedimage[indx,:, :]))
axes[2].imshow(maskedimage[indx,:, :], cmap="gray")
axes[2].set_title("Sagittal")
axes[2].axis("off")
# print(np.unique(maskedimage[0,indx,:, :],))

plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

indx = 45
region_indx = 1
# maskedimage = (sample_img == region_labels[region_indx])
maskedimage1 = (atlas_data == 4011)
maskedimage2 = (atlas_data == 4021)
maskedimage3 = (atlas_data == 8112)
maskedimage4 = (atlas_data == 8302)
maskedimage5 = (atlas_data == 6401)
maskedimage6 = (atlas_data == 9062)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Axial (XY plane)
axes[0].imshow(img_data[ :,indx, :], cmap="gray")
axes[0].set_title("Axial")
axes[0].axis("off")

# Coronal (XZ plane)
axes[1].imshow(atlas_data[:,indx, :], cmap="gray")
axes[1].set_title("Coronal")
axes[1].axis("off")

# print(np.unique(atlas_data[indx,:, :]))
# print(np.unique(maskedimage[0,indx,:, :],))
# Sagittal (YZ plane)
# print(np.unique(img_data[indx, :, :]))
print(np.unique(atlas_data[:,indx, :]))
new_img = np.where(maskedimage1, 1, img_data)
new_img = np.where(maskedimage2, 1, new_img)
new_img = np.where(maskedimage3, 1, new_img)
new_img = np.where(maskedimage4, 1, new_img)
new_img = np.where(maskedimage5, 1, new_img)
new_img = np.where(maskedimage6, 1, new_img)
print(np.unique(maskedimage[indx,:, :]))
axes[2].imshow(new_img[:,indx, :], cmap="gray")
axes[2].set_title("Sagittal")
axes[2].axis("off")
# print(np.unique(maskedimage[0,indx,:, :],))

plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

indx = 45
region_indx = 1
# maskedimage = (sample_img == region_labels[region_indx])
maskedimage1 = (atlas_data == 4011)
maskedimage2 = (atlas_data == 4021)
maskedimage3 = (atlas_data == 8112)
maskedimage4 = (atlas_data == 8302)
maskedimage5 = (atlas_data == 6401)
maskedimage6 = (atlas_data == 9062)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Axial (XY plane)
axes[0].imshow(img_data[ :, :,indx], cmap="gray")
axes[0].set_title("Axial")
axes[0].axis("off")

# Coronal (XZ plane)
axes[1].imshow(atlas_data[:, :, indx], cmap="gray")
axes[1].set_title("Coronal")
axes[1].axis("off")

# print(np.unique(atlas_data[indx,:, :]))
# print(np.unique(maskedimage[0,indx,:, :],))
# Sagittal (YZ plane)
# print(np.unique(img_data[indx, :, :]))
print(np.unique(atlas_data[:,:, indx]))
new_img = np.where(maskedimage1, 1, img_data)
new_img = np.where(maskedimage2, 1, new_img)
new_img = np.where(maskedimage3, 1, new_img)
new_img = np.where(maskedimage4, 1, new_img)
new_img = np.where(maskedimage5, 1, new_img)
new_img = np.where(maskedimage6, 1, new_img)
print(np.unique(maskedimage[indx,:, :]))
axes[2].imshow(new_img[:, :,indx], cmap="gray")
axes[2].set_title("Sagittal")
axes[2].axis("off")
# print(np.unique(maskedimage[0,indx,:, :],))

plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Example: fake 3D image (depth, height, width)
img = np.random.rand(64, 128, 128)

indx = 45
# Pick slice indices
slice1 = sample_img[0,:,indx,:]   # axial slice (z=30)
slice2 = atlas_data[:,indx,:]   # coronal slice (y=60)

# Plot them side by side
plt.figure(figsize=(10,5))

plt.subplot(1, 2, 1)
plt.imshow(slice1, cmap="gray")
plt.title("Axial slice")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(slice2, cmap="gray")
plt.title("Coronal slice")
plt.axis("off")

plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Example: fake 3D image (depth, height, width)
img = np.random.rand(64, 128, 128)

indx = 45
# Pick slice indices
slice1 = sample_img[0,:,:,indx]   # axial slice (z=30)
slice2 = atlas_data[:,:,indx]   # coronal slice (y=60)

# Plot them side by side
plt.figure(figsize=(10,5))

plt.subplot(1, 2, 1)
plt.imshow(slice1, cmap="gray")
plt.title("Axial slice")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(slice2, cmap="gray")
plt.title("Coronal slice")
plt.axis("off")

plt.show()