In [9]:
from annubes.tasks.guido import GuidoTask
from annubes.tasks.habituation import HabituationTask
import numpy as np
import sys
import random
import plotly.graph_objects as go
from plotly.subplots import make_subplots

### Guido Task

#### Generate trials

In [2]:
tau = 100
dt = tau / 5

task_param = {}
task_param['std_inp_noise'] = 0.01
task_param['fixation'] = True
task_param['tau'] = tau
task = GuidoTask(task_param)

numpy_seed = random.randrange(sys.maxsize)
rng = np.random.default_rng(numpy_seed)
minibatch_size = 7
catch_prob = 0.1

trials = task.generate_trials(rng, dt, minibatch_size, catch_prob)
t = np.linspace(dt, task.T, int(task.T / dt))

In [3]:
fig = make_subplots(rows=minibatch_size, cols=1,
                    shared_xaxes=True,
                    vertical_spacing=0.04,
                    subplot_titles=[
                        "Trial " + str(i + 1) + " - modality " + str(trials['modality'][i])
                        for i in range(minibatch_size)])
showlegend = True
for i in range(minibatch_size):
    fig.add_trace(go.Scatter(
        name="VISUAL",
        mode="markers+lines", x=t, y=trials['inputs'][i][:,0],
        marker_symbol="star",
        legendgroup="VISUAL",
        showlegend=showlegend,
        line_color = 'blue'
    ), row=i+1, col=1)
    fig.add_trace(go.Scatter(
        name="AUDITORY",
        mode="markers+lines", x=t, y=trials['inputs'][i][:,1],
        marker_symbol="star",
        legendgroup="AUDITORY",
        showlegend=showlegend,
        line_color = 'black'
    ), row=i+1, col=1)
    fig.add_trace(go.Scatter(
        name="START",
        mode="markers+lines", x=t, y=trials['inputs'][i][:,2],
        marker_symbol="star",
        legendgroup="START",
        showlegend=showlegend,
        line_color = 'green'
    ), row=i+1, col=1)
    fig.add_trace(go.Scatter(
        name="NO STIMULUS",
        mode="lines", x=t, y=trials['outputs'][i][:,0],
        legendgroup="OUTPUT 1",
        showlegend=showlegend,
        line_color = 'orange'
    ), row=i+1, col=1)
    fig.add_trace(go.Scatter(
        name="STIMULUS/STIMULI",
        mode="lines", x=t, y=trials['outputs'][i][:,1],
        legendgroup="OUTPUT 2",
        showlegend=showlegend,
        line_color = 'purple'
    ), row=i+1, col=1)
    fig.add_vline(x=task.fixation + dt, line_width=3, line_dash="dash", line_color="red")
    showlegend = False
    fig.update_yaxes(range=[0, 2], row=i+1, col=1)
fig.update_layout(height=1200, width=800, title_text="Trials")
fig.show()

# Q
# Why START is needed at all as input?
# Sometimes the input stimulus is very low; are we fine with that?
# Confirm the outputs are correct
# I'd suggest to add to the future package a visualization module like this; suggestions/opinions? YEs

#### Training

In [4]:
import os
import logging
import time
import datetime
from tqdm import tqdm
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from annubes.nn.custom_rnn import CustomRNN

In [7]:
# n_networks should be flexible (args)
n_networks = 100
save_path = './training_results/'
logging.basicConfig(filename=f'{save_path}log.txt', level=logging.DEBUG)

# instantiate task class
tau = 100
dt = tau / 5

task_param = {}
task_param['std_inp_noise'] = 0.01
task_param['fixation'] = True
task_param['tau'] = tau
task = GuidoTask(task_param)

# RNN settings
rnn_param = {}

rnn_param['hidden_size'] = 150
rnn_param['ex_in_ratio'] = 0.8
rnn_param['rec_noise_std'] = 0.15

# training hyperparameters
test_size = 1024

minibatch_size = 20
test_minibatch_size = 4
lr = 0.01 # 0.01 seems to be better when zero_grad is used, also in accordance to Song 2016

dt_test = 2
THRESHOLD = 0.2
TARGET_PERFORMANCE = 70
catch_prob = 0.5

