# `nn.ipynb`
This file aims to use a neural network to solve the multiple-peaks problem. There is information in the peaks and positions of the nodes, but we haven't been able to draw it out yet. Maybe a neural network can.

## NN specs
* Input
    * Positions of N nodes
    * Peaks of N-1 MIs, the peak closest to zero
    * Period: avg period of each signal using MI between signal and self
    * Total: 3N inputs
* Output
    * N-1 peaks/time delays

More details below...

Liberties (assumptions) taken:
* When finding the period, you need an upper bound. This could be deduced using the speed and the width of the space, but I am just using a fraction of the signal length. I happen to know that one fifth of the length of the signal (`2000 // 5 = 400`, or `1600 // 6 = 320`) always contains the period. 

Other notes:
* This is the first time I am implementing a randomization of the target location in FN data, of a sort. There will be a bounding box with a minimum size of 50 in each direction, where all the target nodes have to live inside. This way, we do not need to regenerate the FN data each time, but the target is still at a random position relative to the nodes. 

In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
from scipy.signal import find_peaks

from matplotlib import pyplot as plt

from tqdm import tqdm

from WSN import *
import mutual_information as mtin

In [2]:
mtin.default_fn_equ_params["dt"] = 0.01
mtin.default_fn_equ_params["T"] = 20_000
mtin.default_fn_equ_params['stim'] = [[[250, 350], [45, 65], [45, 65]]]
start_frame = int(40 / mtin.default_fn_equ_params["dt"])
mtin.default_fn_equ = FHN(**mtin.default_fn_equ_params)

In [3]:
if mtin.default_fn_sol is None:
    mtin.default_fn_sol = np.load("default_fn_sol_dt_0.01.npy")

In [4]:
# mtin.solve_default_fn_equ()

In [5]:
# np.save("default_fn_sol_dt_0.01.npy", mtin.default_fn_sol)

In [30]:
# Generate signals and find all peaks
def get_fn_peaks_inside_period(self: WSN, r0=None, use_mst=False):
    if r0 is None:
        r0 = np.array([55, 55])
    results = self.transmit_continuous(signal_type="fn", start_frame=start_frame)
    assert np.all(np.array([result[0] for result in results]) == self.nodes)

    self.printv("Finding period")
    avg_period = 0
    for i, sigi in results:
        # This is arbitrary, but I know this will work for this case
        max_tau = len(sigi) // 5
        shifts = np.arange(-max_tau, max_tau, 1)
        mis = mi_shift(np.transpose([sigi, sigi]), shifts)

        mis = np.array(mis)
        max_peak = mis.max()
        peaks, _ = find_peaks(mis)
        peaks = peaks[np.argsort(mis[peaks])]
        peaks = peaks[-5:]
        peaks.sort()
        period1 = peaks[1] - peaks[0]
        period2 = peaks[2] - peaks[1]
        period = (period1 + period2) / 2
        dt = mtin.default_fn_equ_params["dt"]
        period *= dt
        avg_period += period
    avg_period /= len(results)
    self.printv("Period found to be", avg_period)

    all_peaks = []   # [(rn, rm, p, p0), ...]
    pair_to_ind = {} # {(n, m): i}

    # For now, tree must be a list so that the neural network
    # sees the peaks in the same order every time
    if use_mst:
        tree = list(self.find_MST())
    else:
        tree = [
            (None, 0, i)
            for i in range(1, len(self.nodes))
        ]
    for _, i, j in tree:
        ri, sigi = results[i]
        rj, sigj = results[j]
        dt = mtin.default_fn_equ_params["dt"]
        max_tau = int((2 * avg_period) / dt)
        shifts = np.arange(-max_tau, max_tau, 1)
        mis = mtin.mi_shift(np.transpose([sigi, sigj]), shifts)

        mis = np.array(mis)
        max_peak = mis.max()
        inclusion_factor = 0.5
        peaks, _ = find_peaks(mis, height=max_peak * inclusion_factor)
        if self.verbose and np.random.uniform(0, 1) < 0.1:
            plt.plot(shifts, mis)
            plt.show(block=False)
        if len(peaks) == 0:
            plt.plot(shifts * mtin.default_fn_equ_params["dt"], mis)
            plt.show(block=False)
            p = (dist(ri, r0) - dist(rj, r0)) / self.c
            raise ValueError(
                f"No peaks found between nodes {i} and {j} "
                f"at positions {ri} and {rj}.\n"
                f"The peak should be at time {p}."
            )
        peaks = shifts[peaks]
        # p0 = peak closest to 0
        p0 = peaks[np.argmin(np.abs(peaks))]

        # p = ideal peak
        p = (dist(ri, r0) - dist(rj, r0)) / self.c
        pair_to_ind[(i, j)] = len(all_peaks)
        all_peaks.append((ri, rj, p, p0))
    return all_peaks, pair_to_ind, avg_period


