# Kuramoto: Figs 6a, 6b, 9
The figures comparing energy and order paramter trajectories for kuramoto and feedback control.

Please make sure that the required data folder is available at the paths used by the script.
You may generate the required data by running the python script
```nodec_experiments/kuramoto/gen_parameters.py```.

Please also make sure that a training proceedure has produced results in the corresponding paths used in plot and table scripts.
Running ```nodec_experiments/ct_lti/single_sample/train.ipynb``` or ```nodec_experiments/kuramoto/train.ipynb``` with default paths is expected to generate at the requiered location for the plots and table scripts in each folder.

As neural network intialization is stochastic, please make sure that appropriate seeds are used or expect some variance to paper results.


## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch

In [None]:
import os
os.sys.path.append('../../../')

import math
from copy import deepcopy

import torch
from torchdiffeq import odeint
import numpy as np
import pandas as pd

import networkx as nx


import plotly.express as px
from plotly import graph_objects as go

from nnc.controllers.neural_network.nnc_controllers import NNCDynamics
from nnc.controllers.baselines.oscillators.dynamics import AdditiveControlKuramotoDynamics
from nnc.controllers.baselines.oscillators.optimal_controllers import KuramotoFeedbackControl

from nnc.helpers.torch_utils.graphs import adjacency_tensor, maximum_matching_drivers, drivers_to_tensor
from nnc.helpers.torch_utils.oscillators import order_parameter_cos
from nnc.helpers.torch_utils.numerics import faster_adj_odeint
from nnc.helpers.plot_helper import ColorRegistry, base_layout
from nnc.helpers.torch_utils.evaluators import FixedInteractionEvaluator

from tqdm.notebook import tqdm

In [None]:
# Training parameters, such as device, float precision and whether a pre-trained model is used.
device = 'cpu'
dtype = torch.float


## Loading data and parameters

In [None]:
# Loading Parameters for the graph
data_folder = '../../../../data/parameters/kuramoto/'
graph = 'erdos_renyi'
result_folder = '../../../../data/results/kuramoto/' # if you have new results don't forget to put them here.
graph_folder = data_folder + graph + '/'

A = torch.load(graph_folder + 'adjacency.pt',  map_location=device).float() # adjacency matrix
G = nx.from_numpy_matrix(A.numpy())
n_nodes = G.number_of_nodes()
mean_degree = np.mean(list(dict(G.degree()).values()))

A = A.to(device, dtype) # adjacency
L = A.sum(-1).diag() - A # laplacian

In [None]:
# Load dynamics dependendent variables and states
coupling_constants = torch.load(data_folder + 'coupling_constants.pt', map_location=device).to(device, dtype)
frustration_constants = torch.load(data_folder + 'frustration_constants.pt', map_location=device).to(device, dtype)
natural_frequencies = torch.load(data_folder + 'nominal_angular_velocities.pt', map_location=device).to(device, dtype)
K = coupling_constants[2].item() # coupling constant, index 2 should be 0.4
frustration_constant = frustration_constants[0] # we use no frustration for this example
dynamics_params_folder = graph_folder + 'dynamics_parameters/coupling_' + '{:.1f}'.format(K) + '/'


x0 = torch.load(data_folder + 'single_init.pt', map_location=device).to(device=device, dtype=dtype)


# to avoid using extra memory we load the driver vector and use element-wise multiplication instead of the driver matrix.
gain_vector = torch.load(dynamics_params_folder + 'driver_vector.pt', map_location=device).to(device, dtype)
driver_nodes = torch.nonzero(gain_vector).cpu().numpy().flatten().tolist()
driver_percentage = len(driver_nodes)/len(gain_vector)
steady_state = torch.load(dynamics_params_folder + 'steady_state.pt', map_location=device).to(device, dtype)


In [None]:
#  Controller parameters
# Feedback Control
feedback_control_constant = 10

# Neural Network training
n_hidden_units = 3
batch_size = 8
epochs = 20

