# Two task network 

Network has eight inputs:

1. The fixation. 
1. $u_{rule}^{1}$
1. $u_{rule}^{2}$
1. The first context mod. 
1. The second ontext mod. 
1. The first context status. 
1. The second context status. 
1. The Romo signals.

Network has five outputs: 
1. The fixation. 
1. The first context output. 
1. The second context output. 
1. The first Romo task output. 
1. The second Romo task output. 


<div>
<img src="./images/Sheme.png" width="300"/>
</div>

> Learning rule: superspike

> Neuron type: Lif


$$\begin{align*}
            \dot{v} &= 1/\tau_{\text{mem}} (v_{\text{leak}} - v + i) \\
            \dot{i} &= -1/\tau_{\text{syn}} i
        \end{align*}
$$ 

In [6]:
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt  # for analys
from cgtasknet.net.lifrefrac import SNNLifRefrac
from cgtasknet.tasks.tasks import MultyTask
from norse.torch.functional.lif_refrac import LIFRefracParameters
from norse.torch.functional.lif import LIFParameters

## Step -1: Select device

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device('cpu')
print(f'Device: {("gpu (cuda)" if device.type=="cuda" else "cpu")}')

Device: gpu (cuda)


In [8]:
batch_size = 10
number_of_tasks = 10
task_list = [("WorkingMemory", dict()), (("ContextDM", dict()))]
tasks = dict(task_list)
Task = MultyTask(tasks=tasks, batch_size=15)

## Step 1.1: Create model

In [9]:
feature_size, output_size = Task.feature_and_act_size[0]
hidden_size = 200

neuron_parameters = LIFRefracParameters(LIFParameters())
model = SNNLifRefrac(
    feature_size, hidden_size, output_size, neuron_parameters=neuron_parameters
).to(device)

## Step 1.2: Save pre-learning weights

In [10]:
weights_pre_l = []
with torch.no_grad():
    for name, param in model.named_parameters():
        if param.requires_grad:
            weights_pre_l.append((param).cpu().numpy())

## Step 2: loss and creterion 

In [11]:
learning_rate = 5e-2
criterion = nn.MSELoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.8, 0.85))
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [12]:
inputs, target_outputs = Task.dataset(number_of_tasks)
inputs += np.random.normal(0, 0.01, size=(inputs.shape))
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)

## Step 3: Train loop

In [13]:
%matplotlib
plt.ion
fig = plt.figure()
ax = fig.add_subplot(111)
fig2 = plt.figure()
ax2 = fig2.add_subplot(111)
inputs, target_outputs = Task.dataset(number_of_tasks)
(line1,) = ax.plot(np.arange(0, len(target_outputs)), target_outputs[:, 0, 1], "b-")
(line2,) = ax.plot(np.arange(0, len(target_outputs)), target_outputs[:, 0, 2], "r-")
(line3,) = ax.plot(np.arange(0, len(target_outputs)), target_outputs[:, 0, 1], "b-")
(line4,) = ax.plot(np.arange(0, len(target_outputs)), target_outputs[:, 0, 2], "r-")
(line21,) = ax2.plot(np.arange(0, len(target_outputs)), target_outputs[:, 0, 1], "b-")
(line22,) = ax2.plot(np.arange(0, len(target_outputs)), target_outputs[:, 0, 2], "r-")
(line23,) = ax2.plot(np.arange(0, len(target_outputs)), target_outputs[:, 0, 1], "b-")
(line24,) = ax2.plot(np.arange(0, len(target_outputs)), target_outputs[:, 0, 2], "r-")
ax.set_ylim([-0.5, 1.5])
ax.set_xlim([0, 20000])
ax2.set_ylim([-0.5, 1.5])
ax2.set_xlim([0, 20000])
running_loss = 0
fig.canvas.draw()
fig.canvas.flush_events()
fig2.canvas.draw()
fig2.canvas.flush_events()
ax.set_title("LIF")
ax2.set_title("LIF")
for i in range(2000):
    inputs, target_outputs = Task.dataset(number_of_tasks)
    inputs += np.random.normal(0, 0.01, size=(inputs.shape))
    inputs = torch.from_numpy(inputs).type(torch.float).to(device)
    target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs, states = model(inputs)

    loss = criterion(outputs, target_outputs)
    loss.backward()
    optimizer.step()

    # print statistics
    running_loss += loss.item()
    if i % 10 == 9:
        print("epoch: {:d} loss: {:0.5f}".format(i + 1, running_loss / 10))
        running_loss = 0.0
        with torch.no_grad():
            inputs, target_outputs = Task.dataset(number_of_tasks)

            inputs = torch.from_numpy(inputs).type(torch.float).to(device)
            target_outputs = (
                torch.from_numpy(target_outputs).type(torch.float).to(device)
            )
            outputs, states = model(inputs)
            loss = criterion(outputs, target_outputs)

            print("test loss: {:0.5f}".format(loss.item()))
        for_plot = outputs.detach().cpu().numpy()[:, 0, :]
        print(len(for_plot))
        line1.set_xdata(np.arange(0, len(for_plot), 1))
        line2.set_xdata(np.arange(0, len(for_plot), 1))
        line3.set_xdata(np.arange(0, len(for_plot), 1))
        line4.set_xdata(np.arange(0, len(for_plot), 1))
        line21.set_xdata(np.arange(0, len(for_plot), 1))
        line22.set_xdata(np.arange(0, len(for_plot), 1))
        line23.set_xdata(np.arange(0, len(for_plot), 1))
        line24.set_xdata(np.arange(0, len(for_plot), 1))

        line1.set_ydata(for_plot[:, 1])
        line2.set_ydata(for_plot[:, 2])
        line3.set_ydata(target_outputs.detach().cpu().numpy()[:, 0, 1])
        line4.set_ydata(target_outputs.detach().cpu().numpy()[:, 0, 2])

        line21.set_ydata(for_plot[:, 3])
        line22.set_ydata(for_plot[:, 4])
        line23.set_ydata(target_outputs.detach().cpu().numpy()[:, 0, 3])
        line24.set_ydata(target_outputs.detach().cpu().numpy()[:, 0, 4])

    fig.canvas.draw()
    fig.canvas.flush_events()
    fig2.canvas.draw()
    fig2.canvas.flush_events()


print("Finished Training")

Using matplotlib backend: TkAgg
epoch: 10 loss: 0.13755
test loss: 0.11416
20495
epoch: 20 loss: 0.11318
test loss: 0.11389
20394


KeyboardInterrupt: 

invalid command name "140023209386048delayed_destroy"
    while executing
"140023209386048delayed_destroy"
    ("after" script)
