In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
import scipy.io as sio
from sympy import symbols, solve
import trimesh
import math
from scipy.io import loadmat
import random
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial.transform import Rotation as R

def piecewise_random_path(bifurcation):


    data_curves = loadmat('curves.mat')
    data_curves_2 = loadmat('curves2.mat')
    data_bif = loadmat('bifurcation_crossing.mat')


    Homoclinic_to_saddle3 = data_bif['Homoclinic_to_saddle3']
    Homoclinic_to_saddle2 = data_bif['Homoclinic_to_saddle2']
    Homoclinic_to_saddle1 = data_bif['Homoclinic_to_saddle1']
    Homoclinic_to_saddle = data_bif['Homoclinic_to_saddle']
    Fold = data_bif['Fold']
    Hopf = data_bif['Hopf']
    SNIC = data_curves_2['SNIC']

    # Initialize points based on the bifurcation value
    if bifurcation == 3:
        # Fixed rest point
        p0 = Hopf[:, 929]  # MATLAB indexing starts at 1
        # Bifurcation curve
        p1 = Fold[:, np.random.randint(144, 170)]
        p1_5 = np.array([0.3196, 0.2389, -0.0279])
        # Bifurcation curve
        p2 = Hopf[:, np.random.randint(599, 750)]
        # Fixed rest
        p3 = np.array([0.1944, 0.0893, 0.3380])

    elif bifurcation == 7:
        # Fixed rest point
        p0 = Fold[:, 399]
        # Bifurcation curve
        p1 = SNIC
        # Random point in limit cycle
        p1_5 = np.array([0.1314, 0.3298, -0.1843])
        # Bifurcation curve
        p2 = Hopf[:, np.random.randint(599, 750)]
        # Fixed rest
        p3 = np.array([0.1944, 0.0893, 0.3380])

    elif bifurcation == 9:
        # Fixed rest point
        p0 = np.array([0.1944, 0.0893, 0.3380])
        # Bifurcation curve
        p1 = Hopf[:, np.random.randint(599, 750)]
        # Random point in limit cycle
        p1_5 = np.array([-0.0441, 0.2591, -0.3015])
        # Bifurcation curve
        p2 = SNIC
        # Fixed rest
        p3 = Fold[:, 449]

    elif bifurcation == 10:
        # Fixed rest point
        p0 = np.array([0.1944, 0.0893, 0.3380])
        # Bifurcation curve
        p1 = Hopf[:, np.random.randint(599, 750)]
        # Random point in limit cycle
        p1_5 = np.array([-0.1237, 0.3388, -0.1729])
        # Bifurcation curve
        p2 = Homoclinic_to_saddle
        # Fixed rest
        p3 = Fold[:, 399]

    elif bifurcation == 11:
        # Fixed rest point
        p0 = np.array([0.1944, 0.0893, 0.3380])
        # Bifurcation curve
        p1 = Hopf[:, np.random.randint(599, 750)]
        # Random point in limit cycle
        p1_5 = np.array([-0.2104, 0.3180, -0.1209])
        # Bifurcation curve
        p2 = Hopf[:, np.random.randint(599, 750)]
        # Fixed rest
        p3 = np.array([0.1944, 0.0893, 0.3380])

    else:
        raise ValueError("Invalid bifurcation value")

    return p0, p1, p1_5, p2, p3