In [None]:
print('Current experiment info:')
print('\t Loaded ' + graph + 'graph with: ' + str(n_nodes) + ' nodes and ' + str(G.number_of_edges()) + ' edges.' )
print('\t Coupling Constant: ' + str(K))
print('\t Frustration Constant: ' + str(frustration_constant.item()))
print('\t Natural Frequencies: mean: ' + str(natural_frequencies.mean().item()) + ' variance: ' + str(natural_frequencies.var().item()) )
print('\t Ratio of driver node vs total nodes: '  + str(len(driver_nodes)/n_nodes))
print('\t Feedback Control Constant: '  + str(feedback_control_constant))


In [None]:
# Generating the dynamics:
dyn = AdditiveControlKuramotoDynamics(
    A, 
    K, 
    natural_frequencies,
    frustration_constant=frustration_constant
).to(device)

In [None]:
class EluFeedbackControl(torch.nn.Module):
    """
    Very simple Elu architecture for control of linear systems
    """
    def __init__(self, n_nodes, n_drivers, driver_matrix, n_hidden=3):
        super().__init__()
        self.linear = torch.nn.Linear(n_nodes,n_hidden)
        self.linear_h1 = torch.nn.Linear(n_hidden, n_hidden)
        self.linear_final = torch.nn.Linear(n_hidden, n_drivers)
        self.driver_matrix = driver_matrix

    def forward(self, t, x):
        """
        :param t: A scalar or a batch with scalars
        :param x: input_states for all nodes
        :return:
        """     
        u = self.linear(torch.sin(x))
        u = torch.nn.functional.elu(u)
        u = self.linear_h1(u)
        u = torch.nn.functional.elu(u)
        u = self.linear_final(u)
        # we multiply by the nn driver matrix to generate the control signal
        u = (self.driver_matrix@u.unsqueeze(-1)).squeeze(-1)
        return u

In [None]:
# We convert the driver vector back to a matrix and convert the non-zero elements to 1, so that the neural network is agnostic of the exact gain values.
driver_matrix = drivers_to_tensor(A.shape[-1], driver_nodes).to(dtype=dtype, device=device)

In [None]:
neural_net = EluFeedbackControl(n_nodes, len(driver_nodes), driver_matrix.cpu(), n_hidden=n_hidden_units).to(dtype=dtype, device=device)
neural_net.load_state_dict(torch.load( result_folder+ graph + '/' + 'trained_model.pt', map_location=device))

In [None]:
evaluation_steps = 5000
evalu = FixedInteractionEvaluator(
    exp_id='kuramoto_er',
    log_dir=None,
    n_interactions= evaluation_steps, # neither control works consistently for less than 2-3k steps
    loss_fn=lambda t,x: torch.tensor(order_parameter_cos(x[-1].cpu().detach())).mean(),
    ode_solver=None,
    ode_solver_kwargs={},
    preserve_intermediate_states=True,
    preserve_intermediate_controls=True,
    preserve_intermediate_times=True,
    preserve_intermediate_energies=True,
    preserve_intermediate_losses=True,
    preserve_params=False,
    preserve_init_loss = True
)

nn_contorl_fun = lambda t,x: neural_net(t, x)
nnres = evalu.evaluate(dyn, 
                       nn_contorl_fun, 
                       x0.cpu(), 
                       150, 
                       -1
                      )


In [None]:
cont = lambda t,x: feedback_control_constant*gain_vector.cpu()*torch.sin(-x)
contres = evalu.evaluate(dyn, 
                         cont, 
                         x0.cpu(), 
                         150, 
                         -1
                        )

In [None]:
contres['final_loss']

In [None]:
nnres['final_loss']

In [None]:
contres['all_losses'][1:] - nnres['all_losses'][1:]

In [None]:
contres['total_energy']
nnres['total_energy']

## Figure 6a
Energy comparison between feedback control baselines and NODEC.

In [None]:
oc_r = px.line(y=contres['all_energies'].squeeze().numpy(), 
               x=contres['all_times'][0].numpy(), log_y=True, render_mode='svg')
oc_r.data[0].line.color = ColorRegistry.oc
oc_r.data[0].name = 'FC'
oc_r.data[0].showlegend = True
nn_r = px.line(y=nnres['all_energies'].squeeze().numpy(),  
               x=nnres['all_times'][0].numpy(), log_y=True, 
               render_mode='svg')
