In [1]:
import os
import torch
from Models.MultiViewViT import MultiViewViT
from load_data import IMG_Folder
import torch.nn as nn

In [2]:
def weights_init(w):
    classname = w.__class__.__name__
    if classname.find('Conv') != -1:
        if hasattr(w, 'weight'):
            # nn.init.kaiming_normal_(w.weight, mode='fan_out', nonlinearity='relu')
            nn.init.kaiming_normal_(w.weight, mode='fan_in', nonlinearity='leaky_relu')
        if hasattr(w, 'bias') and w.bias is not None:
                nn.init.constant_(w.bias, 0)
    if classname.find('Linear') != -1:
        if hasattr(w, 'weight'):
            torch.nn.init.xavier_normal_(w.weight)
        if hasattr(w, 'bias') and w.bias is not None:
            nn.init.constant_(w.bias, 0)
    if classname.find('BatchNorm') != -1:
        if hasattr(w, 'weight') and w.weight is not None:
            nn.init.constant_(w.weight, 1)
        if hasattr(w, 'bias') and w.bias is not None:
            nn.init.constant_(w.bias, 0)

In [3]:
# Load model
model = MultiViewViT(
    image_sizes=[(91, 109), (91, 91), (109, 91)],
    patch_sizes=[(7, 7), (7, 7), (7, 7)],
    num_channals=[91, 109, 91],
    vit_args={
        'emb_dim': 768, 'mlp_dim': 3072, 'num_heads': 12,
        'num_layers': 12, 'num_classes': 1,
        'dropout_rate': 0.1, 'attn_dropout_rate': 0.0
    },
    mlp_dims=[3, 128, 256, 512, 1024, 512, 256, 128, 1]
)
model.apply(weights_init)
model = model.to("cpu")

# Load checkpoint
CheckpointPath = r'C:\Users\Rishabh\training_output_metricsMulti_VIT_best_model.pth.tar'
checkpoint = torch.load(CheckpointPath, map_location="cpu")
state_dict = checkpoint["state_dict"]
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)


<All keys matched successfully>

In [4]:
import pandas as pd
CheckpointPath = r'C:\Users\Rishabh\trainingMulti_VIT_best_model.pth.tar'
CSVPath = r'C:\Users\Rishabh\Documents\TransBTS\IXI.xlsx'
DataFolder = r'C:\Users\Rishabh\Documents\TrimeseData'
device = "cpu"
Files = os.listdir(DataFolder)
ixi_ids = [int(f[3:6]) for f in Files]

In [5]:
df = pd.read_excel(CSVPath)

In [8]:
import nibabel as nib
import numpy as np
model.eval()
idx = 15
Pred = []
Acct = []
for idx in range(len(Files)):
    filename = Files[idx]
    file_path = os.path.join(DataFolder, filename)
    img = nib.load(file_path)
    x_np = img.get_fdata().astype(np.float32)       # avoid float64 bloat
    x_tn = torch.from_numpy(x_np).unsqueeze(0).to(device).float()
    # print(type(inputvolume), inputvolume.device, inputvolume.shape)
    
    # print(x_tn.shape)
    
    _id = int(filename[3:6])
    AGE = df[df['IXI_ID']==_id]['AGE'].values[0]
    
    inputvolume = x_tn.to(device).type(torch.FloatTensor)
    with torch.no_grad():
        output, (attn1, attn2, attn3) = model(inputvolume, return_attention_weights=True)
    Predicted_Age = output.item()
    Pred.append(Predicted_Age)
    Acct.append(AGE)
    
    print(Predicted_Age, AGE)

79.59088897705078 35.800136892539356
73.44522094726562 35.800136892539356
67.14981842041016 38.78165639972622
78.80917358398438 38.78165639972622
83.26321411132812 46.71047227926078
72.69591522216797 46.71047227926078
62.460731506347656 34.23682409308692
76.03118896484375 34.23682409308692
73.41158294677734 24.28473648186174
80.88882446289062 24.28473648186174
69.3128662109375 55.167693360711844
80.55084228515625 55.167693360711844
75.12857818603516 29.09240246406571
82.59210205078125 29.09240246406571
80.14286804199219 58.65845311430527
75.5879135131836 58.65845311430527
79.01376342773438 39.46611909650924
71.54000854492188 39.46611909650924
73.0363998413086 21.56605065023956
74.21957397460938 21.56605065023956
73.0558853149414 30.67214236824093
81.54778289794922 30.67214236824093
80.54924774169922 33.07871321013005
71.6091537475586 33.07871321013005
81.21086120605469 37.6974674880219
67.09789276123047 37.6974674880219
89.66226959228516 29.779603011635867
77.80033111572266 29.77960301

IndexError: index 0 is out of bounds for axis 0 with size 0

In [9]:

mae = np.mean(np.abs(np.array(Pred) - np.array(Acct)))
print("MAE:", mae)

MAE: 40.58223050291973
