In [None]:
!pip install plotly

In [None]:
pip install jax

In [None]:
!pip install dash

In [4]:
from dgh import upper

In [5]:
import numpy as np
def generate_symmetric_matrix(size, min_val=1, max_val=100):

    matrix = np.random.randint(min_val, max_val + 1, size=(size, size))

    symmetric_matrix = (matrix + matrix.T) / 2

    np.fill_diagonal(symmetric_matrix, 0.0)
    return symmetric_matrix


In [6]:
def get_points_circle(n, R, x_center, y_center, z_center):
    x = np.random.uniform(x_center - R, x_center + R, n)
    y = np.random.uniform(y_center - np.sqrt(np.square(R)-np.square(x-x_center)), y_center + np.sqrt(np.square(R)-np.square(x-x_center)), n)
    z_pos_sqrt = np.sqrt(np.square(R)-np.square(x-x_center)-np.square(y-y_center))
    random_adj = np.random.choice([-1,1],len(x))
    z = random_adj * z_pos_sqrt + z_center
    return x,y,z

In [7]:
points_X = get_points_circle(35, 30, 0,0,0)

Y_x = points_X[0] + 61
Y_y = points_X[1] + 61
Y_z = points_X[2] + 61
points_Y = (Y_x, Y_y, Y_z)


coords_X = np.vstack(points_X).T
coords_Y = np.vstack(points_Y).T

X = np.linalg.norm(coords_X[:, None] - coords_X, axis=2)
Y = np.linalg.norm(coords_Y[:, None] - coords_Y, axis=2)

In [None]:
result = upper(X, Y, iter_budget=500, return_fg=True)
d_gh, f, g = result

In [14]:
d_gh

np.float64(10.892527285641098)

In [None]:
f

In [None]:
#Это 1 вариант рисовашки, когда при нажатии на точку рисуется стрелка соответсвующая, удобно смотреть на конкретные точки, но не удобно кликать на все

import dash
from dash import dcc, html
from dash.dependencies import Input, Output
import numpy as np
import plotly.graph_objects as go

x_X, y_X, z_X = points_X
x_Y, y_Y, z_Y = points_Y

n = len(x_X)

app = dash.Dash(__name__)

fig = go.Figure()
fig.add_trace(go.Scatter3d(
    x=x_X, y=y_X, z=z_X, mode='markers+text', text=[f'X{i}' for i in range(n)],
    marker=dict(color='blue', size=5), textposition='top center', name='X',
    customdata=np.arange(n)
))
fig.add_trace(go.Scatter3d(
    x=x_Y, y=y_Y, z=z_Y, mode='markers+text', text=[f'Y{i}' for i in range(n)],
    marker=dict(color='red', size=5), textposition='top center', name='Y',
    customdata=np.arange(n, 2*n)
))

fig.update_layout(scene=dict(aspectmode='cube'))

app.layout = html.Div([
    dcc.Graph(id='3d-plot', figure=fig, style={'height': '80vh'}),
    html.P("Кликните на точку, чтобы добавить/убрать стрелку.")
])

selected_points = set()