num_epochs = 10002

training_stats_dict = {}

In [8]:
# begin simulations
for n in range(n_networks):
    start_time = time.time()
    now = datetime.datetime.now()
    print(f'Started {n}')
    numpy_seed = random.randrange(sys.maxsize)
    #print(numpy_seed)
    rng = np.random.default_rng(numpy_seed)  # seed for generating trials
    torch_seed = random.randrange(sys.maxsize)
    torch.manual_seed(torch_seed)
    training_stats_dict[n] = {}
    logging.info(f'{now} Started {n} with th={THRESHOLD}, lr={lr}, mb={minibatch_size}')
    #rng = np.random.default_rng()  # seed for generating trials

    # create the network
    net = CustomRNN(task.Nin, rnn_param['hidden_size'], task.Nout, rnn_param)
    # load the model if it exists already

    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(net.parameters(), lr, weight_decay=0.1)

    if not os.path.exists(f'{save_path}{n}'):
        os.makedirs(f'{save_path}{n}')

    with open(os.path.join(save_path, str(n), 'accuracy.txt'), 'a+') as file_accuracy:
        for e in range(num_epochs):
            if e == num_epochs-1:
                logging.warning(f'Model {n} did not convergence within {num_epochs}, final accuracy={accuracy_over_choices} over {n_analyzed_trials_total} trials')
                print(f'Model {n} did not convergence within {num_epochs}, final accuracy={accuracy_over_choices} over {n_analyzed_trials_total} trials')
                os.makedirs(f'{save_path}bad_runs/{n}')
                torch.save(net.state_dict(), f'{save_path}bad_runs/{n}/model_not_converged')
                np.save(os.path.join(save_path, str(n), 'numpy_train_seed.npy'), numpy_seed)
                np.save(os.path.join(save_path, str(n), 'torch_seed.npy'), torch_seed)
                with open(os.path.join(f'{save_path}bad_runs/{n}/training_stats.pkl'), 'wb') as f:
                        pickle.dump(training_stats_dict[n], f, pickle.HIGHEST_PROTOCOL)
            # generate trials
            trials = task.generate_trials(rng, dt, minibatch_size,catch_prob)
            #print(trials['modality'])
            #reset gradient to avoid accumulation
            optimizer.zero_grad()

            # feed trials to network
            output, rnn_output = net(torch.Tensor(trials['inputs']), tau, dt)

            # compute loss (dont penalize first 200ms of stimu)
            if task_param['fixation']:
                fixation = trials['phases']['fixation']
                #Skip first 200ms of stimulus as punishment
                stimulus = trials['phases']['stimulus'][10:]
                time_steps_to_punish = np.concatenate((fixation, stimulus))
                #print(time_steps_to_punish)
                loss = loss_fn(output[:, time_steps_to_punish, :],
                            torch.tensor(trials['outputs'][:, time_steps_to_punish, :]))
            else:
                loss = loss_fn(output[:, trials['phases']['stimulus'][10:], :],
                            torch.tensor(trials['outputs'][:, trials['phases']['stimulus'][10:], :]))
                            #!Loss calculated over all timesteps of trial, so not only end? as to also punish slow?

            # update weights
            loss.backward(torch.ones_like(loss))
            optimizer.step()
            net.set_weights()

            # check accuracy
            prediction = torch.argmax(output[:, -1, :], dim=1)#Accuracy only on final timestep of data
            nun_correct_prediction = np.sum(prediction.numpy() == trials['choice'])
            accuracy = nun_correct_prediction * 100 / minibatch_size
            #write training accuracy to file
            file_accuracy.write(f'Iteration {e}: ')
            file_accuracy.write(str(nun_correct_prediction * 100 / minibatch_size) + '\n')

            if  e >= 200 and e % 500 == 0:
                with torch.no_grad():
                    print(f'Starting validation of epoch {e}')
                    logging.info(f'Starting validation of epoch {e}')
                    print(f'Training Accuracy = {str(accuracy)}')
                    print('Loss ' + str(e) + ': ' + str(loss.item()))
                    test_trials = task.generate_trials(rng, dt_test, test_size, catch_prob)

                    n_test_batch = int(test_size / test_minibatch_size)
                    n_test_correct = 0
                    n_analyzed_trials_total = 0
                    n_choices_made = 0
                    for j in tqdm(range(n_test_batch)):
                        start_idx = j * test_minibatch_size
                        end_idx = (j + 1) * test_minibatch_size

                        cur_batch = test_trials['inputs'][start_idx:end_idx]
                        cur_batch_choice = test_trials['choice'][start_idx:end_idx]

                        test_batch_output, _ = net(torch.Tensor(cur_batch), tau, dt_test)

                        output = test_batch_output.detach().numpy()

                        out_diff = output[:, test_trials['phases']['stimulus'], 1] - output[:, test_trials['phases']['stimulus'], 0]

                        decision_time = np.argmax(np.abs(out_diff) > THRESHOLD, axis=1)

                        analysed_trials = np.nonzero(decision_time != 0)[0]

                        out_diff_onset_stimulus = output[:, trials['phases']['stimulus'][0], 1] - output[:, trials['phases']['stimulus'][0], 0]

                        analysed_trials_valid_start = np.nonzero(np.abs(out_diff_onset_stimulus) <= THRESHOLD)[0]


                        analysed_trials_choice_made = np.nonzero(np.sum(np.abs(out_diff) > THRESHOLD, axis=1) != 0)[0]

                        analysed_trials_good_start_choice_made = np.intersect1d(analysed_trials_valid_start, analysed_trials_choice_made)

                        choice = (out_diff[analysed_trials_good_start_choice_made, decision_time[analysed_trials_good_start_choice_made]] > 0).astype(np.int_)

                        n_analyzed_trials = len(analysed_trials_valid_start)
                        n_analyzed_trials_total += n_analyzed_trials
                        n_choices_made += len(choice)

                        n_test_correct += np.sum(cur_batch_choice[analysed_trials_good_start_choice_made] == choice)
                        #accuracy_over_choices = 100 * np.sum(trials['choice'][analysed_trials_good_start_choice_made] == choice) / n_analyzed_trials

                # print("Total analyzed trials===")
                # print(n_analyzed_trials_total)
                # print('Total choices made===')
                # print(n_choices_made)
                # # print('Total trials difference from onset===')
                # # print(out_diff_onset_stimulus)
                # print('Correct guesses===')
                # print(n_test_correct)


                accuracy_over_choices = 100 * n_test_correct / n_analyzed_trials_total

                print('Testing accuracy:' + str(accuracy_over_choices))
                logging.info(f'Testing accuracy: {str(accuracy_over_choices)}')

                training_stats_dict[n][e] = {
                    'n_analyzed_trials': n_analyzed_trials_total,
                    'n_choices_made':n_choices_made,
                    'n_test_correct':n_test_correct,
                    'accuracy': accuracy_over_choices
                }
                print(training_stats_dict[n][e])

                if accuracy_over_choices > TARGET_PERFORMANCE and n_analyzed_trials_total >= 0.9 * test_size:
                    print(f'Final number of analyzed trials: {n_analyzed_trials_total}')
                    logging.info(f'Final number of analyzed trials: {n_analyzed_trials_total}')
                    print(f'Stopped training with an accuracy of {np.round(accuracy_over_choices,2)} at epoch {e}.')
                    os.makedirs(f'{save_path}good_runs/{n}')
                    logging.info(f'Stopped training with an accuracy of {np.round(accuracy_over_choices,2)} at epoch {e}.')
                    torch.save(net.state_dict(), f'{save_path}good_runs/{n}/model')
                    with open(os.path.join(f'{save_path}good_runs/{n}/training_stats.pkl'), 'wb') as f:
                        pickle.dump(training_stats_dict[n], f, pickle.HIGHEST_PROTOCOL)
                    np.save(os.path.join(save_path, str(n), 'numpy_train_seed.npy'), numpy_seed)
                    np.save(os.path.join(save_path, str(n), 'torch_seed.npy'), torch_seed)
                    break
                    # create directory for saving results if it does not exist
                # # save model
                #torch.save(net.state_dict(), f'{save_path}{n}/EPOCH/{e}/model')

