In [4]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from cgtasknet.instruments.instrument_accuracy_network import CorrectAnswerNetwork
from cgtasknet.tasks.reduce import (
    CtxDMTaskParameters,
    DMTaskParameters,
    DMTaskRandomModParameters,
    GoDlTaskParameters,
    GoDlTaskRandomModParameters,
    GoRtTaskParameters,
    GoRtTaskRandomModParameters,
    GoTaskParameters,
    GoTaskRandomModParameters,
    MultyReduceTasks,
    RomoTaskParameters,
    RomoTaskRandomModParameters,
)
from tqdm import tqdm
go_task_list_values = np.linspace(0, 1, 8)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device=torch.device('cpu')
print(f"device={device}")

device=cuda:0


In [5]:
batch_size = 100
number_of_epochs = 3000
number_of_tasks = 1
romo_parameters = RomoTaskRandomModParameters(
    romo=RomoTaskParameters(
        delay=0.2,
        positive_shift_delay_time=.8,
        trial_time=0.2,
        positive_shift_trial_time=0.4,
        answer_time=.25
    ),
)
dm_parameters = DMTaskRandomModParameters(
    dm=DMTaskParameters(trial_time=0.3, positive_shift_trial_time=1.5, answer_time=.25)
)
ctx_parameters = CtxDMTaskParameters(dm=dm_parameters.dm)
go_parameters = GoTaskRandomModParameters(
    go=GoTaskParameters(
        trial_time=0.3,
        positive_shift_trial_time=1.5,
        value=go_task_list_values,
        answer_time=.25
    )
)
gort_parameters = GoRtTaskRandomModParameters(
    go_rt=GoRtTaskParameters(
        trial_time=0.3,
        positive_shift_trial_time=1.5,
        answer_time=1,
        value=go_task_list_values,
    )
)
godl_parameters = GoDlTaskRandomModParameters(
    go_dl=GoDlTaskParameters(
        go=GoTaskParameters(trial_time=0.2, positive_shift_trial_time=0.4, answer_time=.25, value=go_task_list_values),
        delay=0.2,
        positive_shift_delay_time=.8,

    )
)

tasks = [
    "RomoTask1",
    "RomoTask2",
    "DMTask1",
    "DMTask2",
    "CtxDMTask1",
    "CtxDMTask2",
    "GoTask1",
    "GoTask2",
    "GoRtTask1",
    "GoRtTask2",
    "GoDlTask1",
    "GoDlTask2",
]
task_dict = {
    tasks[0]: romo_parameters,
    tasks[1]: romo_parameters,
    tasks[2]: dm_parameters,
    tasks[3]: dm_parameters,
    tasks[4]: ctx_parameters,
    tasks[5]: ctx_parameters,
    tasks[6]: go_parameters,
    tasks[7]: go_parameters,
    tasks[8]: gort_parameters,
    tasks[9]: gort_parameters,
    tasks[10]: godl_parameters,
    tasks[11]: godl_parameters,
}
Task = MultyReduceTasks(
    tasks=task_dict,
    batch_size=batch_size,
    delay_between=0,
    enable_fixation_delay=True,
    mode="random",
)
print("Task parameters:")
for key in task_dict:
    print(f"{key}:\n{task_dict[key]}\n")

print(f"inputs/outputs: {Task.feature_and_act_size[0]}/{Task.feature_and_act_size[1]}")

Task parameters:
RomoTask1:
RomoTaskRandomModParameters(romo=RomoTaskParameters(dt=0.001, trial_time=0.2, answer_time=0.25, value=(None, None), delay=0.2, negative_shift_trial_time=0, positive_shift_trial_time=0.4, negative_shift_delay_time=0, positive_shift_delay_time=0.8), n_mods=2)

RomoTask2:
RomoTaskRandomModParameters(romo=RomoTaskParameters(dt=0.001, trial_time=0.2, answer_time=0.25, value=(None, None), delay=0.2, negative_shift_trial_time=0, positive_shift_trial_time=0.4, negative_shift_delay_time=0, positive_shift_delay_time=0.8), n_mods=2)

DMTask1:
DMTaskRandomModParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.3, answer_time=0.25, value=None, negative_shift_trial_time=0, positive_shift_trial_time=1.5), n_mods=2)

DMTask2:
DMTaskRandomModParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.3, answer_time=0.25, value=None, negative_shift_trial_time=0, positive_shift_trial_time=1.5), n_mods=2)

CtxDMTask1:
CtxDMTaskParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.3,

In [6]:
sorted_tasks = sorted(tasks)
re_word = 'Go'
choices_tasks = []
values_tasks = []
for i in range(len(sorted_tasks)):
    if re_word in sorted_tasks[i]:
        values_tasks.append(i)
    else:
        choices_tasks.append(i)
can = CorrectAnswerNetwork(choices_tasks, values_tasks, 0.15)

In [7]:
from cgtasknet.net import SNNlifadex
from norse.torch import LIFAdExParameters