def sphereArcPath(k, tstep, point1, point2):
    """
    sphere_arc_path - Generates an arc path between two points on a sphere.

    Parameters:
    point1 : array-like, shape (3,)
        Coordinates of the first point on the sphere.
    point2 : array-like, shape (3,)
        Coordinates of the second point on the sphere.
    k : float
        Parameter for arc path generation.
    tstep : float
        Time step for arc path generation.

    Returns:
    mu2 : array, shape (N,)
        x-coordinates of points along the arc.
    mu1 : array, shape (N,)
        y-coordinates of points along the arc.
    nu : array, shape (N,)
        z-coordinates of points along the arc.
    theta : float
        Angle between the two points.
    """
    radius = 0.4

    # Normalize the input points to ensure they are on the sphere
    point1 = np.array(point1) / np.linalg.norm(point1) * radius
    point2 = np.array(point2) / np.linalg.norm(point2) * radius

    # Compute the quaternion for rotation
    theta = np.arccos(np.dot(point1, point2) / (radius ** 2))
    axis = np.cross(point1, point2)

    if np.linalg.norm(axis) == 0:
        raise ValueError('The points are the same or antipodal.')

    axis = axis / np.linalg.norm(axis)

    # Compute points along the arc
    num_points = int(np.floor(theta / k / tstep))
    arc_path = np.zeros((num_points, 3))

    for i in range(num_points):
        t = i / (num_points - 1)
        angle = t * theta
        R = rotation_matrix(axis, angle)
        arc_path[i, :] = (R @ point1).T

    mu2 = arc_path[:, 0]
    mu1 = arc_path[:, 1]
    nu = arc_path[:, 2]

    return mu2, mu1, nu, theta

def rotation_matrix(axis, angle):
    """
    rotation_matrix - Generates a rotation matrix given an axis and an angle.

    Parameters:
    axis : array-like, shape (3,)
        A vector representing the axis of rotation.
    angle : float
        The angle of rotation in radians.

    Returns:
    R : array, shape (3, 3)
        A 3x3 rotation matrix.
    """
    ux, uy, uz = axis

    c = np.cos(angle)
    s = np.sin(angle)
    t = 1 - c

    R = np.array([[t * ux * ux + c, t * ux * uy - s * uz, t * ux * uz + s * uy],
                  [t * ux * uy + s * uz, t * uy * uy + c, t * uy * uz - s * ux],
                  [t * ux * uz - s * uy, t * uy * uz + s * ux, t * uz * uz + c]])

    return R

def slow_wave_model_piecewise(x, k, mu2, mu1, nu):
    xdot = -x[1]
    ydot = x[0] ** 3 - mu2 * x[0] - mu1 - x[1] * (nu + x[0] + x[0] ** 2)
    zdot = k
    return np.array([xdot, ydot, zdot]), mu2, mu1, nu




def pinknoise(DIM, BETA, MAG):
    """
    Generate 1/f spatial noise with a normal error distribution.

    Parameters:
    DIM (tuple): Size of the spatial pattern (e.g., DIM=(10, 5) for a 10x5 spatial grid)
    BETA (float): Spectral distribution parameter
                  BETA = 0 is random white noise
                  BETA = -1 is pink noise
                  BETA = -2 is Brownian noise
    MAG (float): Scaling variable for the noise amplitude

    Returns:
    np.ndarray: 2D array of pink noise
    """

    u = np.fft.fftfreq(DIM[0]).reshape(-1, 1)
    v = np.fft.fftfreq(DIM[1]).reshape(1, -1)

    S_f = (u**2 + v**2)
    S_f[S_f == 0] = 1  # Avoid division by zero
    S_f = S_f**(BETA/2.0)
    S_f[u**2 + v**2 == 0] = 0  # Set the zero frequency component back to zero

    phi = np.random.rand(*DIM)
    y = S_f**0.5 * (np.cos(2 * np.pi * phi) + 1j * np.sin(2 * np.pi * phi))
    y = y * MAG / np.max(np.abs(y))

    x = np.fft.ifft2(y)
    x = np.real(x)

    return x



# SETTINGS - INTEGRATION
x0 = np.array([0.0, 0, 0])  # initial conditions (must be a column)

# SETTINGS - MODEL
b = 1.0  # focus
R = 0.4  # radius

# Class Information (paths)
# c6
# Tmax is length of time 'Time Series' runs
# Tstep controls the time increment used to create the tspan array(0, tmax + tstep, tstep)
#k controls the speed of change of the z variable, which determines the number of oscillations in the ‘Time Series’ model
k = 0.005
Tmax = 3600
Tstep = 0.01

# Integration
# Class
# load('curves.mat');
I = 13  # Choose index corresponding to class above

sigma = np.array([0.0, 0, 0])  # Input values - first two are for fast subsystem
sigma = sigma.reshape((3, 1))
# Class specific timespan
tstep = Tstep
tmax = Tmax
tspan = np.arange(0, tmax + tstep, tstep)
bifurcation = 3
p0,p1,p1_5,p2,p3 = piecewise_random_path(bifurcation)