In [31]:
[setattr(WSN, attr, globals()[attr]) for attr in (
    "get_fn_peaks_inside_period",
)]

[None]

Create testing and training data

Input format:
* `x0, y0, x1, y1, ...`   (normalized by `wsn.size`)
* `p0, p1, p2, ...`       (normalized by period)
* `period`

Output format:
* `p0, p1, p2, ...`       (normalized by period)

In [2]:
N = 12
c = 1.6424885622140555
wsn = WSN(100, N, D=142, std=0, c=c, verbose=False)
wsn.reset_anchors(range(N))

In [9]:
def get_input_and_output(wsn: WSN):
    all_peaks, pair_to_ind, period = wsn.get_fn_peaks_inside_period()
    # Input format:
    # [
    #   wsn.nodes.flatten() -- len = 2N,
    #   peaks -- len = N-1,
    #   period -- len = 1
    # ]

    # Output format:
    # peaks -- len = N-1
    input = []
    output = []
    for ri, rj, p, p0 in all_peaks:
        input.append(p0 / period)
        output.append(p / period)
    input.append(period)
    input = np.concatenate((wsn.nodes.flatten() / 100, input))
    output = np.array(output)

    return input, output


In [4]:
data_size = 1000

In [13]:
data_size = 100
# input_shape = (data_size, 3 * N)
# output_shape = (data_size, N - 1)
input = [0] * data_size
output = [0] * data_size
for j in tqdm(range(data_size)):
    error = True
    while error:
        error = False
        try:
            wsn.reset_nodes()
            i, o = get_input_and_output(wsn)
            input[j] = i
            output[j] = o
        except ValueError as e:
            print(f"ValueError: {e}")
            error = True
input = np.array(input)
output = np.array(output)


  mi = -0.5 * np.log(np.linalg.det(cor))
100%|██████████| 100/100 [46:51<00:00, 28.12s/it]


In [53]:
def load_data(data_size1, data_size2, data_size3):
    input1 = np.load(f"nn_saves/copy/input{data_size1}.npy", allow_pickle=True)
    output1 = np.load(f"nn_saves/copy/output{data_size1}.npy", allow_pickle=True)
    input2 = np.load(f"nn_saves/copy/(2) input{data_size2}.npy", allow_pickle=True)
    output2 = np.load(f"nn_saves/copy/(2) output{data_size2}.npy", allow_pickle=True)
    input3 = np.load(f"nn_saves/copy/(3) input{data_size3}.npy", allow_pickle=True)
    output3 = np.load(f"nn_saves/copy/(3) output{data_size3}.npy", allow_pickle=True)
    input2 = input2[:data_size2]
    output2 = output2[:data_size2]
    input2 = np.array([*input2])
    output2 = np.array([*output2])
    print(*[obj.dtype for obj in (input1, output1, input2, output2, input3, output3)])
    # Corrections for old data
    input1 = input1[:, :-1]
    output1 = output1[:, :-1]
    input2 = input2[:, :-1]
    output2 = output2[:, :-1]
    input3 = input3[:, :-1]
    output3 = output3[:, :-1]
    input = np.concatenate((input1, input2, input3))
    output = np.concatenate((output1, output2, output3))
    return input, output
input, output = load_data(2000, 2000, 2000)

float64 float64 float64 float64 float64 float64


In [5]:
input = np.load(f"nn_saves/orig/input{data_size}.npy", allow_pickle=True)
output = np.load(f"nn_saves/copy/output{data_size}.npy", allow_pickle=True)

In [6]:
# Corrections for old data
input = np.array([x for x in input[:data_size]])
output = np.array([x for x in output[:data_size]])
input = input[:, :-1]
output = output[:, :-1]

