<a href="https://colab.research.google.com/github/AnonymousAlzheimersGaze/Eye-Gaze-Alzheimers-Paper/blob/main/Attention%20Maps%20Generation/Creation_attention_maps.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Author: Carlos Antunes (2022)

- Load dataframes
- Create location and duration based attention maps for each scan
- Create average attention maps

# Imports and Google Drive mount

In [None]:
# Install a pip package in the current Jupyter kernel
import sys
!{sys.executable} -m pip install mat73

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting mat73
  Downloading mat73-0.59-py3-none-any.whl (19 kB)
Installing collected packages: mat73
Successfully installed mat73-0.59


In [None]:
from google.colab import drive, files # to use Google Drive
import mat73
import numpy as np
import matplotlib.pyplot as plt # to plot
from scipy.io import loadmat # to load matlab files
from scipy.stats import gaussian_kde
import h5py # to load matlab files v7.3
import pandas as pd
from functools import reduce
import copy
from scipy.ndimage import gaussian_filter

In [None]:
drive.flush_and_unmount()
drive.mount('/content/drive/') # connect to Google Drive
root_path = '/content/drive/My Drive/dataset' # change directory to my Google Drive

Drive not mounted, so nothing to flush and unmount.
Mounted at /content/drive/


# Load Dataframes

In [None]:
nc_allData = pd.read_pickle("/content/drive/My Drive/dataset/pandas_dataframes/nc_allData.pkl")
mci_allData = pd.read_pickle("/content/drive/My Drive/dataset/pandas_dataframes/mci_allData.pkl")
ad_allData = pd.read_pickle("/content/drive/My Drive/dataset/pandas_dataframes/ad_allData.pkl")

# Visualization Scans

Auxilary function to transform from 2D coordinates to 1D and vice versa

In [None]:
def coordinates_transform(point_vector_form):
  fixation_image = np.zeros(128*128)
  fixation_image[point_vector_form] = 1
  fixation_image = fixation_image.reshape(128, 128)
  coord_point = np.nonzero(fixation_image)
  return coord_point

def coordinates_untransform(point_x, point_y):
  if 0 <= point_x <= 127 and 0 <= point_y <= 127: 
    pt = 128*point_x+point_y
    return pt
  else:
    return False

Plot a scan

In [None]:
# plot the scan
def show_scan_plot(scan, slice_nr):
  one_slice_2d = scan[:,:,slice_nr] # array with one slice of a single patient
  plt.imshow(one_slice_2d, cmap='jet')
  plt.colorbar()

Plot a slice of a scan with the corresponding fixations on top

In [None]:
# plot the fixations on top of the scan
def show_scan_with_fixations(allData, slice_nr, scan_index, save):

  if allData.at[scan_index, 'Has Fix'] == 0 or allData.at[scan_index, 'Fixations'][slice_nr] == []: 
    print("There are no fixations for slice number ", slice_nr, " of scan on row ", scan_index)
    return

  print("Subject ", allData.at[scan_index, 'Subject ID'])
  print("Fixation points ", allData.at[scan_index, 'Durations'][slice_nr])

  one_slice_2d = copy.deepcopy(allData.at[scan_index, 'Scan'][:,:,slice_nr]) # array with one slice of a single patient
  
  plt.imshow(one_slice_2d, cmap='jet')
  plt.colorbar()

  locations = allData.at[scan_index, 'Fixations'][slice_nr]
  coord = coordinates_transform(locations)
  plt.scatter(coord[0], coord[1])
  
  if save == "save":
    plt.savefig(f"/content/drive/My Drive/plots/scan_fixations_{allData.at[scan_index, 'Class']}_{slice_nr}_{allData.at[scan_index, 'Subject ID']}.pdf", 
              bbox_inches ="tight")
  
  #plt.show() 

# Location based attention maps

