In [None]:
from scipy.ndimage import zoom
import torch
import numpy as np
import os
from Models.MultiViewViT import MultiViewViT
from load_data import IMG_Folder
import torch.nn as nn
import nibabel as nib
from nilearn import datasets

In [None]:
def weights_init(w):
    classname = w.__class__.__name__
    if classname.find('Conv') != -1:
        if hasattr(w, 'weight'):
            # nn.init.kaiming_normal_(w.weight, mode='fan_out', nonlinearity='relu')
            nn.init.kaiming_normal_(w.weight, mode='fan_in', nonlinearity='leaky_relu')
        if hasattr(w, 'bias') and w.bias is not None:
                nn.init.constant_(w.bias, 0)
    if classname.find('Linear') != -1:
        if hasattr(w, 'weight'):
            torch.nn.init.xavier_normal_(w.weight)
        if hasattr(w, 'bias') and w.bias is not None:
            nn.init.constant_(w.bias, 0)
    if classname.find('BatchNorm') != -1:
        if hasattr(w, 'weight') and w.weight is not None:
            nn.init.constant_(w.weight, 1)
        if hasattr(w, 'bias') and w.bias is not None:
            nn.init.constant_(w.bias, 0)

In [None]:
# Load model
model = MultiViewViT(
    image_sizes=[(91, 109), (91, 91), (109, 91)],
    patch_sizes=[(7, 7), (7, 7), (7, 7)],
    num_channals=[91, 109, 91],
    vit_args={
        'emb_dim': 768, 'mlp_dim': 3072, 'num_heads': 12,
        'num_layers': 12, 'num_classes': 1,
        'dropout_rate': 0.1, 'attn_dropout_rate': 0.0
    },
    mlp_dims=[3, 128, 256, 512, 1024, 512, 256, 128, 1]
)
model.apply(weights_init)
model = model.to("cpu")

# Load checkpoint
CheckpointPath = r'C:\Users\Rishabh\training_output_metricsMulti_VIT_best_model.pth.tar'
checkpoint = torch.load(CheckpointPath, map_location="cpu")
state_dict = checkpoint["state_dict"]
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)

In [None]:
CheckpointPath = r'C:\Users\Rishabh\trainingMulti_VIT_best_model.pth.tar'
CSVPath = r'C:\Users\Rishabh\Documents\TransBTS\IXI.xlsx'
DataFolder = r'C:\Users\Rishabh\Documents\TrimeseData'
Files = os.listdir(DataFolder)
test_data = IMG_Folder(CSVPath, DataFolder)
device = "cpu"

In [None]:
# ======== Load AAL atlas ======== #
aal_atlas = datasets.fetch_atlas_aal()
atlas_filename = aal_atlas.maps
atlas_nii = nib.load(atlas_filename)
atlas_data = atlas_nii.get_fdata()
region_labels = np.unique(atlas_data)[1:]  # Exclude 0 (background)
region_mapping = {code: label for code, label in zip(region_labels, aal_atlas.labels)}

print(f"Number of regions in atlas: {len(region_labels)}")
print(f"Atlas shape: {atlas_data.shape}")

In [None]:
# Get sample image to determine exact shape
for sample_data in valid_loader:
    sample_img = sample_data[0]
    img_shape = (sample_img.shape[1], sample_img.shape[2], sample_img.shape[3])
    print(f"Detected image shape: {img_shape}")
    break

In [None]:
model.eval()

In [None]:
def white0(image, threshold=0):
    """
    Standardize voxels with value > threshold

    Args:
        image: Input image
        threshold: Threshold value

    Returns:
        Standardized image
    """
    image = image.astype(np.float32)
    mask = (image > threshold).astype(int)

    # Vectorized implementation to avoid unnecessary memory allocation
    image_h = image * mask

    # Calculate mean and std only for relevant voxels
    non_zero_voxels = np.sum(mask)
    if non_zero_voxels > 0:
        mean = np.sum(image_h) / non_zero_voxels

        # More memory efficient way to calculate std
        std_sum = np.sum((image_h - mean * mask) ** 2)
        std = np.sqrt(std_sum / non_zero_voxels)

        if std > 0:
            normalized = mask * (image - mean) / std
            # Use in-place operations to reduce memory usage
            image = normalized + image * (1 - mask)
            return image

    # Default case
    return np.zeros_like(image, dtype=np.float32)


In [None]:
import nibabel as nib
import numpy as np
model.eval()
idx = 15
filename = Files[idx]
file_path = os.path.join(DataFolder, filename)
img = nib.load(file_path)
x_np = img.get_fdata(caching='unchanged').astype(np.float32)       # avoid float64 bloat

inputvolume = white0(x_np)
inputvolume = torch.from_numpy(inputvolume).unsqueeze(0).to(device).float()
inputvolume = inputvolume.to(device).type(torch.FloatTensor)
output = model(inputvolume)

In [None]:
x_np.shape, inputvolume.shape

In [None]:
output[0]

In [None]:
region_labels

In [None]:
region_mask = (atlas_data == region_labels[1])

In [None]:
region_mask.shape, np.unique(region_mask[30])

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Example: create a random 2D numpy array
img = region_mask[30,:,:]

# Plot the image
plt.imshow(img)   # you can use "gray", "jet", etc.
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Example: create a random 2D numpy array
img = x_np[30,:,:]

# Plot the image
plt.imshow(img)   # you can use "gray", "jet", etc.
plt.show()