Started 0
Starting validation of epoch 500
Training Accuracy = 75.0
Loss 500: 0.3170486390590668


100%|██████████| 256/256 [00:33<00:00,  7.75it/s]


Testing accuracy:69.9902248289345
{'n_analyzed_trials': 1023, 'n_choices_made': 996, 'n_test_correct': 716, 'accuracy': 69.9902248289345}
Starting validation of epoch 1000
Training Accuracy = 75.0
Loss 1000: 0.2803773283958435


100%|██████████| 256/256 [00:34<00:00,  7.39it/s]


Testing accuracy:73.5960591133005
{'n_analyzed_trials': 1015, 'n_choices_made': 1005, 'n_test_correct': 747, 'accuracy': 73.5960591133005}
Final number of analyzed trials: 1015
Stopped training with an accuracy of 73.6 at epoch 1000.
Started 1
Starting validation of epoch 500
Training Accuracy = 90.0
Loss 500: 0.26183581352233887


100%|██████████| 256/256 [00:34<00:00,  7.48it/s]


Testing accuracy:67.5464320625611
{'n_analyzed_trials': 1023, 'n_choices_made': 1020, 'n_test_correct': 691, 'accuracy': 67.5464320625611}
Starting validation of epoch 1000
Training Accuracy = 85.0
Loss 1000: 0.23967643082141876


