### Saliency visualization function

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import torch.nn.functional as F

def plot_saliency(images, img_idx = None):

    saliency, _ = torch.max(images.grad.data.abs(), dim=1)  # Take max over channels
        
    saliency_img = saliency.squeeze().cpu().numpy()[img_idx]

    # Compute 16x16 patch-wise saliency (ViT token-wise)
    patch_size = 16
    num_patches = 224 // patch_size  # 14x14 for ViT-base

    saliency_patches = np.zeros((num_patches, num_patches))

    # Calculate mean values for 16x16 patches
    for i in range(num_patches):
        for j in range(num_patches):
            patch = saliency_img[i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size]
            saliency_patches[i, j] = patch.mean()

    # Normalize saliency for visualization
    saliency_patches = (saliency_patches - saliency_patches.min()) / (saliency_patches.max() - saliency_patches.min())

    # Upsample back to 224x224
    saliency_upsampled = torch.tensor(saliency_patches).unsqueeze(0).unsqueeze(0)  # Add batch & channel dims
    saliency_upsampled = F.interpolate(saliency_upsampled, size=(224, 224), mode="bilinear", align_corners=False)
    saliency_img = saliency_upsampled.squeeze().numpy()

    # print (saliency)

    # Convert image to numpy
    image_np = images[img_idx].permute(1,2,0).cpu().detach().numpy()  # Normalize to [0,1] for blending

    # Create saliency heatmap
    cmap = plt.get_cmap("hot")
    saliency_colored = cmap(saliency_img)[:, :, :3]  # Remove alpha channel

    # Blend saliency map with image using alpha blending
    opacity = 0.8  # Adjust opacity (0 = invisible, 1 = full heatmap)
    overlay = (1 - opacity) * image_np + opacity * saliency_colored

    # Display results
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    ax[0].imshow(image_np)
    ax[0].axis("off")
    ax[0].set_title("Original Image")

    ax[1].imshow(overlay)
    ax[1].axis("off")
    ax[1].set_title(f"Saliency Map Overlay (Opacity={opacity})")

    plt.show()



### Saliency for Finetuning

In [None]:
train_dataset, val_dataset, test_dataset, labels_len = Load_finetuning_dataset("../Datasets/Finetuning", "NHCs", "images", shuffle = False)
print ("datasets were sucssesfully loaded")

model = Model("ViT-B/16").to("cuda")
print ("model was sucssesfully loaded")

cp_path = '##pretrained_cp##' ### add checkpoint that was trained on the desired dataset

checkpoint = torch.load(cp_path, map_location=torch.device("cpu"), weights_only=False)


checkpoint['model'] = {key: value for key, value in checkpoint['model'].items() if 'cls' not in key and 'clip_model' not in key}

model.load_state_dict(checkpoint['model'])

ff_head = create_mlp(inner_layers = 3, inner_dim = 512, dropout_rate = 0, output_dim = labels_len).to("cuda")
ff_head.load_state_dict(checkpoint['head'])



In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import torch.nn.functional as F

loss_func = nn.L1Loss() 
epoch_it = iter(train_dataset)

images, labels = next(epoch_it)

images.requires_grad = True 

embeddings = model.model_image(images).float()
output = ff_head(embeddings)

loss = loss_func(output, labels)
loss.backward()


for i in range(images.shape[0]):
    plot_saliency(images, i)





### Saliency for pretraining

In [None]:
from Data.Dataloaders import Load_contrastive_dataset

train_dataset, val_dataset, classes = Load_contrastive_dataset("../Datasets/Pretraining", "chembl_25", "smiles", batch_size = 32)
print ("datasets were sucssesfully loaded")

model = Model("ViT-B/16", classes = classes).to("cuda")
print ("model was sucssesfully loaded")

checkpoint = torch.load('../Checkpoints/MoleCLIP/MoleCLIP - Primary.pth', map_location=torch.device("cpu"), weights_only=False)

model.load_state_dict(checkpoint['model'])


In [None]:
from Evaluations.pretraining_eval import features_to_logits

loss_function = nn.CrossEntropyLoss()

labels_template = torch.arange(32).to("cuda")


epoch_it = iter(train_dataset)
        
image1, image2, cls1_labels, cls2_labels = next(epoch_it)

image1.requires_grad = True 
image2.requires_grad = True 

image1_features = model.model_image(image1).float()
image2_features = model.model_image(image2).float()

cls_1_preds = model.cls_1(image1_features)
cls_2_preds = model.cls_2(image1_features)

labels = labels_template[:image1.shape[0]]

logits_per_image1, logits_per_image2 = features_to_logits (model, image1_features, image2_features, 15)

loss_singles = (loss_function(logits_per_image1, labels) + loss_function(logits_per_image2, labels))/2 

cls1_loss = loss_function(cls_1_preds, cls1_labels) 
cls2_loss = loss_function(cls_2_preds, cls2_labels)
loss = loss_singles #+ cls1_loss + cls2_loss

loss.backward()
print (loss)

print (image1.shape)

for i in range(image1.shape[0]):

    plot_saliency(image1, i)
    plot_saliency(image2, i)
