In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [None]:
plt.style.use('dark_background')

In [None]:
def form_perceptron(layers, activation_function = None, input_shape = None):
    if activation_function is None:
        activation_function = 'leaky_relu'
    model = tf.keras.Sequential([tf.keras.layers.Dense(l, activation_function) for l in layers])
    if input_shape is not None:
        model.build(input_shape)
    return model

class activity_container:
    def __init__(self, activity, time_interval, ignore = None, optimizer = None):
        if ignore is None:
            ignore = []
        self.time_interval = time_interval
        self.resolution = activity.shape[0]
        self.optimizer = tf.keras.optimizers.Adam() if optimizer is None else optimizer
        self.activity = activity

    @tf.function
    def __call__(self, t):
        pass

    @property
    def params(self):
        pass

    @tf.function
    def loss(self):
        time = tf.reshape(tf.linspace(self.time_interval[0], self.time_interval[1], self.resolution), (self.resolution, 1))
        return tf.reduce_mean((self(time) - self.activity)**2)

    def optimize(self, epochs):
        LH = []
        for _ in range(epochs):
            with tf.GradientTape() as tape:
                loss_value = self.loss()
                grads = tape.gradient(loss_value, self.params)
                self.optimizer.apply_gradients(zip(grads, self.params))
            LH.append(loss_value.numpy())
        return LH

class model_container(activity_container):
    def __init__(self, model, activity, time_interval, ignore=None, optimizer=None):
        super().__init__(activity, time_interval, ignore, optimizer)
        self.model:tf.keras.Model = model

    @property
    def params(self):
        return self.model.trainable_variables
    
    @tf.function
    def __call__(self, time):
        return self.model(time)

class furie_container(model_container):
    def __init__(self, activity, time_interval, ignore=None, optimizer=None):
        m = tf.keras.layers.Dense(activity.shape[1], tf.sin, bias_initializer='random_uniform')
        m.build(1)
        super().__init__(m, activity, time_interval, ignore, optimizer)


In [None]:
class activity_restorer:
    def __init__(self, weights, container, indexes_to_restore):
        self.weights = tf.Variable(weights)
        self.container:activity_container = container
        self.indexes_to_restore = tf.constant(indexes_to_restore, tf.int32)

    @tf.function
    def dx_dt(self, time):
        self.activity_function(time) @ self.weights @ self.activity_function(time)

    @tf.function
    def da_dt(self, time):
        with tf.GradientTape() as tape:
            tape.watch(time)
            grad = tape.gradient(time, self.activity_function(time))
        return grad

    @tf.function
    def loss_activity_dire(self):
        loss = 0.0
        for ti in range(self.time_count):
            t = self.time[ti]
            loss += tf.reduce_mean((self.dx_dt(ti) - self.da_dt(tf.reshape(t, (1, 1))))**2)
        return loss

    @tf.function
    def grad_step(self):
        with tf.GradientTape() as tape:
            loss = self.absolute_activity_loss()


In [None]:
class solver:
  def __init__(self, dt = 0.1, **rules):
    self.rules = rules
    self.params = rules.keys()
    self.dt = dt

  def update(self):
    self.set_to(**{name:getattr(self, name) + self.rules[name](**{name:getattr(self, name) for name in self.rules})*self.dt for name in self.rules})

  def set_to(self, **values):
    for j in values:
      setattr(self, j, values[j])

  def solve(self, t_end, t_start = 0):
    self.t = t_start
    history = {name:[] for name in self.rules}
    history['t'] = []
    while self.t < t_end:
      self.update()
      for j in history:
        history[j].append(getattr(self, j))
      self.t += self.dt
    delattr(self, 't')
    for j in history:
      history[j] = np.array(history[j])
    return history

class xwx(solver):
    def __init__(self, W, dt=0.1):
        self.W = W
        super().__init__(dt, x = lambda x: x @ self.W @ x)

def xwx_ddx_dt_dx(x, w):
    "dx/dt [t] respect to x[q] = xwx_dx_dt_dx(x, w, t)[i][q]"
    return w @ x + x @ w

def generate_connections(groups, group_complexity, group_density = 0.5, noise = 1, noise_density = 0.3, group_recurent_rate = 0.5):
    return \
    np.repeat(np.repeat((1-group_recurent_rate)*(np.random.sample([groups, groups]) < group_density) + group_recurent_rate*np.eye(groups, groups), group_complexity, 0), group_complexity, 1) * np.random.uniform(-1, 1, [group_complexity*groups]*2) + \
    noise * np.random.uniform(-1, 1, [group_complexity*groups]*2) * (np.random.sample([group_complexity*groups]*2) < noise_density) # skip connections

def generate_multiple_connections(C):
    N = C.shape[0]
    return (np.random.uniform(-1, 1, [N]*3)) * C

def generate_activity(W, t_end = 10, dt = 0.001):
    s = xwx(W, dt = dt)
    s.x = np.random.uniform(-1, 1, W.shape[0])
    h = s.solve(t_end)
    return h

def add_axis(a):
    return np.reshape(a, (-1, 1))

def generate_dataset(activity):
    dx_dt = (activity['x'][1:] - activity['x'][:-1]) / add_axis(activity['t'][1:] - activity['t'][:-1])
    xx = np.array([np.kron(x, x).flatten() for x in activity['x']])
    return xx[:-1], dx_dt

def plot_activity(d, var_count = 2):
    fig, ax = plt.subplots(var_count, 1)
    for n, i in enumerate(ax):
        i.plot(d['t'], d['x'][:, n])
    plt.show()

In [None]:
N = 10
w = np.random.uniform(-1, 1, (N, N, N))
w = 5*(w - w.T)
activity_t = generate_activity(w, 1)
plot_activity(activity_t, 3)

In [None]:
f = model_container(form_perceptron([20, 40, 50, activity_t['x'].shape[1]], input_shape=[1, 1], activation_function='leaky_relu'), activity_t['x'], (0, 1))

In [None]:
plt.plot(f.optimize(1000))
plt.show()

In [None]:
t = tf.linspace(0, 1, 100)
output = f(tf.reshape(t, (100, 1))).numpy()
plot_activity({'x':output, 't':t}, 3)
plot_activity(activity_t, 3)