In [78]:
from src.utils import get_signal_by_type, eye_experiment_sa_start, save_figure
import pickle
import os
import numpy as np
import matplotlib.pyplot as plt

In [22]:
colors = {"non-stimulated": "#651b7e",
          "stimulated": "#dd4f66",
          "TTX": "#fca572"}

In [89]:
# for each dataset, get model, predict for all designs, eyes and experiments.
# save results

model_name = 'LinearRegression'
results_file = 'real_data_preds.csv'

#datasets = [x for x in os.listdir('simulated_data') if x.startswith('DS')]

datasets = ['DS_-10_10_10', 'DS_-10_80_10', 'DS_0_10_10']

#for dataset in datasets:

def get_label_from_ds(dataset_name):
    _, white, me, _ = dataset_name.split('_')
    return f'White SNR = {white}\n ME SNR = {me}'


sa_length = 300
response_length = 2700
experiment_length = sa_length + response_length
max_stimuli = 120
channel_id = 0


design_types = ['2D', '3D']
experiment_types = ['stimulated', 'non-stimulated', 'TTX']

for i, design in enumerate(design_types):
    fig, axs = plt.subplots(2, 3, constrained_layout=True)
    fig.set_figwidth(15)
    fig.suptitle(f'{design} design')
    for eye in range(1, 7):
        eye_idx = eye-1
        ax = axs[eye_idx//3, eye_idx%3]
        ax.set_title(f'Eye {eye}')
        for j, dataset in enumerate(datasets):
            model = pickle.load(open(f'../models/{model_name}_{dataset}.pkl', 'rb'))
            boxes = []
            for k, experiment in enumerate(experiment_types):
                time, signal = get_signal_by_type(eye=eye, design=design, experiment=experiment, verbose=0)
                signal = signal[:, channel_id]

                offset = eye_experiment_sa_start[design][eye][experiment]

                X = []
                i = 0
                while offset + (i+1)*experiment_length < len(signal) and i < max_stimuli:
                    x = signal[offset + i*experiment_length + sa_length:offset + (i+1)*experiment_length] # am I off by 1?
                    X.append(x)
                    i += 1
                X = np.array(X)

                y_pred = model.predict(X)
                box = ax.boxplot([y_pred], showfliers=False, vert=False, patch_artist=True, positions=[j*3 + (k+1)], widths=0.6, label=experiment)
                for item in ['boxes', 'whiskers', 'caps']:
                    plt.setp(box[item], color=colors[experiment])
                plt.setp(box['means'], color='black')
                plt.setp(box['medians'], color='black')

                boxes.append(box)

            ax.set_yticks(np.arange(2, len(datasets)*3, 3))
            if eye_idx % 3 == 0:
                ax.set_yticklabels([get_label_from_ds(x) for x in datasets])
            else:
                ax.set_yticklabels([])

    fig.legend([x['whiskers'][0] for x in boxes[-3:]], experiment_types, 
               loc='center right', bbox_to_anchor=(1.15, 0.8),
               title='Experiment type')
    #plt.show()
    save_figure("predict_real_data", width=8) # take up a whole page
    break