mu2_straight_path0, mu1_straight_path0, nu_straight_path0, rad1 = sphereArcPath(k, tstep, p0, p1)
mu2_straight_path0_5, mu1_straight_path0_5, nu_straight_path0_5, rad2 = sphereArcPath(k, tstep, p1, p1_5)
mu2_straight_path, mu1_straight_path, nu_straight_path, rad3 = sphereArcPath(k, tstep, p1_5, p2)
mu2_straight_path1, mu1_straight_path1, nu_straight_path1, rad4 = sphereArcPath(k, tstep, p2, p3)

mu2_all = np.concatenate((mu2_straight_path0, mu2_straight_path0_5, mu2_straight_path, mu2_straight_path1))
mu1_all = np.concatenate((mu1_straight_path0, mu1_straight_path0_5, mu1_straight_path, mu1_straight_path1))*-1
nu_all = np.concatenate((nu_straight_path0, nu_straight_path0_5, nu_straight_path, nu_straight_path1))

N_t = len(mu2_all)
X = np.zeros((3, N_t))
xx = x0
sigma = 100

Rn1 = pinknoise((1, N_t), -1, sigma)
Rn2 = pinknoise((1, N_t), -1, 0)
Rn3 = pinknoise((1, N_t), -1, 0)

Rn = np.vstack([Rn1, Rn2, Rn3])

mu2_big = np.zeros(N_t)
mu1_big = np.zeros(N_t)
nu_big = np.zeros(N_t)

onset_index = int(np.floor((rad1 / k) / tstep))
offset_index = int(np.floor(((rad1 + rad2 + rad3) / k) / tstep))

for n in range(N_t):
    Fxx, mu2, mu1, nu = slow_wave_model_piecewise(xx, k, mu2_all[n], mu1_all[n], nu_all[n])
    xx = xx + tstep * Fxx + np.sqrt(tstep) * Rn[:, n]
    X[:, n] = xx
    mu2_big[n] = mu2
    mu1_big[n] = mu1
    nu_big[n] = nu

x = X.T
t = tspan

# Load data
data_mesh = loadmat('testmesh.mat')
data_curves = loadmat('curves.mat')
data_curves_2 = loadmat('curves2.mat')
data_bif = loadmat('bifurcation_crossing.mat')
data_sphere = loadmat('sphere_mesh.mat')

BCSmesh = data_mesh['BCSmesh']
Active_restmesh = data_mesh['Active_restmesh']
Seizure_mesh = data_mesh['Seizure_mesh']
Bistable_Lcb_mesh = data_mesh['Bistable_Lcb_mesh']
Fold_of_cycles = data_bif['Fold_of_cycles']
Homoclinic_to_saddle3 = data_bif['Homoclinic_to_saddle3']
Homoclinic_to_saddle2 = data_bif['Homoclinic_to_saddle2']
Homoclinic_to_saddle1 = data_bif['Homoclinic_to_saddle1']
Homoclinic_to_saddle = data_bif['Homoclinic_to_saddle']
Fold = data_bif['Fold']
Hopf = data_bif['Hopf']
SNIC = data_curves_2['SNIC']
# Define sphere parameters
radius = 0.4
phi, theta = np.mgrid[0.0:np.pi:50j, 0.0:2.0*np.pi:50j]
# Parametric equations for the sphere
X_sphere = radius * np.sin(phi) * np.cos(theta)
Y_sphere = radius * np.sin(phi) * np.sin(theta)
Z_sphere = radius * np.cos(phi)




 # Plotting
fig_timeseries = go.Figure()
fig_timeseries.add_trace(go.Scatter(x=t, y=x[:, 0], mode='lines', name='x', line=dict(color='black', width=1)))
fig_timeseries.update_layout(
    title='Timeseries',
    xaxis_title='t',
    yaxis_title='x',
    plot_bgcolor='white',  # Background color of the plot area
    paper_bgcolor='white'  # Background color of the whole figure
)

fig_timeseries.show()
# Create a Plotly figure
fig_mesh = go.Figure()

