In [None]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
import scienceplots as scp
import seaborn as sns

import numpy as np
import torch

matplotlib.rcParams['figure.figsize'] = (20, 6)

import sys
sys.path.append("..")

from argparse import Namespace
from main import main, parse_args
from utils.plotting import *
from utils.analysis import process_outputs, get_change_responses, get_omission_responses

from utils.data import load_results_files

import warnings
warnings.filterwarnings('ignore')

import os

In [None]:
ad_files = os.listdir(f"../results/adaptation_baseline")

for f in ad_files:

    file_path = os.path.join(f"../results/adaptation_baseline/{f}")
    data = torch.load(file_path, map_location='cpu')

    change_responses = data['change_responses']
    familiar_resp = change_responses['familiar']
    novel_resp = change_responses['novel']

    familiar_means = {k: torch.zeros(v.shape[0], v.shape[-1], 2) for (k, v) in familiar_resp.items()}
    novel_means = {k: torch.zeros(v.shape[0], v.shape[-1], 2) for (k, v) in novel_resp.items()}

    for k in change_responses['familiar'].keys():
        pre_start = 5
        change_start = 20
        familiar_means[k][:, :, 0] = change_responses["familiar"][k][:, pre_start:pre_start+10].mean(1)
        familiar_means[k][:, :, 1] = change_responses["familiar"][k][:, change_start:change_start+10].mean(1)
        novel_means[k][:, :, 0] = change_responses["novel"][k][:, pre_start:pre_start+10].mean(1)
        novel_means[k][:, :, 1] = change_responses["novel"][k][:, change_start:change_start+10].mean(1)
    
    change_responses["familiar_means"] = familiar_means
    change_responses["novel_means"] = novel_means

    data['change_responses'] = change_responses

    torch.save(data, file_path)

In [None]:
# load the results with hebbian learning

with_args, with_change, with_omission, with_progress = load_results_files(f"../results/adaptation_baseline", "with")
with_args = Namespace(**with_args)

In [None]:
# load the results without hebbian learning

no_args, no_change, no_omission, with_progress = load_results_files(f"../results/adaptation_baseline", "no")
no_args = Namespace(**no_args)

#### Hebbian: change and omission responses

In [None]:
# Change Responses

with plt.style.context(['nature', 'notebook']):

    with_change_fig = plt.figure(figsize=(15, 10))
    plt.tight_layout(pad=10, h_pad=5)

    for i, pop in enumerate(with_change['familiar'].keys()):
        ax = plt.subplot(3, 3, i+1)
        _ = plot_trial_responses(with_args, ax, with_change['familiar'][pop], with_change['novel'][pop], normalize=False)
        _ = ax.set_title(f"{pop}")
        _ = ax.legend()

In [None]:
##### plot responses around omissions

with plt.style.context(['nature', 'notebook']):

    with_omission_fig = plt.figure(figsize=(15, 10))
    #plt.tight_layout(pad=19, h_pad=15)

    for i, pop in enumerate(with_omission['familiar'].keys()):
        ax = plt.subplot(4, 3, i+1)
        plot_trial_responses(with_args, ax, with_omission['familiar'][pop], with_omission['novel'][pop], trial_mode='omission', normalize=False)
        _ = ax.set_title(f"{pop}")

In [None]:
with_change_fig.savefig(f"../figures/figure_7/with_hebb_change_responses.pdf", dpi=600, bbox_inches="tight", pad_inches=0)
with_omission_fig.savefig(f"../figures/figure_7/with_hebb_omission_responses.pdf", dpi=600, bbox_inches="tight", pad_inches=0)

#### No Hebbian: change and omission responses

In [None]:
# Change Responses

with plt.style.context(['nature', 'notebook']):

    no_change_fig = plt.figure(figsize=(15, 10))
    plt.tight_layout(pad=10, h_pad=5)

    for i, pop in enumerate(no_change['familiar'].keys()):
        ax = plt.subplot(3, 3, i+1)
        _ = plot_trial_responses(no_args, ax, no_change['familiar'][pop], no_change['novel'][pop], normalize=False)
        _ = ax.set_title(f"{pop}")
        _ = ax.legend()

In [None]:
##### plot responses around omissions

with plt.style.context(['nature', 'notebook']):

    no_omission_fig = plt.figure(figsize=(15, 10))
    #plt.tight_layout(pad=19, h_pad=15)

    for i, pop in enumerate(no_omission['familiar'].keys()):
        ax = plt.subplot(4, 3, i+1)
        plot_trial_responses(no_args, ax, no_omission['familiar'][pop], no_omission['novel'][pop], trial_mode='omission', normalize=False)
        _ = ax.set_title(f"{pop}")