In [None]:
def create_location_AM(allData):
  total_scans_fixations = 59
  nr_scan_fixation = -1
  scan_index_save = []
  smooth_map = np.empty((total_scans_fixations,128,128,60))
  cmap = plt.get_cmap('viridis')

  # iterate over every scan
  for scan_index in range(allData.shape[0]):
    print("Subject ", allData.at[scan_index, 'Subject ID'])

    # if scan does not have any fixation points, skip
    if allData.at[scan_index, 'Has Fix'] == 0:
      continue

    scan_index_save.append(scan_index)
    nr_scan_fixation += 1

    # iterate over every axial slice of a scan with fixations
    for slice_nr in range(60):
      # if this slice has fixation points
      if len(allData.at[scan_index, 'Durations'][slice_nr]):
        
        locations = allData.at[scan_index, 'Fixations'][slice_nr]
        fix = np.zeros(128*128)
        fix[locations] = 1
        fix = fix.reshape(128, 128)

        # create and add the attention map of this slice to a 3D array
        smooth_map[nr_scan_fixation,:,:,slice_nr] = gaussian_filter(fix, sigma=3).transpose()

        # visualize the attention map
        # cmap = plt.get_cmap('viridis')
        # plt.imshow(smooth_map[:,:,slice_nr], cmap=cmap)
        # plt.colorbar()
        # plt.show()
      else:
        smooth_map[nr_scan_fixation,:,:,slice_nr] = np.zeros((128,128))

    print(np.max(smooth_map[nr_scan_fixation,:,:,:]))
    smooth_map[nr_scan_fixation,:,:,:] = smooth_map[nr_scan_fixation,:,:,:] / np.max(smooth_map[nr_scan_fixation,:,:,:])

    for slice_nr in range(60):
      # if the slice has fixations
      if len(allData.at[scan_index, 'Durations'][slice_nr]):
        print("slice_nr ", slice_nr)
        show_scan_with_fixations(allData, slice_nr, scan_index, "not save") # plot scan with fixations
        plt.show()
        plt.imshow(smooth_map[nr_scan_fixation,:,:,slice_nr], cmap=cmap) # plot saliency map
        plt.clim(0, 1)
        plt.colorbar()
        plt.show()

    np.save(root_path + "/attention_maps/Location_based/"+str(allData.at[scan_index, 'Scan ID'])+".npy", np.array(smooth_map[nr_scan_fixation,:,:,:]))

  # print all slices of scans with fixations and its attention map
  #m = np.max(smooth_map)
  #smooth_map = smooth_map / m # normalize map

  # cmap = plt.get_cmap('viridis')
  # for scan_index, scan_index_fix in zip(scan_index_save, range(total_scans_fixations)):
  #   for slice_nr in range(60):
  #     # if the slice has fixations
  #     if len(allData.at[scan_index, 'Durations'][slice_nr]):
  #       print("slice_nr ", slice_nr)
  #       show_scan_with_fixations(allData, slice_nr, scan_index, "not save") # plot scan with fixations
  #       plt.show()
  #       plt.imshow(smooth_map[scan_index_fix,:,:,slice_nr], cmap=cmap) # plot saliency map
  #       plt.clim(0, 1)
  #       plt.colorbar()
  #       plt.show()

    #np.save(root_path + "/attention_maps/Location_based/"+str(allData.at[scan_index, 'Scan ID']), np.array(smooth_map[scan_index_fix]))

In [None]:
create_location_AM(nc_allData)

In [None]:
create_location_AM(mci_allData)

In [None]:
create_location_AM(ad_allData)

# Duration based attention maps

