In [1]:
import os
import torch
from Models.MultiViewViT import MultiViewViT
from load_data import IMG_Folder
import torch.nn as nn

In [2]:

def weights_init(w):
    classname = w.__class__.__name__
    if classname.find('Conv') != -1:
        if hasattr(w, 'weight'):
            # nn.init.kaiming_normal_(w.weight, mode='fan_out', nonlinearity='relu')
            nn.init.kaiming_normal_(w.weight, mode='fan_in', nonlinearity='leaky_relu')
        if hasattr(w, 'bias') and w.bias is not None:
                nn.init.constant_(w.bias, 0)
    if classname.find('Linear') != -1:
        if hasattr(w, 'weight'):
            torch.nn.init.xavier_normal_(w.weight)
        if hasattr(w, 'bias') and w.bias is not None:
            nn.init.constant_(w.bias, 0)
    if classname.find('BatchNorm') != -1:
        if hasattr(w, 'weight') and w.weight is not None:
            nn.init.constant_(w.weight, 1)
        if hasattr(w, 'bias') and w.bias is not None:
            nn.init.constant_(w.bias, 0)

In [3]:
# Load model
model = MultiViewViT(
    image_sizes=[(91, 109), (91, 91), (109, 91)],
    patch_sizes=[(7, 7), (7, 7), (7, 7)],
    num_channals=[91, 109, 91],
    vit_args={
        'emb_dim': 768, 'mlp_dim': 3072, 'num_heads': 12,
        'num_layers': 12, 'num_classes': 1,
        'dropout_rate': 0.1, 'attn_dropout_rate': 0.0
    },
    mlp_dims=[3, 128, 256, 512, 1024, 512, 256, 128, 1]
)
model.apply(weights_init)
model = model.to("cpu")

# Load checkpoint
CheckpointPath = r'C:\Users\Rishabh\training_output_metricsMulti_VIT_best_model.pth.tar'
checkpoint = torch.load(CheckpointPath, map_location="cpu")
state_dict = checkpoint["state_dict"]
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [4]:
CheckpointPath = r'C:\Users\Rishabh\trainingMulti_VIT_best_model.pth.tar'
CSVPath = r'C:\Users\Rishabh\Documents\TransBTS\IXI.xlsx'
DataFolder = r'C:\Users\Rishabh\Documents\TrimeseData'
test_data = IMG_Folder(CSVPath, DataFolder)
device = "cpu"

In [5]:
valid_loader = torch.utils.data.DataLoader(test_data
                                         ,batch_size=1
                                         ,num_workers=0
                                         ,pin_memory=True
                                         ,drop_last=True
                                         )

In [6]:
attentions = []       # [B,H,T,T] per layer (forward)
attention_grads = []  # [B,H,T,T] per layer (backward)

def hook_attn_probs(module, inp, out):
    # out might be probs or (ctx, probs) depending on your code
    probs = out[1] if isinstance(out, (tuple, list)) else out
    attentions.append(probs)                    # DO NOT detach if you want grads
    probs.register_hook(lambda g: attention_grads.append(g))  # tensor-level grad hook

# Register on each block's attention where probs exist
for blk in model.vit_1.transformer.encoder_layers:
    blk.attn.register_forward_hook(hook_attn_probs)

In [None]:
import numpy as np
import cv2
out, targ, ID, Attn1, Attn2, Attn3 = [], [], [], [], [], []
target_numpy, predicted_numpy, ID_numpy = [], [], []
model.eval()
for _, (input, ids ,target,male) in enumerate(valid_loader):

    input = input.to(device).type(torch.FloatTensor)
    print(input.shape)
    output, (attn1, attn2, attn3) = model(input, return_attention_weights=True)
    print(output)
    output.backward()

torch.Size([1, 91, 109, 91])
tensor([[35.0340]], grad_fn=<AddmmBackward0>)


In [None]:
inputvolume = []
for _, (input, ids ,target,male) in enumerate(valid_loader):

    inputvolume = input.to(device).type(torch.FloatTensor)
    print(input.shape)

In [None]:
Attn = torch.stack(attentions)
AttnGr = torch.stack(attention_grads)

In [None]:
Attn.shape, AttnGr.shape

