In [None]:
import numpy as np
import pandas as pd

In [None]:
import src
import src.spiral
import src.spiral_plots

In [None]:
max_radius = 2
blue_spiral = src.spiral.Spiral(t_max=max_radius)
red_spiral = src.spiral.Spiral(initial_angle=np.pi, t_max=max_radius)
n = 100


In [None]:
blue_samples = blue_spiral.sample_normalized(n)
red_samples = red_spiral.sample_normalized(n)

In [None]:
blue_samples.shape

In [None]:
from plotly import graph_objects as go

In [None]:
max_radius = 2
n = 100

In [None]:
def create_heatmap(x, y, f, f_is_vectorized = True):
    if f_is_vectorized:
        X, Y = np.meshgrid(x, y)
        Z = f(X,Y)
    else:
        Z = np.zeros((len(y), len(x)), dtype = float)
        for i, x_val in enumerate(x):
            for j, y_val in enumerate(y):
                Z[j, i] = f(x_val, y_val)
    return Z


In [None]:
x = np.linspace(-1.1*max_radius, 1.1*max_radius, 300)
y = np.linspace(-1.1*max_radius, 1.1*max_radius, 300)
f = lambda x,y: np.sin(4*np.sqrt(x**2 + y**2) - np.angle(x+1j*y))
Z = create_heatmap(x,y,f, f_is_vectorized=False)

In [None]:
blue_spiral = src.spiral.Spiral(t_max=max_radius)
red_spiral = src.spiral.Spiral(initial_angle=np.pi, t_max=max_radius)

blue_samples = blue_spiral.sample_normalized(n)
red_samples = red_spiral.sample_normalized(n)

blue_df = pd.DataFrame({'x': blue_samples[:, 0], 'y': blue_samples[:, 1], 'spiral': 'Blue'})
red_df = pd.DataFrame({'x': red_samples[:, 0], 'y': red_samples[:, 1], 'spiral': 'Red'})

In [None]:
fig = create_spiral_plot([blue_df, red_df], plot_radius=max_radius)
fig.add_trace(
    go.Heatmap(z = Z, x = x, y = y)
)


In [None]:

from typing import List, Dict, Optional
from plotly import graph_objects as go
import pandas as pd

def create_spiral_plot(spirals: List[pd.DataFrame], plot_radius:float, pad_factor:float=1.1, colors: Optional[List[str]] = None, fig = None):
    pad_factor= 1.1
    if colors is None:
        colors = [None]*len(spirals)
    if fig is None:
        fig = go.Figure()
    for s_i, spiral in enumerate(spirals):
        fig.add_trace(
            go.Scatter(x = spiral.x, y=spiral.y,
                    name= "Spirals", 
                    mode = 'markers',
                    marker = dict(
                        color = colors[s_i]
                    )
            )
        )

    fig.update_layout(
        autosize=False,
        width=500,
        height=500,
        xaxis = dict(
            range = [-pad_factor*plot_radius , pad_factor*plot_radius]
        ),
        yaxis = dict(
            range = [-pad_factor*plot_radius, pad_factor*plot_radius]
        ),
        margin=dict(
            l=40,
            r=0,
            b=40,
            t=0,
            pad=0
        ),
        paper_bgcolor="White",
    )

    return fig
