In [None]:
import numpy as np
import networkx as nx
import scipy.integrate as odeint
import kuramoto as kr
import matplotlib.pyplot as plt
import math 

def cyclic_rgb(value):
    """
    Generates a cyclic RGB color value for values between 0 and 2*pi.

    Args:
        value (float): A value between 0 and 2*pi.

    Returns:
        tuple: A tuple containing the red, green, and blue values between 0 and 255.
    """
    red = (128 + 127 * np.sin(value))/255
    green = (128 + 127 *np.sin(value + 2*math.pi/3))/255
    blue = (128 + 127 * np.sin(value + 4*math.pi/3))/255

    return (red)#, green, blue)
# class Oscillator:
#     def __init__(self,nat):
#         self.nat = nat  # nat freq of oscillator
        
#     def update(self,int): # function takes in a interaction 
''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
class Net:
    def __init__(self,Adj,n,k,dt,var,mean):
        self.Adj = Adj#Adjacency with weighted and dircetional couplings
        self.n = n
        self.dt = dt
        self.k = k
        self.space = [] #3 dimensional vectors in node based vector. 
        # NEEDS GENERATOR
        self.state = np.random.uniform(-2*np.pi,2*np.pi,(n,1)) # nx1 vector of current oscillator  states
        print('init state = {}'.format(self.state))
        self.nats =  2*np.pi*dt*np.random.normal(mean,var,(n,1)) # Vector of natural freqs, currently all unity for simplicity. 
        print('init nats = {}'.format(self.nats))
        self.I = np.zeros((n,1))
        self.noise_A = 0.5*np.sqrt(dt)
        self.graph = nx.from_numpy_array(self.Adj)
        self.pos = nx.kamada_kawai_layout(self.graph)
        
    def Gauss_Space(self,space_mean,space_var):
        self.space = np.random.normal(space_mean,space_var,(self.n,2))
        self.pos = self.space
 
        
        
    def Connect(self,a,P_inhib):
        # -a is power law exponent
        for i in range(self.n-1):
            for j in range(i+1,self.n): 
                if j != i:
                    pool = np.zeros([10000,1])
                    inhib_pool = np.ones([10000,1])
                    xi = self.space[i,:]
                    xj = self.space[j,:]
                    s = np.abs(xi-xj)
                    mag = np.sqrt(s[0]**2 + s[1]**2)# + s[2]**2 )
                    prob_connect = 1/(mag**a+1) # connection law. +1 is offset to normalise ## CHANGE to control connectivity
                    connect_num = np.round(len(pool)*prob_connect)
                    pool[1:int(connect_num)] = 1
                    np.random.shuffle(pool)
                    inhib_num = np.round(len(pool)*P_inhib)
                    pool[1:int(inhib_num)] = -1
                    np.random.shuffle(inhib_pool)
                    self.Adj[i,j] = (1/mag)*pool[np.random.randint(0,len(pool))]*inhib_pool[np.random.randint(0,len(pool))]
                    self.Adj[j,i] = (1/mag)*pool[np.random.randint(0,len(pool))]*inhib_pool[np.random.randint(0,len(pool))]
                else:
                    continue
        
    def Update(self):
    
        for i in range(self.n):
            I_n = 0
            for j in range(self.n):
                if self.Adj[i,j] != 0:
                    I_n += self.Adj[i,j]*np.sin(self.state[j] - self.state[i])*self.dt # update value of phase
                if i != j:    
                   self.Adj[i,j] += np.cos(self.state[j] - self.state[i])*self.dt # update coupling
            self.I[i] = I_n 
        #print('pre state = {}'.format(self.state))
        self.state += self.nats + self.I  #+ self.noise_A*np.random.randn(self.n,1)
        #print('post state = {}'.format(self.state))
        
    def View(self):
        self.graph = nx.from_numpy_array(self.Adj)
        #edge colours should be realted to normalised weights 
        
        #node colours should be current phase
        color_map = []
        for i in range(self.n):
            color_map.append(cyclic_rgb(self.state[i]))
        print(color_map)  
        Order = np.arange(0, self.n, 2)
        Order = np.hstack((Order, np.arange(self.n-1, 0, -2)))
        
        fig, axs = plt.subplots(1, 2, figsize=(15, 8))
        ax = axs[0]
        # colormap(jet)
        ax.pcolormesh(self.Adj[Order, :][:, Order])

        ax.set_xlabel('Nodes')
        ax.set_ylabel('Nodes')
        ax.set_title('Coupling Matrix')
        ax.axis('square')
        
        nx.draw(self.graph,node_color=color_map,pos=self.pos) # BROKEN WITH EVOLUTIONS 
        
''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''       
        # print('th0 {}'.format(self.state))
        # self.sins = np.sin(np.subtract.outer(self.state, np.transpose(self.state))) #sines of differences
        # #print(np.outer(self.state, np.transpose(self.state)))
        # #print('sin {}'.format(np.sin(np.subtract.outer(self.state, np.transpose(self.state)))))
        
        # #print('sins {}'.format(self.sins))
        # self.I = np.sum(np.multiply(self.Adj,self.sins), axis=0) # I is sum of contributions on each oscillator
        # print(np.multiply(self.Adj,self.sins))
        # print('I {}'.format(self.I))
        # self.state += self.nats + self.I
        # #print('nat = {}'.format(self.nats*dt))
        # #print('th1 {}'.format(self.state))
       
''' 
Can't interact with adjacency here 
''' 
    # def theta_dot(self,t,angles_vec):
    #     angles_i, angles_j = np.meshgrid(angles_vec, angles_vec)
    #     I = self.Adj* np.sin(angles_j - angles_i)
    #     dxdt = self.nats + I.sum(axis=0)
    #     return dxdt
    
    # def Run(self,dt,T):
    #     t = np.linspace(0, self.T, int(self.T/self.dt))
    #     timeseries = odeint(self.derivative, self.state, t)
    #     return timeseries
        
        


# Looking at source nodes


In [None]:
N = 30
k = 3.5
connectivity = 0.45

graph = np.array([[0,0,0,0], # No connections into source
         [1,0,2,1],
         [1,1,0,2],
         [1,2,0,0]])
N = np.shape(graph)[0]
T = 1
res = 10000
time = np.linspace(0,T,res)
dt = T/res
results = np.zeros((res,N))
evo_adj = np.zeros((N,N,res))
#graph_nx = nx.erdos_renyi_graph(n=N, p=connectivity) # p=1 -> all-to-all connectivity
#graph = nx.to_numpy_array(graph_nx)*k

sourcenet = Net(graph,N,k,dt,var,mean)
sourcenet.View()
