In [1]:
from annubes.tasks.task import Task
import numpy as np
import sys
import random

### Guido Task

#### Generate trials with Task

In [8]:
session_in = {'a': 0.33, 'v': 0.33, 'av': 0.33}
t_in = 1000
value_in = [9, 9.25, 9.5, 9.75, 10, 10.25, 10.5, 11, 12, 13]
scaling = True
catch_prob = 0.1

# Define task
guido_task = Task(name = 'guido',
                  session_in = session_in,
                  t_in = t_in,
                  value_in = value_in,
                  scaling = scaling,
                  catch_prob = catch_prob,
                  high_out = 1.5,
                  low_out = 0.2)

batch_size = 20
# Generate trials
guido_task.generate_trials(
    batch_size = batch_size)

In [9]:
fig = guido_task.plot_trials(8)
fig.show()

#### Training

In [None]:
import os
import logging
import time
import datetime
from tqdm import tqdm
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from annubes.nn.custom_rnn import CustomRNN

In [None]:
# n_networks should be flexible (args)
n_networks = 100
save_path = './training_results/'
logging.basicConfig(filename=f'{save_path}log.txt', level=logging.DEBUG)

# instantiate task class
tau = 100
dt = tau / 5

task_param = {}
task_param['std_inp_noise'] = 0.01
task_param['fixation'] = True
task_param['tau'] = tau
task = GuidoTask(task_param)

# RNN settings
rnn_param = {}

rnn_param['hidden_size'] = 150
rnn_param['ex_in_ratio'] = 0.8
rnn_param['rec_noise_std'] = 0.15

# training hyperparameters
test_size = 1024

minibatch_size = 20
test_minibatch_size = 4
lr = 0.01 # 0.01 seems to be better when zero_grad is used, also in accordance to Song 2016

dt_test = 2
THRESHOLD = 0.2
TARGET_PERFORMANCE = 70
catch_prob = 0.5

num_epochs = 10002

training_stats_dict = {}

