### Plot RSA matrices by odor (Encoding/Retrieval similarity)

In [12]:
from confusion_matrix import plot_confusion_matrix
from brainpipe.system import study
from brainpipe.visual import *

import numpy as np
from os.path import join, exists
from os import listdir
import matplotlib.pyplot as plt
import glob,os
from itertools import product

In [24]:
"""
Plot RSA matrices for all electrodes in a specific ROI and patient // ANALYSES BY ODOR
"""

st = study('Olfacto')
path_rsa = join(st.path,'feature/1_RSA_E_R_by_odor/')
path_save = join(path_rsa,'figures_by_odor/')

files = listdir(path_rsa)

for fi in files:
    if fi.endswith('.npz'):
        splits = fi.split('_')
        su,cond,odor,roi = splits[0],splits[1],splits[3],splits[4]
        savename = 'RSA_'+su+'_'+cond+'_'+roi+'_odor'+odor+'.png'
        
        if not exists(path_save+savename):
            mat = np.load(path_rsa+fi)
            rsa = mat['rsa'] #nelecs,ncomb,nsamples,nsamples
            if len(rsa.shape) > 3:
                rsa_mean = np.mean(rsa,axis=(0,1)) #average trials and electrodes in an ROI
            else :
                rsa_mean = np.mean(rsa,axis=0) #average electrodes in an ROI

            #PLOT AND SAVE RSA MATRIX
            title = 'RSA for '+su+' cond '+cond+' in '+roi+' // Odor('+odor+')'
            fig, ax = plt.subplots()
            plt.title(title)
            vmin, vmax = np.nanmin(rsa_mean), np.nanmax(rsa_mean)
            extr = abs(vmin) if abs(vmin)> vmax else vmax
            plt.imshow(rsa_mean,vmin=-extr,vmax=extr,origin='lower')

            size = rsa_mean.shape[0]-1
            ticks = [0,size/4,size/2,(size)*3/4,size]
            ticks_labels = [0,0.5,1,1.5,2]
            ax.set_xticks(ticks, minor=False)
            ax.set_yticks(ticks, minor=False)
            ax.set_xticklabels(ticks_labels)
            ax.set_yticklabels(ticks_labels)
            plt.xlabel('Retrieval time (s)')
            plt.ylabel('Encoding time (s)')
            cbar = plt.colorbar()
            cbar.set_label('RSA',rotation=270)
            #     plt.show()

            plt.savefig(path_save+savename)
            plt.clf()
            plt.close()
        else:
            print('already computed',savename)

-> Olfacto loaded


In [29]:
"""
Plot RSA matrices for ONE electrode in a specific ROI // ANALYSES BY CONDITION (3-4 odors)
"""

st = study('Olfacto')
path_rsa = join(st.path,'feature/1_RSA_E_R_by_odor/')
path_save = join(path_rsa,'figures_by_cond/')
subjects = ['LEFC','SEMC','PIRJ','FERJ','VACJ'] #CHAF not included no elec in ROIs
conds = ['low','high']
rois = ['HC','IFG','OFC']

all_files = listdir(path_rsa)
average = 'False'

for su, cond, roi in product(subjects,conds,rois):
    print('plotting rsa for ',su,cond,roi)
    filename = su+'_'+cond+'_odor_*_'+roi+'_E_R_rsa_avg='+average+'.npz'
    files = [file for file in glob.glob(path_rsa+filename)]
    #concatenate odors in the trials dimensions
    #nelecs,ncomb,nsamples,nsamples
    all_data = np.concatenate(([np.load(file)['rsa'] for file in files]),
                              axis=1)
    nelecs,ncomb,nsamples = all_data.shape[:-1]
    
    for elec in range(nelecs):
        savename = 'RSA_'+su+'_'+cond+'_'+roi+'_elec'+str(elec)+'.png'
        if not exists(path_save+savename):
            chan = np.load(files[0])['channels'][elec]
            rsa = np.mean(all_data[elec],axis=0)
            
            #PLOT AND SAVE RSA MATRIX
            title = 'RSA for '+su+' cond '+cond+' in '+roi+' // Elec('+str(elec)+')'
            fig, ax = plt.subplots()
            plt.title(title)
            vmin, vmax = np.nanmin(rsa), np.nanmax(rsa)
            extr = abs(vmin) if abs(vmin)> vmax else vmax
            plt.imshow(rsa,vmin=-extr,vmax=extr,origin='lower')

            size = rsa_mean.shape[0]-1
            ticks = [0,size/4,size/2,(size)*3/4,size]
            ticks_labels = [0,0.5,1,1.5,2]
            ax.set_xticks(ticks, minor=False)
            ax.set_yticks(ticks, minor=False)
            ax.set_xticklabels(ticks_labels)
            ax.set_yticklabels(ticks_labels)
            plt.xlabel('Retrieval time (s)')
            plt.ylabel('Encoding time (s)')
            cbar = plt.colorbar()
            cbar.set_label('RSA',rotation=270)
            #     plt.show()

            plt.savefig(path_save+savename)
            plt.clf()
            plt.close()
        else:
            print('already computed',savename)