In [None]:
def create_duration_AM(allData):
  total_scans_fixations = 59
  nr_scan_fixation = -1
  scan_index_save = []
  cmap = plt.get_cmap('viridis')
  smooth_map = np.empty((total_scans_fixations,128,128,60))
  # iterate over every scan
  for scan_index in range(allData.shape[0]):
    print("Subject ", allData.at[scan_index, 'Subject ID'])

    if allData.at[scan_index, 'Has Fix'] == 0:
      continue

    scan_index_save.append(scan_index)
    nr_scan_fixation += 1
    # iterate over every axial slice of a scan with fixations
    for slice_nr in range(60):
      # if this slice has fixation points
      if len(allData.at[scan_index, 'Durations'][slice_nr]):
        
        # iterate over every point
        fix = np.zeros((128,128))
        for point in allData.at[scan_index, 'Durations'][slice_nr]:
          fix[point[0]][point[1]] = point[2]

        # create and add the attention map of this slice to a 3D array
        smooth_map[nr_scan_fixation,:,:,slice_nr] = gaussian_filter(fix, sigma=3).transpose()

      else:
        smooth_map[nr_scan_fixation,:,:,slice_nr] = np.zeros((128,128))

    smooth_map[nr_scan_fixation,:,:,:] = smooth_map[nr_scan_fixation,:,:,:] / np.max(smooth_map[nr_scan_fixation,:,:,:])

    for slice_nr in range(60):
      # if the slice has fixations
      if len(allData.at[scan_index, 'Durations'][slice_nr]):
        print("slice_nr ", slice_nr)
        show_scan_with_fixations(allData, slice_nr, scan_index, "not save") # plot scan with fixations
        plt.show()
        plt.imshow(smooth_map[nr_scan_fixation,:,:,slice_nr], cmap=cmap) # plot saliency map
        plt.clim(0, 1)
        plt.colorbar()
        plt.show()

    np.save(root_path + "/attention_maps/Duration_based/"+str(allData.at[scan_index, 'Scan ID'])+".npy", np.array(smooth_map[nr_scan_fixation,:,:,:]))

  # m = np.max(smooth_map)
  # smooth_map = smooth_map / m
  # cmap = plt.get_cmap('viridis')
  # for scan_index, scan_index_fix in zip(scan_index_save, range(total_scans_fixations)):
  #   for slice_nr in range(60):
  #     if len(allData.at[scan_index, 'Durations'][slice_nr]):
  #       print("slice_nr ", slice_nr)
  #       show_scan_with_fixations(allData, slice_nr, scan_index, "not save")
  #       plt.show()
  #       plt.imshow(smooth_map[scan_index_fix,:,:,slice_nr], cmap=cmap)
  #       plt.clim(0, 1)
  #       plt.colorbar()
  #       plt.show()
  #   np.save(root_path + "/attention_maps/Duration_based/"+str(allData.at[scan_index, 'Scan ID']), np.array(smooth_map[scan_index_fix]))

In [None]:
create_duration_AM(nc_allData)

In [None]:
create_duration_AM(mci_allData)

In [None]:
create_duration_AM(ad_allData)

# Constant average attention map

In [None]:
def add_am(allData, type):
  fixed_map = np.zeros((128,128,60))
  for scan_index in range(allData.shape[0]):
    if allData.at[scan_index, 'Has Fix'] == 1:
      attention_map = np.load(root_path + '/attention_maps/'+type+'/' + str(allData.at[scan_index, 'Scan ID']) + '.npy')
      fixed_map += attention_map
  return fixed_map

def normalized_fixation_map(fixed_map):
  m = np.max(fixed_map)
  return fixed_map / m

def plot_save_fixation_map(fixed_map, classes, type):
  # cmap = plt.get_cmap('viridis')
  # for slice_nr in range(60):
  #   print("slice_nr ", slice_nr)
  #   plt.imshow(fixed_map[:,:,slice_nr], cmap=cmap)
  #   plt.clim(0, 1)
  #   plt.colorbar()
  #   plt.show()
  np.save(root_path + "/attention_maps/Avg_attention/"+type+"_"+classes, np.array(fixed_map))

