In [None]:
import numpy as np
import torch
from torch import nn, optim


def compute_RFs(model,
                dataloaders,
                data_key,
                selected_channels=None,
                init_img=None,
                neuron_positions=None,
                lr=1,
                model_forward_kwargs=None,):
    """
    Computes the gradient receptive fields of the neurons through gradients.
    Args:
        model (nn.Module): A model trained on static images predicting neural responses
        dataloaders (dict): A dictionary of dictionaries consiting of
                dataloaders (similar to nnfabrik convention)
        data_key (str): data_key specifying a specific dataset (or session)
        selected_channels (iterable, optional): channels for which you want the gradient RF.
                By default returns the gradient for all channels.
        init_img (nn.Parameter, optional): Image used to compute responses. Defaults to None.
        neuron_positions (iterable, optional): An iterable specifying the postion of neurons
                of interest. Defaults to None.
    Returns:
        np.ndarray: gradient receptive fields computed for specified neurons
    """
    sample_input = next(iter(dataloaders['train'][data_key]))[0]
    init_img = sample_input[0:1,:,:,:]
    print(init_img.shape)


    _, c, h, w = sample_input.shape

    selected_channels = (
        selected_channels
        if selected_channels is not None
        else np.arange(c)
    )


    device = sample_input.device

    init_img = (
        torch.nn.Parameter(init_img.to(device))
        if init_img is not None
        else nn.Parameter(torch.randn(1, c, h, w, device=device))
    )
    model_forward_kwargs = (
        model_forward_kwargs
        if model_forward_kwargs is not None
        else dict()
    )
    optimizer = optim.Adam([init_img], lr=lr)
    m = model(init_img, data_key=data_key, **model_forward_kwargs, pretrained=True)
    neuron_positions = (
        neuron_positions
        if neuron_positions is not None
        else range(m.shape[1])
    )
    fig, axs = plt.subplots(3, 5, figsize=(20,20))
    fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)
    fig.tight_layout()

    grad_RFs = list()

    for neuron_position in neuron_positions:
        optimizer.zero_grad()
        neuron_resp = m[0, neuron_position]
        neuron_resp
        print('m model', m[0, neuron_position])
        neuron_resp.backward(retain_graph=True)
     
        grad_RF, *_ = init_img.grad.data.cpu().numpy()


        i=0
        j = neuron_position
        if (neuron_position > 4):
            i = 1
            j = neuron_position - 5
        if neuron_position > 9:
            i = 2
            j = neuron_position - 10

        #print(i, j)
        axs[i][j].imshow(grad_RF[0:1,:,:].squeeze(0))
        grad_RFs.append(grad_RF[0,:,:].copy())


    return grad_RFs, init_img.detach().numpy()

In [None]:
gradient_image, image = compute_RFs(model,
                dataloaders = dataloader,
                data_key = first_session_ID,
                selected_channels=None,
                init_img=None,
                neuron_positions=None,
                lr=1,
                model_forward_kwargs=None)







In [None]:

fig, axs = plt.subplots(4, 5, figsize=(20,20))


for neuron in range(0, 14):

    i=0
    j = neuron
    if (neuron > 4):
        i = 1
        j = neuron - 5
    if neuron > 9:
        i = 2
        j = neuron- 10

    axs[i][j].imshow(gradient_image[neuron])
fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)

fig.tight_layout()

axs[2][4].imshow(image.squeeze(0)[0,:,:], cmap = "gray")



axs[3][0].imshow(image.squeeze(0)[1,:,:])
axs[3][0].title.set_text("Saliency map")

axs[3][1].title.set_text("Gradient to x")
axs[3][1].imshow(image.squeeze(0)[2,:,:], cmap="gray")

axs[3][2].title.set_text("Gradient to y")
axs[3][2].imshow(image.squeeze(0)[3,:,:], cmap="gray")




fig.show()



fig.savefig("saliency_maps_visualization")