100%|██████████| 256/256 [00:34<00:00,  7.45it/s]


Testing accuracy:72.8515625
{'n_analyzed_trials': 1024, 'n_choices_made': 1024, 'n_test_correct': 746, 'accuracy': 72.8515625}
Final number of analyzed trials: 1024
Stopped training with an accuracy of 72.85 at epoch 1000.
Started 2
Starting validation of epoch 500
Training Accuracy = 75.0
Loss 500: 0.31995677947998047


100%|██████████| 256/256 [00:33<00:00,  7.64it/s]


Testing accuracy:65.81027667984189
{'n_analyzed_trials': 1012, 'n_choices_made': 1002, 'n_test_correct': 666, 'accuracy': 65.81027667984189}
Starting validation of epoch 1000
Training Accuracy = 90.0
Loss 1000: 0.21445322036743164


100%|██████████| 256/256 [00:34<00:00,  7.35it/s]


Testing accuracy:75.17241379310344
{'n_analyzed_trials': 1015, 'n_choices_made': 1013, 'n_test_correct': 763, 'accuracy': 75.17241379310344}
Final number of analyzed trials: 1015
Stopped training with an accuracy of 75.17 at epoch 1000.
Started 3
Starting validation of epoch 500
Training Accuracy = 95.0
Loss 500: 0.24103637039661407


100%|██████████| 256/256 [00:35<00:00,  7.28it/s]


Testing accuracy:65.13545347467608
{'n_analyzed_trials': 849, 'n_choices_made': 849, 'n_test_correct': 553, 'accuracy': 65.13545347467608}
Starting validation of epoch 1000
Training Accuracy = 85.0
Loss 1000: 0.23125243186950684


100%|██████████| 256/256 [00:35<00:00,  7.26it/s]


Testing accuracy:67.47967479674797
{'n_analyzed_trials': 984, 'n_choices_made': 984, 'n_test_correct': 664, 'accuracy': 67.47967479674797}
Starting validation of epoch 1500
Training Accuracy = 95.0
Loss 1500: 0.18020616471767426


100%|██████████| 256/256 [00:35<00:00,  7.28it/s]


Testing accuracy:67.0275590551181
{'n_analyzed_trials': 1016, 'n_choices_made': 1015, 'n_test_correct': 681, 'accuracy': 67.0275590551181}
Starting validation of epoch 2000
Training Accuracy = 90.0
Loss 2000: 0.21321657299995422


100%|██████████| 256/256 [00:35<00:00,  7.16it/s]


Testing accuracy:68.75612144955926
{'n_analyzed_trials': 1021, 'n_choices_made': 1021, 'n_test_correct': 702, 'accuracy': 68.75612144955926}
Starting validation of epoch 2500
Training Accuracy = 75.0
Loss 2500: 0.2877655029296875


100%|██████████| 256/256 [00:34<00:00,  7.51it/s]