# BCSmesh
fig_mesh.add_trace(go.Mesh3d(
    x=BCSmesh['vertices'][0][0][:, 0],
    y=BCSmesh['vertices'][0][0][:, 1],
    z=BCSmesh['vertices'][0][0][:, 2],
    i=BCSmesh['faces'][0][0][:, 0] - 1,
    j=BCSmesh['faces'][0][0][:, 1] - 1,
    k=BCSmesh['faces'][0][0][:, 2] - 1,
    opacity=0.3,
    color='rgba(248, 246, 184, 0.3)',
    name='BCSmesh'
))

# Active_restmesh
fig_mesh.add_trace(go.Mesh3d(
    x=Active_restmesh['vertices'][0][0][:, 0],
    y=Active_restmesh['vertices'][0][0][:, 1],
    z=Active_restmesh['vertices'][0][0][:, 2],
    i=Active_restmesh['faces'][0][0][:, 0] - 1,
    j=Active_restmesh['faces'][0][0][:, 1] - 1,
    k=Active_restmesh['faces'][0][0][:, 2] - 1,
    opacity=0.3,
    color='rgba(235, 235, 235, 0.7)',
    name='Active_restmesh'
))

# Seizure_mesh
fig_mesh.add_trace(go.Mesh3d(
    x=Seizure_mesh['vertices'][0][0][:, 0],
    y=Seizure_mesh['vertices'][0][0][:, 1],
    z=Seizure_mesh['vertices'][0][0][:, 2],
    i=Seizure_mesh['faces'][0][0][:, 0] - 1,
    j=Seizure_mesh['faces'][0][0][:, 1] - 1,
    k=Seizure_mesh['faces'][0][0][:, 2] - 1,
    opacity=0.3,
    color='rgba(228, 180, 211, 0.2)',
    name='Seizure_mesh'
))

# Bistable_Lcb_mesh
fig_mesh.add_trace(go.Mesh3d(
    x=Bistable_Lcb_mesh['vertices'][0][0][:, 0],
    y=Bistable_Lcb_mesh['vertices'][0][0][:, 1],
    z=Bistable_Lcb_mesh['vertices'][0][0][:, 2],
    i=Bistable_Lcb_mesh['faces'][0][0][:, 0] - 1,
    j=Bistable_Lcb_mesh['faces'][0][0][:, 1] - 1,
    k=Bistable_Lcb_mesh['faces'][0][0][:, 2] - 1,
    opacity=0.3,
    color='rgba(248, 246, 184, 0.3)',
    name='Bistable_Lcb_mesh'
))
fig_mesh.add_trace(go.Scatter3d(
        x=mu2_big,
        y=-mu1_big,
        z=nu_big,
        mode='lines',
        line=dict(color='rgba(0,0,0, 1)', width=2),
        name='Bursting path'
    ))