def create_fixed_AM(nc_allData, ad_allData, mci_allData, type):
  nc_fixed_map = np.zeros((128,128,60))
  ad_fixed_map = np.zeros((128,128,60))
  mci_fixed_map = np.zeros((128,128,60))

  nc_fixed_map = add_am(nc_allData, type)
  nc_fixed_map = normalized_fixation_map(nc_fixed_map)
  plot_save_fixation_map(nc_fixed_map, "NC", type)

  ad_fixed_map = add_am(ad_allData, type)
  ad_fixed_map = normalized_fixation_map(ad_fixed_map)
  plot_save_fixation_map(ad_fixed_map, "AD", type)

  mci_fixed_map = add_am(mci_allData, type)
  mci_fixed_map = normalized_fixation_map(mci_fixed_map)
  plot_save_fixation_map(mci_fixed_map, "MCI", type)

  #nc-ad
  fixed_map = nc_fixed_map + ad_fixed_map
  fixed_map = normalized_fixation_map(fixed_map)
  plot_save_fixation_map(fixed_map, "NC_AD", type)
  #nc-mci
  fixed_map = nc_fixed_map + mci_fixed_map
  fixed_map = normalized_fixation_map(fixed_map)
  plot_save_fixation_map(fixed_map, "NC_MCI", type)
  #ad-mci
  fixed_map = ad_fixed_map + mci_fixed_map
  fixed_map = normalized_fixation_map(fixed_map)
  plot_save_fixation_map(fixed_map, "AD_MCI", type)
  #nc-mci-ad
  fixed_map = nc_fixed_map + ad_fixed_map + mci_fixed_map 
  fixed_map = normalized_fixation_map(fixed_map)
  plot_save_fixation_map(fixed_map, "NC_MCI_AD", type)

In [None]:
create_fixed_AM(nc_allData, ad_allData, mci_allData, "Location_based")

In [None]:
create_fixed_AM(nc_allData, ad_allData, mci_allData, "Duration_based")

# Attention Maps Correlation

Pearson correlation coefficient between two attention maps

In [None]:
def Pearson_Coefficient(map1, map2):
  array_a = np.ndarray.flatten(map1)
  array_b = np.ndarray.flatten(map2)
  return np.corrcoef(array_a, array_b)

Correlation between two maps

In [None]:
def compare_maps(ID1=199, ID2=199, type="Location_based", target_classes1=None, target_classes2=None):
  if target_classes1 != None:
    c = "".join('_'+i for i in target_classes1)
    saliency_map1 = np.load(root_path + '/attention_maps/Avg_attention/'+type+ c + '.npy')
  else:
    saliency_map1 = np.load(root_path + '/attention_maps/'+type+'/'+ str(ID1) + '.npy')

  if target_classes2 != None:
    c = "".join('_'+i for i in target_classes2)
    saliency_map2 = np.load(root_path + '/attention_maps/Avg_attention/'+type+ c + '.npy')
  else:
    saliency_map2 = np.load(root_path + '/attention_maps/'+type+'/'+ str(ID2) + '.npy')

  coef = Pearson_Coefficient(saliency_map1, saliency_map2)[0,1]
  return round(coef, 2)

In [None]:
print(compare_maps(type="Location_based", target_classes1=['NC'], target_classes2=['MCI']))
print(compare_maps(type="Location_based", target_classes1=['NC'], target_classes2=['AD']))

print(compare_maps(type="Location_based", target_classes1=['MCI'], target_classes2=['AD']))

0.88
0.85
0.87


Correlation between two slices, one of each map

In [None]:
def compare_slice_of_maps(ID1=199, ID2=199, slice_nr=25, type="Location_based", target_classes1=None, target_classes2=None):
  if target_classes1 != None:
    c = "".join('_'+i for i in target_classes1)
    saliency_map1 = np.load(root_path + '/attention_maps/Avg_attention/'+type+ c + '.npy')
  else:
    saliency_map1 = np.load(root_path + '/attention_maps/'+type+'/'+ str(ID1) + '.npy')
  saliency_map1 = saliency_map1[:,:,slice_nr]

  if target_classes2 != None:
    c = "".join('_'+i for i in target_classes2)
    saliency_map2 = np.load(root_path + '/attention_maps/Avg_attention/'+type+ c + '.npy')
  else:
    saliency_map2 = np.load(root_path + '/attention_maps/'+type+'/'+ str(ID2) + '.npy')
  saliency_map2 = saliency_map2[:,:,slice_nr]

  coef = Pearson_Coefficient(saliency_map1, saliency_map2)[0,1]
  return round(coef, 2)