nn_r.data[0].line.color = ColorRegistry.nodec
nn_r.data[0].name = 'NODEC'
nn_r.data[0].showlegend = True
oc_r.add_trace(nn_r.data[0])
oc_r.update_layout(base_layout)
oc_r.layout.xaxis.title = 'Time'
oc_r.layout.yaxis.title = 'Total Energy'
oc_r.layout.yaxis.exponentformat = 'power'
oc_r.layout.width = 165
oc_r.layout.height =150
oc_r.update_layout(legend=dict(
                                        orientation="h",
                                  font = dict(size=8),
                                  x=0.3,
                                  y=0.4,                                
                                  bgcolor="rgba(0,0,0,0)",
                                  bordercolor="Black",
                                  borderwidth=0
                                  ),
                   margin = dict(t=0,b=20,l=20,r=0)
                  )
oc_r

In [None]:
# Generating a trajectory without control
tlin = torch.linspace(0, 150, 500)
state_trajectory_noc = odeint(lambda t,y: dyn(t,y,u=None),x0, tlin, method='dopri5')
y=order_parameter_cos(state_trajectory_noc.squeeze().cpu())
fig_noc = px.line(y=y.cpu().numpy(), x=tlin.cpu().numpy(),width=600, height=300)
fig_noc.data[0].name = 'No control'
fig_noc.data[0].line.color = ColorRegistry.constant
fig_noc.data[0].showlegend = True
fig_noc.layout.xaxis.title.text = 'Time'
fig_noc.layout.yaxis.title.text = '$r(t)$'
fig_noc

## Figure 6b
Loss comparison between NODEC and feedback control baseline.

In [None]:
oc_er = px.line(y=contres['all_losses'].squeeze().numpy(),
                x=torch.cat([torch.zeros([1]), 
                             contres['all_times'][1]]).numpy(), 
                log_y=True,
                render_mode='svg')
oc_er.data[0].line.color = ColorRegistry.oc
oc_er.data[0].name = 'FC'
oc_er.data[0].showlegend = True
nn_er = px.line(y=nnres['all_losses'].squeeze().numpy(), 
                x=torch.cat([torch.zeros([1]), 
                             nnres['all_times'][1]]).numpy(), 
                log_y=True,
                 render_mode='svg'
               )
nn_er.data[0].line.color = ColorRegistry.nodec
nn_er.data[0].name = 'NODEC'
nn_er.data[0].showlegend = True
fig = go.Figure([oc_er.data[0], nn_er.data[0], fig_noc.data[0]])

fig.update_layout(base_layout)
fig.layout.xaxis.title = 'Time'
fig.layout.yaxis.title = 'Order Parameter'
fig.layout.width = 165
fig.update_yaxes(nticks=6)
fig.update_yaxes(tick0=0.2, dtick=0.15)
fig.layout.height =150
fig.update_layout(legend=dict(
                                        orientation="h",
                                  font = dict(size=8),
                                  x=0.2,
                                  y=0.96,                                
                                  bgcolor="rgba(0,0,0,0)",
                                  bordercolor="Black",
                                  borderwidth=0
                                  ),
                   margin = dict(t=0,b=20,l=20,r=0)
                  )
fig

## Fig 9
Here we select a time interval close to $t=0$ on previous figure 6b to produce figure 9

In [None]:
fig2 = go.Figure([oc_er.data[0], nn_er.data[0]])
fig2.update_layout(base_layout)
fig2.layout.xaxis.title = 'Time'
fig2.layout.xaxis.range = [0,1]
fig2.layout.yaxis.title = 'Order Parameter'
fig2.layout.width = 165
fig2.update_yaxes(nticks=6)
fig2.update_yaxes(tick0=0.2, dtick=0.15)
fig2.layout.height =150
fig2.update_layout(legend=dict(
                                        orientation="h",
                                  font = dict(size=8),
                                  x=0.4,
                                  y=0.35,                                
                                  bgcolor="rgba(0,0,0,0)",
                                  bordercolor="Black",
                                  borderwidth=0
                                  ),
                   margin = dict(t=0,b=20,l=20,r=0)
                  )
fig2