### Visualizing Attention Maps

In [None]:

import torch
import matplotlib.pyplot as plt
from src.models import ViTBinaryClassifier
from src.dataset import ImagingDataset
import pandas as pd
import os
import numpy as np
from monai.networks.nets import ViTAutoEnc
import torch.nn.functional as F





# --- Configuration ---
folds = 5
trial = 50
weights_dir = "weights"
img_seq_path = "imgs/molab-hardy-leaf-97_images.npy"
test_csv = "dataframes/threshold_df_new.csv"
curated_csv = "dataframes/molab_df_curated.csv"
label_col = 'label-1RN-0Normal'
exclude_columns = ['label-1RN-0Normal', 'Patient ID', 'id', 'BASELINE_TIME_POINT', "CROSSING_TIME_POINT", "BASELINE_VOLUME", "scan_date"]
save_dir = r"C:\Users\gomaaad\PycharmProjects\attn_maps"




# Load test dataframe
geo_df = pd.read_csv(test_csv)
curated_df = pd.read_csv(curated_csv)
geo_df = geo_df.merge(curated_df[['Patient ID', 'id', label_col]], on=['Patient ID', 'id'], how='left')
geo_df = geo_df[geo_df[label_col].notna()]

# Prepare dataset and dataloader for a few samples
ds = ImagingDataset(geo_df, data_dir=img_seq_path, is_gap=False, is_img=True)
sample_loader = torch.utils.data.DataLoader(ds, batch_size=1, shuffle=False)

# Load pretrained ViT model
pretrained_model_path = "C:/Users/gomaaad/PycharmProjects/pretrainedmodels/hardy-leaf-97/best_model_942_0.291.pt"  # Update with actual path
pre_trained_model = ViTAutoEnc(
    img_size=(64, 64, 64),
    patch_size=8,
    in_channels=1,
    out_channels=1,
    num_layers=12,
    num_heads=12,
    hidden_size=384,
    mlp_dim=2048
)
state_dict = torch.load(pretrained_model_path, map_location="cpu", weights_only=False)
if any(k.startswith("module.") for k in state_dict.keys()):
    state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
pre_trained_model.load_state_dict(state_dict, strict=False)


# Load ViTBinaryClassifier and weights
model = ViTBinaryClassifier(pretrained_model=pre_trained_model, unfreeze_last_n=0)
weight_path = os.path.join(weights_dir, "imaging",f"model_fold_{4}_trial_{trial}.pth")  # Adjust trial number if needed
if not os.path.exists(weight_path):
    print(f"Weight file not found: {weight_path}")
model.load_state_dict(torch.load(weight_path, map_location="cpu"))
model.eval()

# Get a sample and forward pass
for i, (x, _) in enumerate(sample_loader):
    # x -> [1, 1, 64, 64, 64]
    with torch.no_grad():
        # Forward pass
        output = model(x)

        # Extract attention weights from last transformer block
        last_block = model.blocks[-1]
        attn_weights = last_block.attn.att_mat  # [1, 12, N, N]

        # Average over heads and keep only attention from each patch to itself or to others
        attn_weights = attn_weights.mean(dim=1).squeeze(0)  # [N, N]

        # Take mean attention given *to* each patch
        token_attn = attn_weights.mean(dim=0)  # [N] -- attention received per patch

        # Reshape attention to 3D grid (assuming patch size = 8 for 64³ volume → 8x8x8 grid)
        num_patches_per_dim = x.shape[-1] // 8  # = 8
        attn_map = token_attn.reshape(num_patches_per_dim, num_patches_per_dim, num_patches_per_dim)

        # Upsample attention map to original resolution (64x64x64)
        attn_map = attn_map.unsqueeze(0).unsqueeze(0)  # [1,1,8,8,8]
        attn_map_upsampled = F.interpolate(attn_map, size=(64, 64, 64), mode='trilinear', align_corners=False)
        attn_map_upsampled = attn_map_upsampled.squeeze().cpu().numpy()  # [64, 64, 64]

        # Visualize a few axial slices
        input_volume = x.squeeze().cpu().numpy()  # [64, 64, 64]

        # Directory to save figures
        os.makedirs(save_dir, exist_ok=True)


        print(f"Saving attention map for sample {i}")
        fig, ax = plt.subplots(figsize=(5, 5))
        ax.imshow(input_volume[:, :, 31], cmap='gray', alpha=0.8)
        ax.imshow(attn_map_upsampled[:, :, 31], cmap='jet', alpha=0.4)
        ax.axis('off')

        # Save the figure as PDF without title
        destination = os.path.join(save_dir, f"attention_map_sample_{i}_slice_{31}.pdf")
        fig.savefig(destination, bbox_inches='tight', pad_inches=0)
        plt.show()
        plt.close(fig)  # Close to avoid memory buildup