Testing accuracy:70.08797653958945
{'n_analyzed_trials': 1023, 'n_choices_made': 1022, 'n_test_correct': 717, 'accuracy': 70.08797653958945}
Final number of analyzed trials: 1023
Stopped training with an accuracy of 70.09 at epoch 2500.
Started 4
Starting validation of epoch 500
Training Accuracy = 55.0
Loss 500: 0.3090377748012543


100%|██████████| 256/256 [00:35<00:00,  7.19it/s]


Testing accuracy:61.19257086999023
{'n_analyzed_trials': 1023, 'n_choices_made': 1021, 'n_test_correct': 626, 'accuracy': 61.19257086999023}
Starting validation of epoch 1000
Training Accuracy = 80.0
Loss 1000: 0.22231720387935638


100%|██████████| 256/256 [00:35<00:00,  7.19it/s]


Testing accuracy:65.234375
{'n_analyzed_trials': 1024, 'n_choices_made': 1022, 'n_test_correct': 668, 'accuracy': 65.234375}
Starting validation of epoch 1500
Training Accuracy = 100.0
Loss 1500: 0.18897323310375214


100%|██████████| 256/256 [00:34<00:00,  7.52it/s]


Testing accuracy:59.86328125
{'n_analyzed_trials': 1024, 'n_choices_made': 1024, 'n_test_correct': 613, 'accuracy': 59.86328125}
Starting validation of epoch 2000
Training Accuracy = 95.0
Loss 2000: 0.2236461341381073


100%|██████████| 256/256 [00:35<00:00,  7.13it/s]


Testing accuracy:64.453125
{'n_analyzed_trials': 1024, 'n_choices_made': 1024, 'n_test_correct': 660, 'accuracy': 64.453125}
Starting validation of epoch 2500
Training Accuracy = 85.0
Loss 2500: 0.24933715164661407


100%|██████████| 256/256 [00:34<00:00,  7.40it/s]


Testing accuracy:64.55078125
{'n_analyzed_trials': 1024, 'n_choices_made': 1024, 'n_test_correct': 661, 'accuracy': 64.55078125}
Starting validation of epoch 3000
Training Accuracy = 85.0
Loss 3000: 0.26174768805503845


100%|██████████| 256/256 [00:35<00:00,  7.16it/s]


Testing accuracy:64.453125
{'n_analyzed_trials': 1024, 'n_choices_made': 1024, 'n_test_correct': 660, 'accuracy': 64.453125}
Starting validation of epoch 3500
Training Accuracy = 85.0
Loss 3500: 0.1870669722557068


100%|██████████| 256/256 [00:36<00:00,  7.06it/s]


Testing accuracy:64.2578125
{'n_analyzed_trials': 1024, 'n_choices_made': 1024, 'n_test_correct': 658, 'accuracy': 64.2578125}
Starting validation of epoch 4000
Training Accuracy = 95.0
Loss 4000: 0.1775290071964264


100%|██████████| 256/256 [00:34<00:00,  7.49it/s]


Testing accuracy:63.76953125
{'n_analyzed_trials': 1024, 'n_choices_made': 1023, 'n_test_correct': 653, 'accuracy': 63.76953125}
Starting validation of epoch 4500
Training Accuracy = 85.0
Loss 4500: 0.16065900027751923


100%|██████████| 256/256 [00:35<00:00,  7.18it/s]


Testing accuracy:67.08984375
{'n_analyzed_trials': 1024, 'n_choices_made': 1024, 'n_test_correct': 687, 'accuracy': 67.08984375}
Starting validation of epoch 5000
Training Accuracy = 80.0
Loss 5000: 0.22463910281658173


100%|██████████| 256/256 [00:34<00:00,  7.35it/s]


Testing accuracy:64.453125
{'n_analyzed_trials': 1024, 'n_choices_made': 1024, 'n_test_correct': 660, 'accuracy': 64.453125}
Starting validation of epoch 5500
Training Accuracy = 90.0
Loss 5500: 0.17341502010822296


100%|██████████| 256/256 [00:36<00:00,  7.08it/s]