In [None]:
# begin simulations
for n in range(n_networks):
    start_time = time.time()
    now = datetime.datetime.now()
    print(f'Started {n}')
    numpy_seed = random.randrange(sys.maxsize)
    #print(numpy_seed)
    rng = np.random.default_rng(numpy_seed)  # seed for generating trials
    torch_seed = random.randrange(sys.maxsize)
    torch.manual_seed(torch_seed)
    training_stats_dict[n] = {}
    logging.info(f'{now} Started {n} with th={THRESHOLD}, lr={lr}, mb={minibatch_size}')
    #rng = np.random.default_rng()  # seed for generating trials

    # create the network
    net = CustomRNN(task.Nin, rnn_param['hidden_size'], task.Nout, rnn_param)
    # load the model if it exists already

    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(net.parameters(), lr, weight_decay=0.1)

    if not os.path.exists(f'{save_path}{n}'):
        os.makedirs(f'{save_path}{n}')

    with open(os.path.join(save_path, str(n), 'accuracy.txt'), 'a+') as file_accuracy:
        for e in range(num_epochs):
            if e == num_epochs-1:
                logging.warning(f'Model {n} did not convergence within {num_epochs}, final accuracy={accuracy_over_choices} over {n_analyzed_trials_total} trials')
                print(f'Model {n} did not convergence within {num_epochs}, final accuracy={accuracy_over_choices} over {n_analyzed_trials_total} trials')
                os.makedirs(f'{save_path}bad_runs/{n}')
                torch.save(net.state_dict(), f'{save_path}bad_runs/{n}/model_not_converged')
                np.save(os.path.join(save_path, str(n), 'numpy_train_seed.npy'), numpy_seed)
                np.save(os.path.join(save_path, str(n), 'torch_seed.npy'), torch_seed)
                with open(os.path.join(f'{save_path}bad_runs/{n}/training_stats.pkl'), 'wb') as f:
                        pickle.dump(training_stats_dict[n], f, pickle.HIGHEST_PROTOCOL)
            # generate trials
            trials = task.generate_trials(rng, dt, minibatch_size,catch_prob)
            #print(trials['modality'])
            #reset gradient to avoid accumulation
            optimizer.zero_grad()

            # feed trials to network
            output, rnn_output = net(torch.Tensor(trials['inputs']), tau, dt)

            # compute loss (dont penalize first 200ms of stimu)
            if task_param['fixation']:
                fixation = trials['phases']['fixation']
                #Skip first 200ms of stimulus as punishment
                stimulus = trials['phases']['stimulus'][10:]
                time_steps_to_punish = np.concatenate((fixation, stimulus))
                #print(time_steps_to_punish)
                loss = loss_fn(output[:, time_steps_to_punish, :],
                            torch.tensor(trials['outputs'][:, time_steps_to_punish, :]))
            else:
                loss = loss_fn(output[:, trials['phases']['stimulus'][10:], :],
                            torch.tensor(trials['outputs'][:, trials['phases']['stimulus'][10:], :]))
                            #!Loss calculated over all timesteps of trial, so not only end? as to also punish slow?

            # update weights
            loss.backward(torch.ones_like(loss))
            optimizer.step()
            net.set_weights()

            # check accuracy
            prediction = torch.argmax(output[:, -1, :], dim=1)#Accuracy only on final timestep of data
            nun_correct_prediction = np.sum(prediction.numpy() == trials['choice'])
            accuracy = nun_correct_prediction * 100 / minibatch_size
            #write training accuracy to file
            file_accuracy.write(f'Iteration {e}: ')
            file_accuracy.write(str(nun_correct_prediction * 100 / minibatch_size) + '\n')

            if  e >= 200 and e % 500 == 0:
                with torch.no_grad():
                    print(f'Starting validation of epoch {e}')
                    logging.info(f'Starting validation of epoch {e}')
                    print(f'Training Accuracy = {str(accuracy)}')
                    print('Loss ' + str(e) + ': ' + str(loss.item()))
                    test_trials = task.generate_trials(rng, dt_test, test_size, catch_prob)

                    n_test_batch = int(test_size / test_minibatch_size)
                    n_test_correct = 0
                    n_analyzed_trials_total = 0
                    n_choices_made = 0
                    for j in tqdm(range(n_test_batch)):
                        start_idx = j * test_minibatch_size
                        end_idx = (j + 1) * test_minibatch_size

                        cur_batch = test_trials['inputs'][start_idx:end_idx]
                        cur_batch_choice = test_trials['choice'][start_idx:end_idx]

                        test_batch_output, _ = net(torch.Tensor(cur_batch), tau, dt_test)

                        output = test_batch_output.detach().numpy()

                        out_diff = output[:, test_trials['phases']['stimulus'], 1] - output[:, test_trials['phases']['stimulus'], 0]

                        decision_time = np.argmax(np.abs(out_diff) > THRESHOLD, axis=1)

                        analysed_trials = np.nonzero(decision_time != 0)[0]

                        out_diff_onset_stimulus = output[:, trials['phases']['stimulus'][0], 1] - output[:, trials['phases']['stimulus'][0], 0]

                        analysed_trials_valid_start = np.nonzero(np.abs(out_diff_onset_stimulus) <= THRESHOLD)[0]


                        analysed_trials_choice_made = np.nonzero(np.sum(np.abs(out_diff) > THRESHOLD, axis=1) != 0)[0]

                        analysed_trials_good_start_choice_made = np.intersect1d(analysed_trials_valid_start, analysed_trials_choice_made)

                        choice = (out_diff[analysed_trials_good_start_choice_made, decision_time[analysed_trials_good_start_choice_made]] > 0).astype(np.int_)

                        n_analyzed_trials = len(analysed_trials_valid_start)
                        n_analyzed_trials_total += n_analyzed_trials
                        n_choices_made += len(choice)

                        n_test_correct += np.sum(cur_batch_choice[analysed_trials_good_start_choice_made] == choice)
                        #accuracy_over_choices = 100 * np.sum(trials['choice'][analysed_trials_good_start_choice_made] == choice) / n_analyzed_trials

                # print("Total analyzed trials===")
                # print(n_analyzed_trials_total)
                # print('Total choices made===')
                # print(n_choices_made)
                # # print('Total trials difference from onset===')
                # # print(out_diff_onset_stimulus)
                # print('Correct guesses===')
                # print(n_test_correct)


                accuracy_over_choices = 100 * n_test_correct / n_analyzed_trials_total

                print('Testing accuracy:' + str(accuracy_over_choices))
                logging.info(f'Testing accuracy: {str(accuracy_over_choices)}')

                training_stats_dict[n][e] = {
                    'n_analyzed_trials': n_analyzed_trials_total,
                    'n_choices_made':n_choices_made,
                    'n_test_correct':n_test_correct,
                    'accuracy': accuracy_over_choices
                }
                print(training_stats_dict[n][e])

                if accuracy_over_choices > TARGET_PERFORMANCE and n_analyzed_trials_total >= 0.9 * test_size:
                    print(f'Final number of analyzed trials: {n_analyzed_trials_total}')
                    logging.info(f'Final number of analyzed trials: {n_analyzed_trials_total}')
                    print(f'Stopped training with an accuracy of {np.round(accuracy_over_choices,2)} at epoch {e}.')
                    os.makedirs(f'{save_path}good_runs/{n}')
                    logging.info(f'Stopped training with an accuracy of {np.round(accuracy_over_choices,2)} at epoch {e}.')
                    torch.save(net.state_dict(), f'{save_path}good_runs/{n}/model')
                    with open(os.path.join(f'{save_path}good_runs/{n}/training_stats.pkl'), 'wb') as f:
                        pickle.dump(training_stats_dict[n], f, pickle.HIGHEST_PROTOCOL)
                    np.save(os.path.join(save_path, str(n), 'numpy_train_seed.npy'), numpy_seed)
                    np.save(os.path.join(save_path, str(n), 'torch_seed.npy'), torch_seed)
                    break
                    # create directory for saving results if it does not exist
                # # save model
                #torch.save(net.state_dict(), f'{save_path}{n}/EPOCH/{e}/model')

### Habituation task

#### Generate trials

In [10]:
session_in = {'v': 0.5, 'a': 0.5}
t_in = 5000
value_in = [.8, .9, 1]
scaling = True
catch_prob = 0.5

# Define task
hab_task = Task(name = 'habituation',
                  session_in = session_in,
                  t_in = t_in,
                  value_in = value_in,
                  scaling = scaling,
                  catch_prob = catch_prob,
                  high_out = 1.5,
                  low_out = 0.2)

batch_size = 20
# Generate trials
hab_task.generate_trials(
    batch_size = batch_size)

In [11]:
fig = hab_task.plot_trials(8)
fig.show()