In [7]:
input2 = np.load("nn_saves/copy/(2) input400.npy", allow_pickle=True)
output2 = np.load("nn_saves/copy/(2) output400.npy", allow_pickle=True)

In [8]:
# Corrections for old data
input2 = input2[:, :-1]
output2 = output2[:, :-1]

In [9]:
input2.shape, output2.shape

((400, 36), (400, 11))

In [10]:
input = np.concatenate((input, input2), axis=0)
output = np.concatenate((output, output2), axis=0)

In [54]:
print(input.shape)
print(output.shape)
data_size = len(input)

(6000, 36)
(6000, 11)


In [10]:
print(input[0])
print(output[0])

[ 6.67032026e-01  2.55447035e-01  6.84546894e-02  9.17007937e-01
  7.08423579e-02  6.83348825e-02  4.96216109e-01  1.74547307e-02
  8.79624750e-01  4.96301752e-01  3.06407044e-01  3.28620519e-01
  5.20289613e-01  1.18830827e-02  1.34066432e-01  6.09224732e-01
  8.44997969e-01  5.18230939e-01  6.72768384e-01  4.89914870e-01
  6.28993510e-01  6.15569032e-01  6.06160553e-01  3.81025015e-01
 -3.84529461e+01 -7.18793633e+01 -9.81848646e+01 -3.93744764e+00
 -2.20329517e+01 -9.72633343e+01 -4.44010053e+01  1.10583636e+01
  2.13627478e+01  4.02122312e+00 -2.93214186e+01  3.10827994e+00]
[-1.47154165 -1.84871252 -1.11346063 -0.08680584 -0.06227296 -1.13224162
 -0.52626109  0.10327194  0.91945981  1.09299102  0.70839955]


In [55]:
num_train_samples = data_size * 8 // 10
train_input = input[:num_train_samples]
train_output = output[:num_train_samples]
test_input = input[num_train_samples:]
test_output = output[num_train_samples:]

Create model...

In [56]:
import os
os.getpid()

1432

In [68]:
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(3*N,)),
    keras.layers.Dense(512, activation="linear"),
    keras.layers.Dense(1024, activation="relu"),
    keras.layers.Dense(1024, activation="relu"),
    keras.layers.Dense(2048, activation="relu"),
    keras.layers.Dense(1024, activation="relu"),
    keras.layers.Dense(1024, activation="relu"),
    keras.layers.Dense(512, activation="relu"),
    keras.layers.Dense(N-1, activation="linear")
])
model.compile(optimizer="adam", loss=keras.losses.mean_squared_error, metrics=["mean_squared_error"])

In [69]:
model.fit(train_input, train_output, validation_data=(test_input, test_output), batch_size=25, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x24caf7d2d60>

In [70]:
prediction = model.predict(test_input)
prediction[0], test_output[0]



(array([-0.00085952,  0.01251677, -0.01614039, -0.05552344, -0.11687545,
        -0.10616183, -0.00083554, -0.06193177, -0.045387  , -0.23937139,
         0.04486244], dtype=float32),
 array([ 0.30201001, -0.48214488,  0.30031751,  0.48871973, -0.88692996,
        -0.47003263, -1.30997936,  1.41865404, -0.10948085, -0.79822337,
        -0.97715453]))

In [71]:
from numba import jit

In [72]:
@jit(nopython=True)
def get_redundancies():
    b = 0
    redundancies = np.zeros(len(input), dtype=np.int64)
    for j in range(data_size):
        for k in range(j):
            if np.all(input[j] == input[k]):
                redundancies[b] = j
                b += 1
                break
    redundancies = redundancies[:b]
    return redundancies
def get_redundancy_info():
    redundancies = get_redundancies()
    # redundancies = np.array(redundancies)
    print(len(redundancies))
    slices = [[redundancies[0], redundancies[0] + 1]]
    for i, j in enumerate(redundancies):
        if i == 0:
            continue
        if j == slices[-1][1]:
            slices[-1][1] = j + 1
        else:
            slices.append([j, j + 1])
    for i, s in enumerate(slices):
        if s[1] == s[0] + 1:
            slices[i] = s[0]
    print(slices)