Testing accuracy:64.453125
{'n_analyzed_trials': 1024, 'n_choices_made': 1024, 'n_test_correct': 660, 'accuracy': 64.453125}
Starting validation of epoch 6000
Training Accuracy = 95.0
Loss 6000: 0.15283755958080292


100%|██████████| 256/256 [00:35<00:00,  7.14it/s]


Testing accuracy:64.2578125
{'n_analyzed_trials': 1024, 'n_choices_made': 1024, 'n_test_correct': 658, 'accuracy': 64.2578125}
Starting validation of epoch 6500
Training Accuracy = 90.0
Loss 6500: 0.1992107629776001


100%|██████████| 256/256 [00:37<00:00,  6.90it/s]


Testing accuracy:62.98828125
{'n_analyzed_trials': 1024, 'n_choices_made': 1024, 'n_test_correct': 645, 'accuracy': 62.98828125}
Starting validation of epoch 7000
Training Accuracy = 80.0
Loss 7000: 0.2791494131088257


100%|██████████| 256/256 [00:35<00:00,  7.28it/s]


Testing accuracy:65.234375
{'n_analyzed_trials': 1024, 'n_choices_made': 1022, 'n_test_correct': 668, 'accuracy': 65.234375}
Starting validation of epoch 7500
Training Accuracy = 95.0
Loss 7500: 0.20661382377147675


100%|██████████| 256/256 [00:36<00:00,  6.94it/s]


Testing accuracy:61.81640625
{'n_analyzed_trials': 1024, 'n_choices_made': 1024, 'n_test_correct': 633, 'accuracy': 61.81640625}
Starting validation of epoch 8000
Training Accuracy = 95.0
Loss 8000: 0.18907301127910614


100%|██████████| 256/256 [00:35<00:00,  7.28it/s]


Testing accuracy:64.2578125
{'n_analyzed_trials': 1024, 'n_choices_made': 1024, 'n_test_correct': 658, 'accuracy': 64.2578125}
Starting validation of epoch 8500
Training Accuracy = 95.0
Loss 8500: 0.18374192714691162


100%|██████████| 256/256 [00:35<00:00,  7.21it/s]


Testing accuracy:62.5
{'n_analyzed_trials': 1024, 'n_choices_made': 1024, 'n_test_correct': 640, 'accuracy': 62.5}
Starting validation of epoch 9000
Training Accuracy = 90.0
Loss 9000: 0.23426076769828796


100%|██████████| 256/256 [00:36<00:00,  7.08it/s]


Testing accuracy:64.94140625
{'n_analyzed_trials': 1024, 'n_choices_made': 1023, 'n_test_correct': 665, 'accuracy': 64.94140625}
Starting validation of epoch 9500
Training Accuracy = 90.0
Loss 9500: 0.1780959814786911


100%|██████████| 256/256 [00:36<00:00,  7.08it/s]


Testing accuracy:64.74609375
{'n_analyzed_trials': 1024, 'n_choices_made': 1024, 'n_test_correct': 663, 'accuracy': 64.74609375}
Starting validation of epoch 10000
Training Accuracy = 95.0
Loss 10000: 0.15363676846027374


100%|██████████| 256/256 [00:34<00:00,  7.40it/s]


Testing accuracy:64.94140625
{'n_analyzed_trials': 1024, 'n_choices_made': 1024, 'n_test_correct': 665, 'accuracy': 64.94140625}
Model 4 did not convergence within 10002, final accuracy=64.94140625 over 1024 trials
Started 5
Starting validation of epoch 500
Training Accuracy = 75.0
Loss 500: 0.27004361152648926


100%|██████████| 256/256 [00:36<00:00,  7.06it/s]


Testing accuracy:72.33400402414487
{'n_analyzed_trials': 994, 'n_choices_made': 994, 'n_test_correct': 719, 'accuracy': 72.33400402414487}
Final number of analyzed trials: 994
Stopped training with an accuracy of 72.33 at epoch 500.
Started 6
Starting validation of epoch 500
Training Accuracy = 100.0
Loss 500: 0.1808059960603714


100%|██████████| 256/256 [00:34<00:00,  7.41it/s]


