In [None]:
import os
import torch
from Models.MultiViewViT import MultiViewViT
from load_data import IMG_Folder
import torch.nn as nn

In [None]:
def weights_init(w):
    classname = w.__class__.__name__
    if classname.find('Conv') != -1:
        if hasattr(w, 'weight'):
            # nn.init.kaiming_normal_(w.weight, mode='fan_out', nonlinearity='relu')
            nn.init.kaiming_normal_(w.weight, mode='fan_in', nonlinearity='leaky_relu')
        if hasattr(w, 'bias') and w.bias is not None:
                nn.init.constant_(w.bias, 0)
    if classname.find('Linear') != -1:
        if hasattr(w, 'weight'):
            torch.nn.init.xavier_normal_(w.weight)
        if hasattr(w, 'bias') and w.bias is not None:
            nn.init.constant_(w.bias, 0)
    if classname.find('BatchNorm') != -1:
        if hasattr(w, 'weight') and w.weight is not None:
            nn.init.constant_(w.weight, 1)
        if hasattr(w, 'bias') and w.bias is not None:
            nn.init.constant_(w.bias, 0)

In [None]:
# Load model
model = MultiViewViT(
    image_sizes=[(91, 109), (91, 91), (109, 91)],
    patch_sizes=[(7, 7), (7, 7), (7, 7)],
    num_channals=[91, 109, 91],
    vit_args={
        'emb_dim': 768, 'mlp_dim': 3072, 'num_heads': 12,
        'num_layers': 12, 'num_classes': 1,
        'dropout_rate': 0.1, 'attn_dropout_rate': 0.0
    },
    mlp_dims=[3, 128, 256, 512, 1024, 512, 256, 128, 1]
)
model.apply(weights_init)
model = model.to("cpu")

# Load checkpoint
CheckpointPath = r'C:\Users\Rishabh\training_output_metricsMulti_VIT_best_model.pth.tar'
checkpoint = torch.load(CheckpointPath, map_location="cpu")
state_dict = checkpoint["state_dict"]
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)


In [None]:
import pandas as pd
CheckpointPath = r'C:\Users\Rishabh\trainingMulti_VIT_best_model.pth.tar'
CSVPath = r'C:\Users\Rishabh\Documents\TransBTS\IXI.xlsx'
DataFolder = r'C:\Users\Rishabh\Documents\TrimeseData'
device = "cpu"
Files = os.listdir(DataFolder)
ixi_ids = [int(f[3:6]) for f in Files]

In [None]:
df = pd.read_excel(CSVPath)

In [None]:
import nibabel as nib
import numpy as np
model.eval()
idx = 15
Pred = []
Acct = []
for idx in range(len(Files)):
    filename = Files[idx]
    file_path = os.path.join(DataFolder, filename)
    img = nib.load(file_path)
    x_np = img.get_fdata().astype(np.float32)       # avoid float64 bloat
    x_tn = torch.from_numpy(x_np).unsqueeze(0).to(device).float()
    # print(type(inputvolume), inputvolume.device, inputvolume.shape)
    
    # print(x_tn.shape)
    
    _id = int(filename[3:6])
    AGE = df[df['IXI_ID']==_id]['AGE'].values[0]
    
    inputvolume = x_tn.to(device).type(torch.FloatTensor)
    with torch.no_grad():
        output, (attn1, attn2, attn3) = model(inputvolume, return_attention_weights=True)
    Predicted_Age = output.item()
    Pred.append(Predicted_Age)
    Acct.append(AGE)
    
    print(Predicted_Age, AGE)

In [None]:

mae = np.mean(np.abs(np.array(Pred) - np.array(Acct)))
print("MAE:", mae)