In [None]:
Attn = torch.mean(Attn, dim=0)
AttnGr = torch.mean(AttnGr, dim=0)

In [None]:
len(attentions),  len(attention_grads)

In [None]:
attentions[0].shape, attention_grads[0].shape

In [None]:
attentions = torch.mean(torch.stack(attentions), dim=0)
attention_grads = torch.mean(torch.stack(attention_grads), dim=0)

In [None]:
attentions.shape, attention_grads.shape

In [None]:
attentions = attentions.mean(dim=0)
attention_grads = attention_grads.mean(dim=0)

In [None]:
import matplotlib.pyplot as plt

plt.imshow(attentions.detach().numpy(), cmap='hot')  # mask is your 2D NumPy array
plt.colorbar(label='Attention intensity')
plt.title('Attention Map')
plt.axis('off')
plt.show()


In [None]:
import matplotlib.pyplot as plt

plt.imshow(attention_grads.detach().numpy(), cmap='hot')  # mask is your 2D NumPy array
plt.colorbar(label='Attention intensity')
plt.title('Attention Map')
plt.axis('off')
plt.show()

In [None]:
attention_grads

In [None]:
import numpy as np
import torch
def grad_rollout(attentions, gradients, discard_ratio):
    result = torch.eye(attentions[0].size(-1))
    with torch.no_grad():
        for attention, grad in zip(attentions, gradients):                
            weights = grad
            attention_heads_fused = attention*weights
            print(attention_heads_fused.shape)
            attention_heads_fused = attention_heads_fused.mean(axis=1)
            attention_heads_fused[attention_heads_fused < 0] = 0

            # Drop the lowest attentions, but
            # don't drop the class token
            flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
            _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False)
            #indices = indices[indices != 0]
            flat[0, indices] = 0

            I = torch.eye(attention_heads_fused.size(-1))
            a = (attention_heads_fused + 1.0*I)/2
            a = a / a.sum(dim=-1)
            result = torch.matmul(a, result)
    
    # Look at the total attention between the class token,
    # and the image patches
    return result
    mask = result[0, 0 , 1 :]
    # In case of 224x224 image, this brings us from 196 to 14
    width = int(mask.size(-1)**0.5)
    mask = mask.reshape(width, width).numpy()
    mask = mask / np.max(mask)
    return mask    

In [None]:

matrix1 = torch.randn(12, 196, 196)  # 5×5 random floats
matrix2 = torch.randn(12, 196, 196)  # 5×5 random floats
matrix1.shape

In [None]:
result = grad_rollout(Attn, AttnGr, 0.9)

In [None]:
mask = result[0,1:]
width = int(mask.size(-1)**0.5)
print(width)
# mask = mask.reshape(width, width).numpy()
mask = mask.reshape(15, 13).numpy()
mask = mask / np.max(mask)
mask.shape, width

In [None]:
mask = cv2.resize(mask, (109, 91))
mask.shape

In [None]:
import matplotlib.pyplot as plt

plt.imshow(mask, cmap='hot')  # mask is your 2D NumPy array
# plt.colorbar(label='Attention intensity')
# plt.title('Attention Map')
plt.axis('off')
plt.show()
mask.shape

In [None]:
inputvolume.shape
import matplotlib.pyplot as plt

plt.imshow(np.array(inputvolume[0,:,:,25]))  # mask is your 2D NumPy array
# plt.colorbar(label='Attention intensity')
# plt.title('Attention Map')
plt.axis('off')
plt.show()
inputvolume[0,:,:,15].shape

In [None]:
15*13

In [None]:
result.shape

In [None]:
import matplotlib.pyplot as plt

plt.imshow(Attn.detach().numpy()[5], cmap='hot')  # mask is your 2D NumPy array
plt.colorbar(label='Attention intensity')
plt.title('Attention Map')
plt.axis('off')
plt.show()


In [None]:
import matplotlib.pyplot as plt

plt.imshow(AttnGr.detach().numpy()[10], cmap='hot')  # mask is your 2D NumPy array
plt.colorbar(label='Attention intensity')
plt.title('Attention Map')
plt.axis('off')
plt.show()


In [None]:
AttnGr[0]