-> Olfacto loaded
plotting rsa for  LEFC low HC
plotting rsa for  LEFC low IFG
plotting rsa for  LEFC low OFC
plotting rsa for  LEFC high HC
plotting rsa for  LEFC high IFG
plotting rsa for  LEFC high OFC
plotting rsa for  SEMC low HC
plotting rsa for  SEMC low IFG
plotting rsa for  SEMC low OFC
plotting rsa for  SEMC high HC
plotting rsa for  SEMC high IFG
plotting rsa for  SEMC high OFC
plotting rsa for  PIRJ low HC
plotting rsa for  PIRJ low IFG
plotting rsa for  PIRJ low OFC
plotting rsa for  PIRJ high HC
plotting rsa for  PIRJ high IFG
plotting rsa for  PIRJ high OFC
plotting rsa for  FERJ low HC
plotting rsa for  FERJ low IFG
plotting rsa for  FERJ low OFC
plotting rsa for  FERJ high HC
plotting rsa for  FERJ high IFG
plotting rsa for  FERJ high OFC
plotting rsa for  VACJ low HC
plotting rsa for  VACJ low IFG
plotting rsa for  VACJ low OFC
plotting rsa for  VACJ high HC
plotting rsa for  VACJ high IFG
plotting rsa for  VACJ high OFC


In [31]:
"""
Plot RSA matrices for ONE electrode in a specific ROI //
HIGH - LOW MATRICES (all odors combined)
"""

st = study('Olfacto')
path_rsa = join(st.path,'feature/1_RSA_E_R_by_odor/')
path_save = join(path_rsa,'figures_high_low_by_elec/')
subjects = ['LEFC','SEMC','PIRJ','FERJ','VACJ'] #CHAF not included no elec in ROIs
conds = ['low','high']
rois = ['HC','IFG','OFC']
average = 'False'

all_files = listdir(path_rsa)

for su, roi in product(subjects,rois):
    print('plotting rsa for ',su,cond,roi)
    filename_l = su+'_'+conds[0]+'_odor_*_'+roi+'_E_R_rsa_avg='+average+'.npz'
    filename_h = su+'_'+conds[1]+'_odor_*_'+roi+'_E_R_rsa_avg='+average+'.npz'
    files_low = [file for file in glob.glob(path_rsa+filename_l)]
    files_high = [file for file in glob.glob(path_rsa+filename_h)]
    #concatenate odors in the trials dimensions
    #nelecs,ncomb,nsamples,nsamples
    all_data_l = np.concatenate(([np.load(file)['rsa'] for file in files_low]),
                              axis=1)
    all_data_h = np.concatenate(([np.load(file)['rsa'] for file in files_high]),
                              axis=1)
#     print(all_data_h.shape,all_data_l.shape)
    nelecs,ncombs,nsamples = all_data_l.shape[:-1]
    
    for elec in range(nelecs):
        savename = 'RSA_'+su+'_high_low_'+roi+'_elec'+str(elec)+'.png'
        if not exists(path_save+savename):
            chan = np.load(files[0])['channels'][elec]
            rsa_h = np.mean(all_data_h[elec],axis=0)
            rsa_l = np.mean(all_data_l[elec],axis=0)
            rsa = rsa_h - rsa_l
            
            #PLOT AND SAVE RSA MATRIX
            title = 'RSA for '+su+' High - Low in '+roi+' // Elec('+str(elec)+')'
            fig, ax = plt.subplots()
            plt.title(title)
            vmin, vmax = np.nanmin(rsa), np.nanmax(rsa)
            extr = abs(vmin) if abs(vmin)> vmax else vmax
            plt.imshow(rsa,vmin=-extr,vmax=extr,origin='lower')

            size = rsa.shape[0]-1
            ticks = [0,size/4,size/2,(size)*3/4,size]
            ticks_labels = [0,0.5,1,1.5,2]
            ax.set_xticks(ticks, minor=False)
            ax.set_yticks(ticks, minor=False)
            ax.set_xticklabels(ticks_labels)
            ax.set_yticklabels(ticks_labels)
            plt.xlabel('Retrieval time (s)')
            plt.ylabel('Encoding time (s)')
            cbar = plt.colorbar()
            cbar.set_label('RSA',rotation=270)
            #     plt.show()

            plt.savefig(path_save+savename)
            plt.clf()
            plt.close()
        else:
            print('already computed',savename)

-> Olfacto loaded
plotting rsa for  LEFC high HC
plotting rsa for  LEFC high IFG
plotting rsa for  LEFC high OFC
plotting rsa for  SEMC high HC
plotting rsa for  SEMC high IFG
plotting rsa for  SEMC high OFC
plotting rsa for  PIRJ high HC
plotting rsa for  PIRJ high IFG
plotting rsa for  PIRJ high OFC
plotting rsa for  FERJ high HC
plotting rsa for  FERJ high IFG
plotting rsa for  FERJ high OFC
plotting rsa for  VACJ high HC
plotting rsa for  VACJ high IFG
plotting rsa for  VACJ high OFC