Average correlation between maps of the same class

In [None]:
from statistics import mean, stdev

def avg_corr_same(allData1, type="Location_based"):
  corr = []
  for row_index1 in range(allData1.shape[0]):
    if allData1.at[row_index1, 'Has Fix'] == 1:
      for row_index2 in range(row_index1+1, allData1.shape[0]):
        if allData1.at[row_index2, 'Has Fix'] == 1:
          corr.append(compare_maps(allData1.at[row_index1, 'Scan ID'], allData1.at[row_index2, 'Scan ID'], type=type))

  print("Mean correlation ", round(mean(corr), 2))
  print("Standard deviation correlation ", round(stdev(corr), 2))

In [None]:
avg_corr_same(nc_allData, type="Location_based")
avg_corr_same(mci_allData, type="Location_based")
avg_corr_same(ad_allData, type="Location_based")

Mean correlation  0.15
Standard deviation correlation  0.11
Mean correlation  0.15
Standard deviation correlation  0.1
Mean correlation  0.11
Standard deviation correlation  0.09


Average correlation between maps of different classes

In [None]:
def avg_corr_diff(allData1, allData2, type="Location_based"):
  corr = []
  for row_index1 in range(allData1.shape[0]):
    if allData1.at[row_index1, 'Has Fix'] == 1:
      for row_index2 in range(allData2.shape[0]):
        if allData2.at[row_index2, 'Has Fix'] == 1:
          corr.append(compare_maps(allData1.at[row_index1, 'Scan ID'], allData2.at[row_index2, 'Scan ID'], type=type))

  print("Mean correlation ", round(mean(corr), 2))
  print("Standard deviation correlation ", round(stdev(corr), 2))

In [None]:
avg_corr_diff(nc_allData, ad_allData, type="Location_based")
avg_corr_diff(nc_allData, mci_allData, type="Location_based")
avg_corr_diff(mci_allData, ad_allData, type="Location_based")

Mean correlation  0.12
Standard deviation correlation  0.09
Mean correlation  0.14
Standard deviation correlation  0.1
Mean correlation  0.13
Standard deviation correlation  0.09


Average correlation between individual maps and avg attention maps

In [None]:
def avg_corr_fixed_map(allData, type="Location_based", target_classes=['NC', 'AD']):
  corr = []
  for row_index in range(allData.shape[0]):
    if allData.at[row_index, 'Has Fix'] == 1:
      corr.append(compare_maps(ID1 = allData.at[row_index, 'Scan ID'], type=type, target_classes2=target_classes))

  print("Mean correlation ", round(mean(corr), 4))
  print("Standard deviation correlation ", round(stdev(corr), 4))

In [None]:
# avg_corr_fixed_map(nc_allData, type="Location_based", target_classes=['NC'])
# avg_corr_fixed_map(ad_allData, type="Location_based", target_classes=['AD'])
# avg_corr_fixed_map(mci_allData, type="Location_based", target_classes=['MCI'])

# avg_corr_fixed_map(nc_allData, type="Location_based", target_classes=['MCI'])
# avg_corr_fixed_map(nc_allData, type="Location_based", target_classes=['AD'])

# avg_corr_fixed_map(mci_allData, type="Location_based", target_classes=['NC'])
# avg_corr_fixed_map(mci_allData, type="Location_based", target_classes=['AD'])

avg_corr_fixed_map(ad_allData, type="Location_based", target_classes=['NC'])
avg_corr_fixed_map(ad_allData, type="Location_based", target_classes=['MCI'])

Mean correlation  0.2985
Standard deviation correlation  0.1111
Mean correlation  0.3103
Standard deviation correlation  0.109
