# <center>A-T-L-A-S ATLAS!</center>

In [None]:
# Install required libraries
%pip install numpy pandas networkx matplotlib pyvis plotly

In [2]:
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from pyvis.network import Network
import plotly.graph_objects as go
import plotly.io as pio

In [3]:
def create_atlas_graph(data):
    """
    Creates a directed graph for the input.
    
    Parameters:
        data (list): A list of place names (countries/cities).
    
    Returns:
        G (networkx.DiGraph): A directed graph.
    """
    G = nx.DiGraph()
    
    for place in data:
        G.add_node(place)
        last_letter = place[-1].lower()  # last letter of the name, case-insensitive comparison
        for candidate in data:
            if candidate[0].lower() == last_letter:
                G.add_edge(place, candidate)
    
    return G

In [4]:
with open("/Users/mago/Desktop/Atlas/data/countries.txt", "r") as file:
    countries = [line.strip() for line in file]

with open("/Users/mago/Desktop/Atlas/data/cities.txt", "r") as file:
    cities = [line.strip() for line in file]

combined = countries + cities

country_graph = create_atlas_graph(countries)
city_graph = create_atlas_graph(cities)
combined_graph = create_atlas_graph(combined)

In [5]:
def visualize_graph(G, title, save_path=None, interactive=False):
    """
    Visualizes a directed graph using NetworkX and Matplotlib.
    
    Parameters:
        G (networkx.DiGraph): A directed graph.
        title (str): Title for the graph visualization.
        save_path (str, optional): Path to save the visualization as an image file. Defaults to None.
        interactive (bool, optional): If True, creates an interactive graph using pyvis. Defaults to False.
    """
    if interactive:
        net = Network(height="100%", width="100%", notebook=True, directed=True)
        net.from_nx(G)
        net.set_options("""
          var options = {
            "physics": {
              "enabled": true,
              "stabilization": {
                "enabled": true,
                "iterations": 1000
              },
              "solver": "forceAtlas2Based",
              "forceAtlas2Based": {
                "gravitationalConstant": -50,
                "springLength": 100,
                "springConstant": 0.08,
                "avoidOverlap": 1
              }
            }
          }
        """)
        net.show(f"{title.replace(' ', '_').lower()}.html")
        return
    
    else:
      plt.figure(figsize=(15, 15))
      pos = nx.spring_layout(G, seed=42)
      
      node_color = [degree for _, degree in G.degree()]
      nx.draw_networkx_nodes(
          G, pos, node_size=500, node_color=node_color, cmap=plt.cm.Blues, edgecolors="black"
      )
    
      nx.draw_networkx_edges(G, pos, arrowstyle="->", arrowsize=15, edge_color="gray")
      nx.draw_networkx_labels(G, pos, font_size=10, font_color="black", font_family="sans-serif")
      
      plt.title(title, fontsize=16)
      plt.axis("off")
    
      if save_path:
          plt.savefig(save_path, dpi=300, bbox_inches="tight")

      plt.show()

In [None]:
visualize_graph(country_graph, "Country Graph", interactive=True)
visualize_graph(city_graph, "City Graph", interactive=True)
# visualize_graph(combined_graph, "Combined Graph")

In [7]:
def visualize_static_3d_graph(G, title):
    """
    Visualizes a static 3D graph using Plotly.
    
    Parameters:
        G (networkx.Graph): The graph to visualize.
        title (str): Title of the graph.
    """
    # Generate a 3D layout for the graph
    pos = nx.spring_layout(G, dim=3, seed=42)  # 3D spring layout
    
    # Extract node positions
    x_nodes = [pos[node][0] for node in G.nodes()]
    y_nodes = [pos[node][1] for node in G.nodes()]
    z_nodes = [pos[node][2] for node in G.nodes()]
    
    # Create edge traces
    x_edges = []
    y_edges = []
    z_edges = []
    
    for edge in G.edges():
        x_edges += [pos[edge[0]][0], pos[edge[1]][0], None]
        y_edges += [pos[edge[0]][1], pos[edge[1]][1], None]
        z_edges += [pos[edge[0]][2], pos[edge[1]][2], None]
    
    # Create node trace
    node_trace = go.Scatter3d(
        x=x_nodes, y=y_nodes, z=z_nodes,
        mode='markers',
        marker=dict(
            size=5,
            color=np.arange(len(G.nodes())),  # Color by node index
            colorscale='Viridis',  # Color scheme
            opacity=0.8
        ),
        text=list(G.nodes()),  # Node labels
        hoverinfo='text'
    )
    
    # Create edge trace
    edge_trace = go.Scatter3d(
        x=x_edges, y=y_edges, z=z_edges,
        mode='lines',
        line=dict(color='gray', width=1),
        hoverinfo='none'
    )
    
    # Create the figure
    fig = go.Figure(
        data=[edge_trace, node_trace],
        layout=go.Layout(
            title=title,
            showlegend=False,
            scene=dict(
                xaxis=dict(showbackground=False),
                yaxis=dict(showbackground=False),
                zaxis=dict(showbackground=False)
            )
        )
    )
    
    # Show the figure
    fig.show()

In [None]:
# visualize_static_3d_graph(city_graph, "City Graph - Static 3D Visualization")
visualize_static_3d_graph(country_graph, "Country Graph - Static 3D Visualization")