@app.callback(
    Output('3d-plot', 'figure'),
    Input('3d-plot', 'clickData')
)
def update_graph(clickData):
    global selected_points
    print("Click Data:", clickData)

    if clickData and 'points' in clickData and len(clickData['points']) > 0:
        point_info = clickData['points'][0]
        if 'customdata' in point_info:
            point_idx = int(point_info['customdata'])
            if point_idx in selected_points:
                selected_points.remove(point_idx)
            else:
                selected_points.add(point_idx)

    new_fig = go.Figure()
    new_fig.add_trace(go.Scatter3d(x=x_X, y=y_X, z=z_X, mode='markers+text', text=[f'X{i}' for i in range(n)],
                                   marker=dict(color='blue', size=5), textposition='top center', name='X',
                                   customdata=np.arange(n)))
    new_fig.add_trace(go.Scatter3d(x=x_Y, y=y_Y, z=z_Y, mode='markers+text', text=[f'Y{i}' for i in range(n)],
                                   marker=dict(color='red', size=5), textposition='top center', name='Y',
                                   customdata=np.arange(n, 2*n)))

    for i in selected_points:
        if i < n:
            j = f[i]
            new_fig.add_trace(go.Scatter3d(
                x=[x_X[i], x_Y[j]], y=[y_X[i], y_Y[j]], z=[z_X[i], z_Y[j]],
                mode='lines', line=dict(color='blue', width=2), name=f'f(X{i}) → Y{j}'
            ))
        else:
            i -= n
            k = g[i]
            new_fig.add_trace(go.Scatter3d(
                x=[x_Y[i], x_X[k]], y=[y_Y[i], y_X[k]], z=[z_Y[i], z_X[k]],
                mode='lines', line=dict(color='red', width=2), name=f'g(Y{i}) → X{k}'
            ))

    new_fig.update_layout(scene=dict(aspectmode='cube'))
    return new_fig

if __name__ == '__main__':
    app.run(debug=True)

<IPython.core.display.Javascript object>

In [None]:
#А вот эта рисовашка уже позволяет рисовать только неправильные отображения (при условии что у нас два одинаковых объекта как я делаю, иначе надо модифиуировать)
#Также можно еще на кнопки понажимать 
import dash
from dash import dcc, html
from dash.dependencies import Input, Output, State
import numpy as np
import plotly.graph_objects as go

x_X, y_X, z_X = points_X
x_Y, y_Y, z_Y = points_Y
n = len(x_X)

app = dash.Dash(__name__)

def find_incorrect_mappings():
    incorrect_X = [i for i in range(n) if f[i] != i]
    incorrect_Y = [i + n for i in range(n) if g[i] != i]
    return set(incorrect_X + incorrect_Y)

fig = go.Figure()
fig.add_trace(go.Scatter3d(
    x=x_X, y=y_X, z=z_X, mode='markers+text', text=[f'X{i}' for i in range(n)],
    marker=dict(color='blue', size=5), textposition='top center', name='X',
    customdata=np.arange(n)
))
fig.add_trace(go.Scatter3d(
    x=x_Y, y=y_Y, z=z_Y, mode='markers+text', text=[f'Y{i}' for i in range(n)],
    marker=dict(color='red', size=5), textposition='top center', name='Y',
    customdata=np.arange(n, 2*n)
))
fig.update_layout(scene=dict(aspectmode='cube'))

app.layout = html.Div([
    dcc.Graph(id='3d-plot', figure=fig, style={'height': '80vh'}),
    html.Div([
        html.Button('Все стрелки', id='toggle-all', n_clicks=0,
                  style={'margin': '10px', 'padding': '10px'}),
        html.Button('Только f (X→Y)', id='show-f', n_clicks=0,
                  style={'margin': '10px', 'padding': '10px', 'background-color': 'lightblue'}),
        html.Button('Только g (Y→X)', id='show-g', n_clicks=0,
                  style={'margin': '10px', 'padding': '10px', 'background-color': 'lightcoral'}),
        html.Button('Неправильные', id='show-incorrect', n_clicks=0,
                  style={'margin': '10px', 'padding': '10px', 'background-color': 'yellow'}),
        html.Button('Очистить все', id='clear-all', n_clicks=0,
                  style={'margin': '10px', 'padding': '10px'}),
    ], style={'display': 'flex', 'justify-content': 'center', 'flex-wrap': 'wrap'}),
    html.P("Кликните на точку, чтобы добавить/убрать стрелку."),
    dcc.Store(id='selected-points', data={'points': []})
])