# Plotting curves
for scale in [1]:
    Fold_of_cycles_scaled = scale * Fold_of_cycles
    Homoclinic_to_saddle3_scaled = scale * Homoclinic_to_saddle3
    Homoclinic_to_saddle2_scaled = scale * Homoclinic_to_saddle2
    Homoclinic_to_saddle1_scaled = scale * Homoclinic_to_saddle1
    Homoclinic_to_saddle_scaled = scale * Homoclinic_to_saddle
    Fold_scaled = scale * Fold
    Hopf_scaled = scale * Hopf
    SNIC_scaled = scale * SNIC

    fig_mesh.add_trace(go.Scatter3d(
        x=Fold_of_cycles_scaled[0, :],
        y=Fold_of_cycles_scaled[1, :],
        z=Fold_of_cycles_scaled[2, :],
        mode='lines',
        line=dict(color='rgba(248, 68, 149, 1)', width=2),
        name='Fold of cycles (FLC)'
    ))

    fig_mesh.add_trace(go.Scatter3d(
        x=Homoclinic_to_saddle3_scaled[0, :],
        y=Homoclinic_to_saddle3_scaled[1, :],
        z=Homoclinic_to_saddle3_scaled[2, :],
        mode='lines',
        line=dict(color='rgba(103, 179, 217, 1)', width=2, dash='dash'),
        showlegend=False
    ))

    fig_mesh.add_trace(go.Scatter3d(
        x=Homoclinic_to_saddle2_scaled[0, :],
        y=Homoclinic_to_saddle2_scaled[1, :],
        z=Homoclinic_to_saddle2_scaled[2, :],
        mode='lines',
        line=dict(color='rgba(103, 179, 217, 1)', width=2),
        showlegend=False
    ))

    fig_mesh.add_trace(go.Scatter3d(
        x=Homoclinic_to_saddle1_scaled[0, :],
        y=Homoclinic_to_saddle1_scaled[1, :],
        z=Homoclinic_to_saddle1_scaled[2, :],
        mode='lines',
        line=dict(color='rgba(103, 179, 217, 1)', width=2, dash='dash'),
        showlegend=False
    ))

    fig_mesh.add_trace(go.Scatter3d(
        x=Homoclinic_to_saddle_scaled[0, :],
        y=Homoclinic_to_saddle_scaled[1, :],
        z=Homoclinic_to_saddle_scaled[2, :],
        mode='lines',
        line=dict(color='rgba(103, 179, 217, 1)', width=2),
        name='Saddle-Homoclinic (SH)'
    ))

    fig_mesh.add_trace(go.Scatter3d(
        x=Fold_scaled[0, 139:563],
        y=Fold_scaled[1, 139:563],
        z=Fold_scaled[2, 139:563],
        mode='lines',
        line=dict(color='rgba(244, 156, 52, 1)', width=2),
        name='Saddle Node (SN)'
    ))

    fig_mesh.add_trace(go.Scatter3d(
        x=Fold_scaled[0, 574:],
        y=Fold_scaled[1, 574:],
        z=Fold_scaled[2, 574:],
        mode='lines',
        line=dict(color='rgba(244, 156, 52, 1)', width=2),
       showlegend=False
    ))

    fig_mesh.add_trace(go.Scatter3d(
        x=Fold_scaled[0, :79],
        y=Fold_scaled[1, :79],
        z=Fold_scaled[2, :79],
        mode='lines',
        line=dict(color='rgba(244, 156, 52, 1)', width=2),
        showlegend=False
    ))

    fig_mesh.add_trace(go.Scatter3d(
        x=Fold_scaled[0, 564:574],
        y=Fold_scaled[1, 564:574],
        z=Fold_scaled[2, 564:574],
        mode='lines',
        line=dict(color='rgba(244, 156, 52, 1)', width=2, dash='dash'),
        showlegend=False
    ))

    fig_mesh.add_trace(go.Scatter3d(
        x=Hopf_scaled[0, 399:973],
        y=Hopf_scaled[1, 399:973],
        z=Hopf_scaled[2, 399:973],
        mode='lines',
        line=dict(color='rgba(116, 191, 69, 1)', width=2),
        name='Supercritical Hopf (SupH)'
    ))

    fig_mesh.add_trace(go.Scatter3d(
        x=Hopf_scaled[0, :400],
        y=Hopf_scaled[1, :400],
        z=Hopf_scaled[2, :400],
        mode='lines',
        line=dict(color='rgba(116, 191, 69, 1)', width=2, dash='dot'),
        name='Subcritical Hopf (SubH)'
    ))
    fig_mesh.add_trace(go.Scatter3d(
        x=SNIC_scaled[0, :],
        y=SNIC_scaled[1, :],
        z=SNIC_scaled[2, :],
        mode='lines',
        line=dict(color='rgba(244, 156, 52, 1)', width=2, dash='dash'),
        name='Saddle Node Invariant Cycle (SNIC)'
    ))

fig_mesh.add_trace(go.Mesh3d(
    x=X_sphere.flatten(),
    y=Y_sphere.flatten(),
    z=Z_sphere.flatten(),
    opacity=0.3,
    color='rgba(76, 76, 76, 0.1)',
    name='Sphere',
    alphahull=0
))



fig_mesh.update_layout(
    scene=dict(
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z',
        bgcolor='white',  # Background color of the 3D scene
    ),
    title='Parameter Space Plot',
    paper_bgcolor='white',  # Background color of the whole figure
    plot_bgcolor='white'  # Background color of the plot area
)

fig_mesh.show()