In [None]:
no_change_fig.savefig(f"../figures/figure_7/no_hebb_change_responses.pdf", dpi=600, bbox_inches="tight", pad_inches=0)
no_omission_fig.savefig(f"../figures/figure_7/no_hebb_omission_responses.pdf", dpi=600, bbox_inches="tight", pad_inches=0)

#### Confidence plots

In [None]:
conf_cond = 0

##### Excitatory

In [None]:
with plt.style.context(['nature', 'notebook']):

    fam_z = with_change['familiar_means']['E'] if conf_cond == 0 else no_change['familiar_means']['E']
    nov_z = with_change['novel_means']['E'] if conf_cond == 0 else no_change['novel_means']['E']
    
    exc_conf_fig, exc_conf = plt.subplots(1, 2, figsize=(7, 5), sharey=True)
    plt.tight_layout(pad=8, h_pad=5)

    # pre-change
    fam_pre = fam_z[..., 0].detach().mean(0)
    nov_pre = nov_z[..., 0].detach().mean(0)
    plot_confidence_intervals(exc_conf[0], fam_pre.numpy(), nov_pre.numpy())
    exc_conf[0].set_ylim([0.007, .025])

    # change
    fam_change = fam_z[..., 1].detach().mean(0)
    nov_change = nov_z[..., 1].detach().mean(0)
    plot_confidence_intervals(exc_conf[1], fam_change.numpy(), nov_change.numpy())
    exc_conf[1].set_ylim([0.007, .025])
    exc_conf[1].set_ylabel('')

##### SST

In [None]:
with plt.style.context(['nature', 'notebook']):

    fam_z = with_change['familiar_means']['SST'] if conf_cond == 0 else no_change['familiar_means']['SST']
    nov_z = with_change['novel_means']['SST'] if conf_cond == 0 else no_change['novel_means']['SST']
    
    sst_conf_fig, sst_conf = plt.subplots(1, 2, figsize=(7, 5), sharey=True)
    plt.tight_layout(pad=8, h_pad=5)

    # pre-change
    fam_pre = fam_z[..., 0].detach().mean(0)
    nov_pre = nov_z[..., 0].detach().mean(0)
    plot_confidence_intervals(sst_conf[0], fam_pre.numpy(), nov_pre.numpy())
    sst_conf[0].set_ylim([0.006, .02])

    # change
    fam_change = fam_z[..., 1].detach().mean(0)
    nov_change = nov_z[..., 1].detach().mean(0)
    plot_confidence_intervals(sst_conf[1], fam_change.numpy(), nov_change.numpy())
    sst_conf[1].set_ylim([0.006, .02])
    sst_conf[1].set_ylabel('')

##### VIP

In [None]:
with plt.style.context(['nature', 'notebook']):

    fam_z = with_change['familiar_means']['VIP'] if conf_cond == 0 else no_change['familiar_means']['VIP']
    nov_z = with_change['novel_means']['VIP'] if conf_cond == 0 else no_change['novel_means']['VIP']
    
    vip_conf_fig, vip_conf = plt.subplots(1, 2, figsize=(7, 5), sharey=True)
    plt.tight_layout(pad=8, h_pad=5)

    # pre-change
    fam_pre = fam_z[..., 0].detach().mean(0)
    nov_pre = nov_z[..., 0].detach().mean(0)
    plot_confidence_intervals(vip_conf[0], fam_pre.numpy(), nov_pre.numpy())
    vip_conf[0].set_ylim([0.0, .02])

    # change
    fam_change = fam_z[..., 1].detach().mean(0)
    nov_change = nov_z[..., 1].detach().mean(0)
    plot_confidence_intervals(vip_conf[1], fam_change.numpy(), nov_change.numpy())
    vip_conf[1].set_ylim([0.0, .02])
    vip_conf[1].set_ylabel('')

In [None]:
cond_txt = 'with_hebb' if conf_cond == 0 else 'no_hebb'
#exc_conf_fig.savefig(f"../figures/figure_7/{cond_txt}_exc_conf.pdf", dpi=600, bbox_inches="tight", pad_inches=0)
sst_conf_fig.savefig(f"../figures/figure_7/{cond_txt}_sst_conf.pdf", dpi=600, bbox_inches="tight", pad_inches=0)
#vip_conf_fig.savefig(f"../figures/figure_7/{cond_txt}_vip_conf.pdf", dpi=600, bbox_inches="tight", pad_inches=0)