@app.callback(
    Output('3d-plot', 'figure'),
    Output('selected-points', 'data'),
    Input('3d-plot', 'clickData'),
    Input('toggle-all', 'n_clicks'),
    Input('show-f', 'n_clicks'),
    Input('show-g', 'n_clicks'),
    Input('show-incorrect', 'n_clicks'),
    Input('clear-all', 'n_clicks'),
    State('selected-points', 'data'),
    prevent_initial_call=True
)
def update_graph(clickData, toggle_all, show_f, show_g, show_incorrect, clear_all, stored_data):
    ctx = dash.callback_context
    trigger_id = ctx.triggered[0]['prop_id'].split('.')[0]

    selected_points = set(stored_data['points'])

    if trigger_id == 'toggle-all':
        selected_points = set(range(2*n))
    elif trigger_id == 'show-f':
        selected_points = set(range(n))
    elif trigger_id == 'show-g':
        selected_points = set(range(n, 2*n))
    elif trigger_id == 'show-incorrect':
        selected_points = find_incorrect_mappings()
    elif trigger_id == 'clear-all':
        selected_points = set()
    elif trigger_id == '3d-plot' and clickData:
        point_info = clickData['points'][0]
        if 'customdata' in point_info:
            point_idx = int(point_info['customdata'])
            if point_idx in selected_points:
                selected_points.remove(point_idx)
            else:
                selected_points.add(point_idx)

    new_fig = go.Figure()
    new_fig.add_trace(go.Scatter3d(x=x_X, y=y_X, z=z_X, mode='markers+text', text=[f'X{i}' for i in range(n)],
                                 marker=dict(color='blue', size=5), textposition='top center', name='X',
                                 customdata=np.arange(n)))
    new_fig.add_trace(go.Scatter3d(x=x_Y, y=y_Y, z=z_Y, mode='markers+text', text=[f'Y{i}' for i in range(n)],
                                 marker=dict(color='red', size=5), textposition='top center', name='Y',
                                 customdata=np.arange(n, 2*n)))

    for i in selected_points:
        if i < n:
            j = f[i]
            line_style = dict(color='blue', width=3, dash='dot' if j != i else 'solid')
            new_fig.add_trace(go.Scatter3d(
                x=[x_X[i], x_Y[j]], y=[y_X[i], y_Y[j]], z=[z_X[i], z_Y[j]],
                mode='lines', line=line_style, name=f'f(X{i}) → Y{j}',
                hoverinfo='name'
            ))
        else:
            i_adj = i - n
            k = g[i_adj]
            line_style = dict(color='red', width=3, dash='dot' if k != i_adj else 'solid')
            new_fig.add_trace(go.Scatter3d(
                x=[x_Y[i_adj], x_X[k]], y=[y_Y[i_adj], y_X[k]], z=[z_Y[i_adj], z_X[k]],
                mode='lines', line=line_style, name=f'g(Y{i_adj}) → X{k}',
                hoverinfo='name'
            ))

    new_fig.update_layout(scene=dict(aspectmode='cube'))

    return new_fig, {'points': list(selected_points)}

if __name__ == '__main__':
    app.run(debug=True)

In [None]:
#Рисовка не интерактивных рисунков, что полезно для массового изучения
import plotly.graph_objects as go
import numpy as np

def plot_incorrect_mappings(points_X, points_Y, f, g, title=""):
    x_X, y_X, z_X = points_X
    x_Y, y_Y, z_Y = points_Y
    n = len(x_X)

    incorrect_X = [i for i in range(n) if f[i] != i]
    incorrect_Y = [i for i in range(n) if g[i] != i]

    fig = go.Figure()

    fig.add_trace(go.Scatter3d(
        x=x_X, y=y_X, z=z_X, mode='markers+text',
        text=[f'X{i}' for i in range(n)],
        marker=dict(color='blue', size=5),
        textposition='top center', name='X'
    ))
    fig.add_trace(go.Scatter3d(
        x=x_Y, y=y_Y, z=z_Y, mode='markers+text',
        text=[f'Y{i}' for i in range(n)],
        marker=dict(color='red', size=5),
        textposition='top center', name='Y'
    ))

    for i in incorrect_X:
        j = f[i]
        fig.add_trace(go.Scatter3d(
            x=[x_X[i], x_Y[j]], y=[y_X[i], y_Y[j]], z=[z_X[i], z_Y[j]],
            mode='lines', line=dict(color='blue', width=2, dash='dot'),
            name=f'f(X{i}) → Y{j}'
        ))

    for i in incorrect_Y:
        k = g[i]
        fig.add_trace(go.Scatter3d(
            x=[x_Y[i], x_X[k]], y=[y_Y[i], y_X[k]], z=[z_Y[i], z_X[k]],
            mode='lines', line=dict(color='red', width=2, dash='dot'),
            name=f'g(Y{i}) → X{k}'
        ))

    fig.update_layout(
        title=title,
        scene=dict(aspectmode='cube'),
        showlegend=True
    )

    return fig