feature_size, output_size = Task.feature_and_act_size
hidden_size = 256
def model_load(tau_a, model_path):
    neuron_parameters = LIFAdExParameters(
        v_th=torch.as_tensor(0.45),
        tau_ada_inv=torch.as_tensor(tau_a),
        alpha=100,
        method="super",
        # rho_reset = torch.as_tensor(5)
    )
    model = SNNlifadex(
        feature_size,
        hidden_size,
        output_size,
        neuron_parameters=neuron_parameters,
        tau_filter_inv=20,
        save_states=True,
    )
    model.load_state_dict(torch.load(model_path))
    model = model.to(device)
    return model

In [8]:
sorted_name_of_tasks = sorted(tasks)
dl_tasks_indexes = []
#dl_tasks_indexes.append(sorted_name_of_tasks.index('RomoTask1'))
#dl_tasks_indexes.append(sorted_name_of_tasks.index('RomoTask2'))
dl_tasks_indexes.append(sorted_name_of_tasks.index('RomoTask1'))
dl_tasks_indexes.append(sorted_name_of_tasks.index('RomoTask2'))
#without_dl_tasks_indexes = list(set([*range(len(tasks))])
#                                - set(dl_tasks_indexes))
without_dl_tasks_indexes = []
without_dl_tasks_indexes.append(sorted_name_of_tasks.index('DMTask1'))
without_dl_tasks_indexes.append(sorted_name_of_tasks.index('DMTask2'))
print(
    f'dl_tasks_indexes = {dl_tasks_indexes}')
print(
    f'other = {without_dl_tasks_indexes}'
)

dl_tasks_indexes = [10, 11]
other = [2, 3]


In [9]:
dl_tasks = []
without_dl_tasks = []
for i in range(len(tasks)):
    temp_task = MultyReduceTasks(
    tasks=task_dict,
    batch_size=batch_size,
    delay_between=0,
    enable_fixation_delay=True,
    mode="random",
    task_number=i,
    )
    if i in dl_tasks_indexes:
        dl_tasks.append(temp_task)
    elif i in without_dl_tasks_indexes:
        without_dl_tasks.append(temp_task)

In [10]:
name_weights = 'weights_100_N_256_without_square_2999_'
dirs=[
    'weights'
]
tau_a_values = [
    1/2
]

In [11]:
repeat = 10

In [12]:
from norse.torch import LIFAdExState
import os
for i in range(len(dirs)):
    results = 0
    model = model_load(tau_a_values[i], os.path.join(dirs[i], name_weights))
    for _ in range(repeat):
        for j in range(len(dl_tasks)):
            inputs, target_outputs = dl_tasks[j].dataset(1, delay_between=0)
            inputs[:, :, 1:3] #+= np.random.normal(0, test_sigma, size=inputs[:, :, 1:3].shape)
            inputs = torch.from_numpy(inputs).type(torch.float).to(device)
            target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
            init_state = LIFAdExState(
                    torch.zeros(batch_size, hidden_size).to(device),
                    torch.rand(batch_size, hidden_size).to(device),
                    torch.zeros(batch_size, hidden_size).to(device),
                    torch.zeros(batch_size, hidden_size).to(device),
                    )
            outputs = model(inputs, init_state)[0]
            type_tasks = list(np.where(inputs[-1, :, 3:].detach().cpu().numpy() == 1)[1])
            answers = can.run(target_outputs[50:, :, 0].cpu(), outputs[50:, :, 0].cpu(), target_outputs[50:, :, 1:].cpu(),
                          outputs[50:, :, 1:].cpu(), type_tasks)
            answers /= batch_size 
            results += answers
    results /= 2 * repeat
    with open(f'accuracy_dl_{0.15}', 'a') as f:
        f.write(f'{tau_a_values[i]}:{results}\n')
    print(dirs[i], results * 100)

weights 87.05000000000001


In [15]:
for i in range(len(dirs)):
    results = 0
    model = model_load(tau_a_values[i], os.path.join(dirs[i], name_weights))
    for _ in range(repeat):
        for j in range(len(without_dl_tasks)):
            inputs, target_outputs = without_dl_tasks[j].dataset(1, delay_between=0)
            inputs[:, :, 1:3] #+= np.random.normal(0, test_sigma, size=inputs[:, :, 1:3].shape)
            inputs = torch.from_numpy(inputs).type(torch.float).to(device)
            target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
            init_state = LIFAdExState(
                    torch.zeros(batch_size, hidden_size).to(device),
                    torch.rand(batch_size, hidden_size).to(device),
                    torch.zeros(batch_size, hidden_size).to(device),
                    torch.zeros(batch_size, hidden_size).to(device),
                    )
            outputs = model(inputs, init_state)[0]
            type_tasks = list(np.where(inputs[-1, :, 3:].detach().cpu().numpy() == 1)[1])
            answers = can.run(target_outputs[50:, :, 0].cpu(), outputs[50:, :, 0].cpu(), target_outputs[50:, :, 1:].cpu(),
                          outputs[50:, :, 1:].cpu(), type_tasks)
            answers /= batch_size 
            results += answers
    results /= 2 * repeat
    with open(f'accuracy_without_dl_{0.15}', 'a') as f:
        f.write(f'{tau_a_values[i]}:{results}\n')
    print(dirs[i], results * 100)

weights 96.44999999999999


In [14]:
without_dl_tasks

[<cgtasknet.tasks.reduce.multy.MultyReduceTasks at 0x22a46c3c790>,
 <cgtasknet.tasks.reduce.multy.MultyReduceTasks at 0x22a46c3ceb0>]