In [30]:
import subprocess
import sys
def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

install('kaleido')

In [22]:
import torch
import networkx as nx
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import gif
import kaleido

In [2]:
with open('./../data/input.txt', 'r') as f:
    lines = f.readlines()
G = nx.Graph()
for line in lines:
    words = line.replace("\n", "").replace(";", "").replace(",", "").split(" ")
    name = words[1]
    flow_rate = int(words[4].split("=")[1])
    connections = words[9:]
    G.add_node(name)
    G.nodes[name]['fr'] = flow_rate
    G.nodes[name]['pos'] = np.random.normal(0, 1., (2,))


for line in lines:
    words = line.replace("\n", "").replace(";", "").replace(",", "").split(" ")
    name = words[1]
    flow_rate = int(words[4].split("=")[1])
    connections = words[9:]
    for con in connections:
        G.add_edge(name, con)
N = G.number_of_nodes()

In [3]:
def graph_to_plotly(G):
    for node in G.nodes:
        T = len(G.nodes[node]['pos'][0])
    
    edges_pos ={'x': [], 'y': []}
    for t in range(T):
        edges_pos['x'].append([])
        edges_pos['y'].append([])
        for edge in G.edges:
            edges_pos['x'][t].append(G.nodes[edge[0]]['pos'][0][t])
            edges_pos['x'][t].append(G.nodes[edge[1]]['pos'][0][t])
            edges_pos['x'][t].append(None)
            
            edges_pos['y'][t].append(G.nodes[edge[0]]['pos'][1][t])
            edges_pos['y'][t].append(G.nodes[edge[1]]['pos'][1][t])
            edges_pos['y'][t].append(None)
    
    node_pos = {'x': [], 'y': []}
    for t in range(T):
        node_pos['x'].append([])
        node_pos['y'].append([])
        for node in G.nodes:
            node_pos['x'][t].append(G.nodes[node]['pos'][0][t])
            node_pos['y'][t].append(G.nodes[node]['pos'][1][t])
    
    return node_pos, edges_pos

In [4]:
def animate_network(node_pos, edge_pos):
    T = len(node_pos['x'])
    fig = go.Figure(
        data=[go.Scatter(x=node_pos['x'][0], y=node_pos['y'][0],
                        name="frame",
                        mode="markers",
                        line=dict(width=2, color="blue")),
            go.Scatter(x=edge_pos['x'][0], y=edge_pos['y'][0],
                        name="curve",
                        mode="lines",
                        line=dict(width=2, color="blue"))
            ],
        layout=go.Layout(width=600, height=600,
                        xaxis=dict(range=[-12,12], autorange=False, zeroline=False),
                        yaxis=dict(range=[-12, 12], autorange=False, zeroline=False),
                        title="Moving Frenet Frame Along a Planar Curve",
                        hovermode="closest",
                        updatemenus=[dict(type="buttons",
                                        buttons=[dict(label="Play",
                                                        method="animate",
                                                        args=[None])])]
                        ),
        frames=[go.Frame(
            data=[go.Scatter(
                x= edge_pos['x'][k],
                y= edge_pos['y'][k],
                mode="lines",
                line=dict(color="red", width=2))
            ,
            go.Scatter(
                x= node_pos['x'][k],
                y= node_pos['y'][k],
                mode="markers",
                line=dict(color="blue", width=2))
            ]) for k in range(T)]
    )

    return fig

In [5]:
def get_index_maps(G):
    node_to_index = {}
    index_to_node = {}
    current_i = 0
    for node in G.nodes:
        node_to_index[str(node)] = current_i
        index_to_node[current_i] = str(node)
        current_i += 1
    return node_to_index, index_to_node

def get_connection_matrix(G, node_to_index, num_nodes):
    connected = torch.zeros((num_nodes, num_nodes), dtype=torch.float32)
    for edge in G.edges:
        connected[node_to_index[edge[0]], node_to_index[edge[1]]] = 1.
        connected[node_to_index[edge[1]], node_to_index[edge[0]]] = 1.
    return connected

In [6]:
def div_2(M2):
    eps = 0.1
    return 1/(M2+eps)+M2-2

def logexp_2(M2):
    eps = 0.1
    return torch.log(torch.exp(1/(M2+eps))+M2-2.5)

In [26]:
Ig = torch.ones((N,N), dtype=torch.float32)-torch.eye(N, N)
node_to_index, index_to_node = get_index_maps(G)
Ic = get_connection_matrix(G, node_to_index, N)


In [27]:
node_x = torch.normal(0, np.sqrt(N), (N,1), requires_grad=True, dtype=torch.float32)
node_y = torch.normal(0, np.sqrt(N), (N,1), requires_grad=True, dtype=torch.float32)
optim = torch.optim.Adam([node_x, node_y], lr=1.0)

In [28]:
losses = []
T = 500
pos = np.zeros((2, N, T))
pos[0, :, 0] = node_x.clone().detach()[:,0]
pos[1, :, 0] = node_y.clone().detach()[:,0]
for t in range(T-1):
  
  R2 = torch.pow(node_x-node_x.T, 2) + torch.pow(node_y-node_y.T, 2)
  Gen = R2*Ig
  Con = R2*Ic
  Eg = logexp_2(Gen/N)
  Ec = div_2(Con)
  E = Eg + Ec
  L = torch.sum(E)

  optim.zero_grad()
  L.backward()
  optim.step()
  losses.append(L.item())
  pos[0, :, t+1] = node_x.clone().detach()[:,0]
  pos[1, :, t+1] = node_y.clone().detach()[:,0]

for node in G.nodes:
    G.nodes[node]['pos'] = pos[:, node_to_index[node], :]

In [29]:
node_pos, edge_pos = graph_to_plotly(G)
fig = animate_network(node_pos, edge_pos)

In [11]:
fig.show()

In [12]:
import plotly
import json

In [13]:
@gif.frame
def get_fram_as_fig(i, fig):
    fig.frames[i]['layout'] = fig['layout']
    new_fig = plotly.io.from_json(json.dumps(fig.frames[0].to_plotly_json()))
    return new_fig

In [15]:

def get_fram_as_fig(i, fig):
    fig.frames[i]['layout'] = fig['layout']
    new_fig = plotly.io.from_json(json.dumps(fig.frames[0].to_plotly_json()))
    return new_fig
    
frames = []
for i in range(T):
    frame = get_fram_as_fig(i, fig)
    frames.append(frame)

# Save gif from frames with a specific duration for each frame in ms
gif.save(frames, 'example.gif', duration=100)

AttributeError: 'Figure' object has no attribute 'save'

In [21]:
frame.show().write()

In [30]:

import imageio
temp_file = "temp.png"
with imageio.get_writer('example.gif', mode='I') as writer:
    for frame in fig.frames:
        frame['layout'] = fig['layout']
        new_fig = plotly.io.from_json(json.dumps(frame.to_plotly_json()))
        new_fig.write_image(temp_file)
        writer.append_data( imageio.imread(temp_file))

import os
os.remove(temp_file)