In [17]:
def get_points_torus(n, R, r, center=(0, 0, 0)):

    theta = np.random.uniform(0, 2*np.pi, n)
    phi = np.random.uniform(0, 2*np.pi, n)

    x = (R + r * np.cos(phi)) * np.cos(theta)
    y = (R + r * np.cos(phi)) * np.sin(theta)
    z = r * np.sin(phi)

    x += center[0]
    y += center[1]
    z += center[2]

    return x, y, z


In [35]:
points_X = get_points_torus(n=35, R=30, r=10, center=(0, 0, 0))
#points_Y = get_points_torus(n=35, R=10, r=3, center=(61, 61, 61))
Y_x = points_X[0] + 61
Y_y = points_X[1] + 61
Y_z = points_X[2] + 61
points_Y = (Y_x, Y_y, Y_z)


coords_X = np.vstack(points_X).T
coords_Y = np.vstack(points_Y).T

X = np.linalg.norm(coords_X[:, None] - coords_X, axis=2)
Y = np.linalg.norm(coords_Y[:, None] - coords_Y, axis=2)

In [None]:
result = upper(X, Y, iter_budget=1500, return_fg=True)
d_gh, f, g = result
d_gh

np.float64(9.511791706186104)

In [None]:
#вот и пример использования неинтерактивных рисунков
import plotly.graph_objects as go
import numpy as np

def plot_incorrect_mappings(points_X, points_Y, f, g, title=""):
    x_X, y_X, z_X = points_X
    x_Y, y_Y, z_Y = points_Y
    n = len(x_X)

    incorrect_X = [i for i in range(n) if f[i] != i]
    incorrect_Y = [i for i in range(n) if g[i] != i]

    fig = go.Figure()

    fig.add_trace(go.Scatter3d(
        x=x_X, y=y_X, z=z_X, mode='markers+text',
        text=[f'X{i}' for i in range(n)],
        marker=dict(color='blue', size=5),
        textposition='top center', name='X'
    ))
    fig.add_trace(go.Scatter3d(
        x=x_Y, y=y_Y, z=z_Y, mode='markers+text',
        text=[f'Y{i}' for i in range(n)],
        marker=dict(color='red', size=5),
        textposition='top center', name='Y'
    ))

    for i in incorrect_X:
        j = f[i]
        fig.add_trace(go.Scatter3d(
            x=[x_X[i], x_Y[j]], y=[y_X[i], y_Y[j]], z=[z_X[i], z_Y[j]],
            mode='lines', line=dict(color='blue', width=2, dash='dot'),
            name=f'f(X{i}) → Y{j}'
        ))

    for i in incorrect_Y:
        k = g[i]
        fig.add_trace(go.Scatter3d(
            x=[x_Y[i], x_X[k]], y=[y_Y[i], y_X[k]], z=[z_Y[i], z_X[k]],
            mode='lines', line=dict(color='red', width=2, dash='dot'),
            name=f'g(Y{i}) → X{k}'
        ))

    fig.update_layout(
        title=title,
        scene=dict(aspectmode='cube'),
        showlegend=True
    )

    return fig

for DGH in [100,1000,1500,3000,5000,7500]:
    result = upper(X, Y, iter_budget=DGH, return_fg=True)
    d_gh, f, g = result
    fig = plot_incorrect_mappings(points_X, points_Y, f, g, title=f"ITERS_BUDGET = {DGH}, dgh = {d_gh}")
    fig.show()