Testing accuracy:57.03125
{'n_analyzed_trials': 1024, 'n_choices_made': 1024, 'n_test_correct': 584, 'accuracy': 57.03125}
Starting validation of epoch 1000
Training Accuracy = 80.0
Loss 1000: 0.27861201763153076


100%|██████████| 256/256 [00:35<00:00,  7.31it/s]


Testing accuracy:55.37109375
{'n_analyzed_trials': 1024, 'n_choices_made': 1024, 'n_test_correct': 567, 'accuracy': 55.37109375}
Starting validation of epoch 1500
Training Accuracy = 95.0
Loss 1500: 0.17758683860301971


100%|██████████| 256/256 [00:36<00:00,  6.98it/s]


Testing accuracy:63.671875
{'n_analyzed_trials': 1024, 'n_choices_made': 1024, 'n_test_correct': 652, 'accuracy': 63.671875}
Starting validation of epoch 2000
Training Accuracy = 90.0
Loss 2000: 0.2118021696805954


100%|██████████| 256/256 [00:37<00:00,  6.91it/s]


Testing accuracy:67.578125
{'n_analyzed_trials': 1024, 'n_choices_made': 1024, 'n_test_correct': 692, 'accuracy': 67.578125}


KeyboardInterrupt: 

### Habituation task

#### Generate trials

In [10]:
tau = 100
dt = tau / 5

task_param = {}
task_param['std_inp_noise'] = 0.01
task_param['fixation'] = True
task_param['tau'] = tau
task = HabituationTask(task_param)

numpy_seed = random.randrange(sys.maxsize)
rng = np.random.default_rng(numpy_seed)
minibatch_size = 7 # It will be 40 in the real experiments for each stimulus (80 in total), each block represents an epoch
catch_prob = 0.5

trials = task.generate_trials(rng, dt, minibatch_size, catch_prob)
t = np.linspace(dt, task.T, int(task.T / dt))

In [11]:
fig = make_subplots(rows=minibatch_size, cols=1,
                    shared_xaxes=True,
                    vertical_spacing=0.04,
                    subplot_titles=[
                        "Trial " + str(i + 1) + " - modality " + str(trials['modality'][i])
                        for i in range(minibatch_size)])
showlegend = True
for i in range(minibatch_size):
    fig.add_trace(go.Scatter(
        name="VISUAL",
        mode="markers+lines", x=t, y=trials['inputs'][i][:,0],
        marker_symbol="star",
        legendgroup="VISUAL",
        showlegend=showlegend,
        line_color = 'blue'
    ), row=i+1, col=1)
    fig.add_trace(go.Scatter(
        name="AUDITORY",
        mode="markers+lines", x=t, y=trials['inputs'][i][:,1],
        marker_symbol="star",
        legendgroup="AUDITORY",
        showlegend=showlegend,
        line_color = 'black'
    ), row=i+1, col=1)
    fig.add_trace(go.Scatter(
        name="START",
        mode="markers+lines", x=t, y=trials['inputs'][i][:,2],
        marker_symbol="star",
        legendgroup="START",
        showlegend=showlegend,
        line_color = 'green'
    ), row=i+1, col=1)
    fig.add_trace(go.Scatter(
        name="NO STIMULUS",
        mode="lines", x=t, y=trials['outputs'][i][:,0],
        legendgroup="OUTPUT 1",
        showlegend=showlegend,
        line_color = 'orange'
    ), row=i+1, col=1)
    fig.add_trace(go.Scatter(
        name="STIMULUS/STIMULI",
        mode="lines", x=t, y=trials['outputs'][i][:,1],
        legendgroup="OUTPUT 2",
        showlegend=showlegend,
        line_color = 'purple'
    ), row=i+1, col=1)
    fig.add_vline(x=task.fixation + dt, line_width=3, line_dash="dash", line_color="red")
    showlegend = False
    fig.update_yaxes(range=[0, 2], row=i+1, col=1)
fig.update_layout(height=1300, width=1300, title_text="Trials")
fig.show()

# Q
# How can we insert the inter-trials interval? (3-5 s) Was it present in Guido's task? No
# Sometimes the input stimulus is very low; are we fine with that? YES
# separate auditory and visual neurons in the networks
# 40 visual + 40 auditory in one epoch, 50% of them can be catch trials