In [2]:
import sys
sys.path.append('..')
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sfp_nsdsyn import *
import warnings
import itertools

warnings.filterwarnings("ignore", category=UserWarning)
pd.options.mode.chained_assignment = None
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Configurations

In [3]:
output_dir = '/Volumes/server/Projects/sfp_nsd/derivatives'
fig_dir = os.path.join(output_dir, 'figures/sfp_model/results_1D/nsdsyn')
precision_dir = '/Volumes/server/Projects/sfp_nsd/derivatives/dataframes/nsdsyn/precision'
stim_classes = ['annulus', 'pinwheel', 'forward spiral', 'reverse spiral']
roi_list = ['V1', 'V2', 'V3']
lr = 0.005
max_epoch = 8000
subj_list = [utils.sub_number_to_string(sn, 'nsdsyn') for sn in np.arange(1, 9)]
subj = 'subj02'
roi = 'V2'
voxels = 'pRFcenter'
stim_class = 'forward-spiral'

# Eccentricity bin information
e1, e2 = 0.5, 4
enum = 7
enum_range = range(7) if enum == 7 else range(3)
bin_list, bin_labels = tuning.get_bin_labels(e1, e2, enum=enum)
ecc_colors = utils.get_continuous_colors(len(bin_labels) + 1, '#3f0377')
ecc_colors = ecc_colors[1:][::-1]
ecc_colors = dict(zip(bin_labels, ecc_colors))

# Figure settings
utils.set_rcParams({'figure.dpi': 72*2})
sns.set_context("notebook", font_scale=2)

## Load in loss

In [5]:
args = ['class', 'lr', 'eph', 'sub', 'roi', 'curbin']
l, m = [], []
for curbin, roi, subj, stim_class in itertools.product(enum_range, roi_list, subj_list, ['avg']):
    stim_class = stim_class.replace(' ', '-')
    loss_file_name = f'loss-history_class-{stim_class}_lr-{lr}_eph-{max_epoch}_e1-{e1}_e2-{e2}_nbin-{enum}_curbin-{curbin}_sub-{subj}_roi-{roi}_vs-{voxels}.h5'
    l.append(os.path.join(output_dir, 'sfp_model', 'results_1D', 'nsdsyn', loss_file_name))

cur_loss_df = pd.DataFrame({})
for l_file in l:
    tmp_df = utils.load_dataframes([l_file], *args)
    tmp_df = tmp_df.query('epoch == @max_epoch-1')
    cur_loss_df = pd.concat((cur_loss_df, tmp_df), axis=0)

In [7]:
args = ['class', 'lr', 'eph', 'sub', 'roi', 'curbin']
l, m = [], []
for curbin, roi, subj, stim_class in itertools.product(enum_range, roi_list, subj_list, ['avg']):
    stim_class = stim_class.replace(' ', '-')
    loss_file_name = f'loss-history_class-{stim_class}_lr-{lr}_eph-{max_epoch}_e1-{e1}_e2-{e2}_nbin-{enum}_curbin-{curbin}_dset-nsdsyn_sub-{subj}_roi-{roi}_vs-{voxels}.h5'
    l.append(os.path.join(output_dir, 'before_w_a_correction', 'sfp_model', 'results_1D', 'nsdsyn', loss_file_name))

old_loss_df = pd.DataFrame({})
for l_file in l:
    tmp_df = utils.load_dataframes([l_file], *args)
    tmp_df = tmp_df.query('epoch == @max_epoch-1')
    old_loss_df = pd.concat((old_loss_df, tmp_df), axis=0)

In [8]:
cur_loss_df['df_type'] = 'corrected'
old_loss_df['df_type'] = 'uncorrected'
loss_df = pd.concat((cur_loss_df, old_loss_df), axis=0)


In [29]:
sns.set_context("notebook", font_scale=2)

for stim_class in ['avg']:
    g = sns.catplot(data=loss_df.query('names == @stim_class'), sharey=True, 
                    hue_order=subj_list, order=bin_labels,
                    x='ecc_bin',y='loss', hue='sub',
                    col='vroinames', kind='point', height=utils.get_height_based_on_width(6,1.5), aspect=1.3)
    g.set_xticklabels(rotation=30)
    g.set(ylim=[-0.05,0.6])

    # Access the axes and set grid properties
    for ax in g.axes.flat:
        ax.yaxis.grid(True)  # Show y-axis grid lines
        ax.yaxis.grid(which='major', linestyle='-', linewidth=2)  # Set grid line style and width
    g.fig.suptitle(stim_class, y=1.05)
    plt.savefig(os.path.join(fig_dir, f'ecc_loss_{stim_class}.png'), bbox_inches='tight')
    plt.close()