In [2]:
import torch 
from torchdiffeq import odeint
from torch import nn
import numpy as n


G = torch.rand(4,4)
print(torch.real(torch.linalg.eigvals(G @ G.T)))

tensor([4.3046+0.j, 0.7331+0.j, 0.1868+0.j, 0.0072+0.j])


The model

In [None]:
class kuramoto(nn.Module):  # dense_resnet_relu1,2,7

    def __init__(self,  adj_mat, coupling, n_nodes=4, natfreqs=None):
        super(kuramoto, self).__init__()
        self.adj_mat = adj_mat 
        self.coupling = coupling
        if natfreqs is not None:
            self.natfreqs = natfreqs
            self.n_nodes = len(natfreqs)
        else:
            self.n_nodes = n_nodes
            self.natfreqs = torch.randn((self.n_nodes,1))

       
        
    def forward(self, t, x):
    
        assert len(x) == len(self.natfreqs) == len(self.adj_mat), \
            'Input dimensions do not match, check lengths'

        print(x)
        angles_i, angles_j = np.meshgrid(x.float(), x.float())
        interactions = self.adj_mat * torch.sin(torch.tensor((angles_j - angles_i))  # Aij * sin(j-i)

        dxdt = self.natfreqs + self.coupling * interactions.sum(axis=0)  # sum over incoming interactions
        return dxdt

In [None]:


Adjacecy = torch.tensor([[0 ,1, 0 ,1],
                        [1, 0, 1 ,0],
                        [0 ,1, 0 ,1],
                        [1, 0, 1, 0]])


oscilator = kuramoto(adj_mat= Adjacecy, coupling = 1, n_nodes=4, natfreqs=None)

vec_angles = torch.randn(4,1)
oscilator(0.0,vec_angles)

In [None]:
class Kuramoto:

    def __init__(self, coupling=1, dt=0.01, T=10, n_nodes=None, natfreqs=None):
        '''
        coupling: float
            Coupling strength. Default = 1. Typical values range between 0.4-2
        dt: float
            Delta t for integration of equations.
        T: float
            Total time of simulated activity.
            From that the number of integration steps is T/dt.
        n_nodes: int, optional
            Number of oscillators.
            If None, it is inferred from len of natfreqs.
            Must be specified if natfreqs is not given.
        natfreqs: 1D ndarray, optional
            Natural oscillation frequencies.
            If None, then new random values will be generated and kept fixed
            for the object instance.
            Must be specified if n_nodes is not given.
            If given, it overrides the n_nodes argument.
        '''
        if n_nodes is None and natfreqs is None:
            raise ValueError("n_nodes or natfreqs must be specified")

        self.dt = dt
        self.T = T
        self.coupling = coupling

        if natfreqs is not None:
            self.natfreqs = natfreqs
            self.n_nodes = len(natfreqs)
        else:
            self.n_nodes = n_nodes
            self.natfreqs = torch.randn((self.n_nodes,1))

    def init_angles(self):
        '''
        Random initial random angles (position, "theta").
        '''
        return 2 * torch.pi * torch.randn(self.n_nodes)

    def derivative(self, angles_vec, t, adj_mat, coupling):
        '''
        Compute derivative of all nodes for current state, defined as

        dx_i    natfreq_i + k  sum_j ( Aij* sin (angle_j - angle_i) )
        ---- =             ---
         dt                M_i

        t: for compatibility with scipy.odeint
        '''
        assert len(angles_vec) == len(self.natfreqs) == len(adj_mat), \
            'Input dimensions do not match, check lengths'

        angles_i, angles_j = torch.meshgrid(angles_vec, angles_vec)
        interactions = adj_mat * torch.sin(angles_j - angles_i)  # Aij * sin(j-i)

        dxdt = self.natfreqs + coupling * interactions.sum(axis=0)  # sum over incoming interactions
        return dxdt

    def integrate(self, angles_vec, adj_mat):
        '''Updates all states by integrating state of all nodes'''
        # Coupling term (k / Mj) is constant in the integrated time window.
        # Compute it only once here and pass it to the derivative function
        n_interactions = (adj_mat != 0).sum(axis=0)  # number of incoming interactions
        coupling = self.coupling / n_interactions  # normalize coupling by number of interactions
        tol = 1e-7
        t = torch.linspace(0, self.T, int(self.T/self.dt))
        timeseries = odeint(self.derivative(angles_vec=angles_vec,t = t,adj_mat=adj_mat,coupling=coupling), angles_vec, t, rtol=tol, atol=tol, method="euler")
        return timeseries.T  # transpose for consistency (act_mat:node vs time)

    def run(self, adj_mat=None, angles_vec=None):
        '''
        adj_mat: 2D nd array
            Adjacency matrix representing connectivity.
        angles_vec: 1D ndarray, optional
            States vector of nodes representing the position in radians.
            If not specified, random initialization [0, 2pi].

        Returns
        -------
        act_mat: 2D ndarray
            Activity matrix: node vs time matrix with the time series of all
            the nodes.
        '''
        if angles_vec is None:
            angles_vec = self.init_angles()

        return self.integrate(angles_vec, adj_mat)

    @staticmethod
    def phase_coherence(angles_vec):
        '''
        Compute global order parameter R_t - mean length of resultant vector
        '''
        suma = sum([(torch.e ** (1j * i)) for i in angles_vec])
        return abs(suma / len(angles_vec))

    def mean_frequency(self, act_mat, adj_mat):
        '''
        Compute average frequency within the time window (self.T) for all nodes
        '''
        assert len(adj_mat) == act_mat.shape[0], 'adj_mat does not match act_mat'
        _, n_steps = act_mat.shape

        # Compute derivative for all nodes for all time steps
        dxdt = torch.zeros_like(act_mat)
        for time in range(n_steps):
            dxdt[:, time] = self.derivative(act_mat[:, time], None, adj_mat)

        # Integrate all nodes over the time window T
        integral = torch.sum(dxdt * self.dt, axis=1)
        # Average across complete time window - mean angular velocity (freq.)
        meanfreq = integral / self.T
        return meanfreq


In [None]:






oscilator = Kuramoto(coupling=1, dt=0.01, T=10, n_nodes=4)

Adjacecy = torch.tensor([[0 ,1, 0 ,1],
                        [1, 0, 1 ,0],
                        [0 ,1, 0 ,1],
                        [1, 0, 1, 0]])


oscilator.run(adj_mat=Adjacecy)