# MillenniumAI performance comparison

## Set executable paths

In [21]:
import sys
import networkx as nx
import matplotlib.pyplot as plt
import os
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
from torch_geometric.loader import NeighborLoader
import torch
import subprocess
from typing import Tuple, List, Dict
import socket
import pickle
import time
import tracemalloc

# Necessary to import from sibling directory
sys.path.append("..")


from pymdb import (
    MDBClient,
    TrainGraphLoader,
    EvalGraphLoader,
    SamplingGraphLoader,
    Sampler,
)

In [22]:
# Path to MillenniumDB/MillenniumAI executables
SERVER_PYMDB_PATH = "/home/mdbai/MillenniumDB-Dev/build/Release/bin/server_pymdb"
CREATE_DB_PATH = "/home/mdbai/MillenniumDB-Dev/build/Release/bin/create_db"
GENERATION_BASE_PATH = "/home/mdbai/PyMDB/examples/generated"

if not os.path.exists(SERVER_PYMDB_PATH):
    raise Exception(
        "SERVER_PYMDB_PATH is not set to the correct path. "
        "Please set it to the path of the MillenniumDB server_pymdb executable."
    )

if not os.path.exists(CREATE_DB_PATH):
    raise Exception(
        "CREATE_DB_PATH is not set to the correct path. "
        "Please set it to the path of the MillenniumDB create_db executable."
    )

# Port to run MillenniumDB server on
SERVER_PORT = 8080

## Define performance test function

In [23]:
def gen_node_feat(num_nodes, num_node_feat):
    print("  Generating node features...")
    return torch.zeros((num_nodes, num_node_feat), dtype=torch.float32)
        
def gen_edge(num_nodes, num_edges):
    print("  Generating graph edges...")    
    return torch.randint(0, num_nodes - 1, (2, num_edges))

# Generate graphs in multiple formats. Returns the necessary data of each one
def generate_graphs(
    num_nodes: int,
    num_edges: int,
    num_node_feat: int,
) -> Tuple[Data, str]:
    graph_name = f"N{num_nodes}_E{num_edges}_F{num_node_feat}"
    
    # In-memory graph
    pickle_path = f"{GENERATION_BASE_PATH}/{graph_name}.pkl"
    if os.path.exists(pickle_path):
        # Load a graph from an existing pickle
        print("  Pickle graph dump already exists. Loading file...")
        graph = pickle.load(open(pickle_path, "rb"))
    else:
        # Generate a new graph and dump to a pickle
        graph = Data(
            num_nodes=num_nodes,
            node_feat=gen_node_feat(num_nodes, num_node_feat),
            edge_index=gen_edge(num_nodes, num_edges)
        )
        print("  Writing the generated graph to a pickle...")
        pickle.dump(graph, open(pickle_path, "wb"))

    # On-disk MillenniumDB graph
    mdb_dump_path = f"{GENERATION_BASE_PATH}/{graph_name}.milldb"
    if os.path.exists(mdb_dump_path):
        # Skip MDB dump creation
        print("  MillenniumDB's graph dump already exists. Skipping dump creation...")
    else:
        with open(mdb_dump_path, "w") as f:
            for idx in range(graph.num_nodes):
                f.write(f"N{idx} feat:{graph.node_feat[idx].tolist()}\n")
            for edge in graph.edge_index.T:
                f.write(f"N{edge[0]}->N{edge[1]} :T\n")

    return graph, mdb_dump_path


# Create a MillenniumDB database from a file in the current directory and return its path
def create_db(mdb_dump_path: str) -> str:
    dest_path = mdb_dump_path.replace(".milldb", "")
    
    if os.path.isdir(dest_path):
        print("  MillenniumDB's database already exists. Skipping database creation...")
    else:
        print("  Creating MillenniumDB's database...")
        result = subprocess.run(
            [CREATE_DB_PATH, mdb_dump_path, dest_path],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.PIPE,
        )
        if result.returncode != 0:
            raise RuntimeError(f"create_db: {result.stderr.decode('utf-8')}")
    return dest_path


# Start a MillenniumDB server for a given database directory
def start_server(db_path: str):
    process = subprocess.Popen(
        [SERVER_PYMDB_PATH, db_path, "-p", str(SERVER_PORT)],
        stdout=subprocess.DEVNULL,
        stderr=subprocess.PIPE,
    )

    # Wait for server to listen to port
    while socket.socket().connect_ex(("localhost", SERVER_PORT)) != 0:
        time.sleep(0.5)

    return process


# Kill a MillenniumDB server process and return its exit code
def kill_server(process) -> int:
    process.kill()
    return process.wait()

# Clear both buffer/cache and swap of the OS
def clear_os():
    os.system("sudo sync && echo 3 | sudo tee /proc/sys/vm/drop_caches > /dev/null")
    os.system("sudo swapoff -a && sudo swapon -a")

# Get the a in-memory graph memory estimation in MB (node_features + edge_index)
def graph_size(graph):
    return (graph.node_feat.numel() * 4 + graph.edge_index.numel() * 8) / 1e6

# Run performance tests for a list of instances and a list of batch sizes
def run_performance_tests(instances: List[Dict], batch_sizes: List[int]):
    if not os.path.exists(GENERATION_BASE_PATH):
        os.makedirs(GENERATION_BASE_PATH)

    plot_data = dict()

    for instance in instances:
        print(f"Running for instance: {instance}...")
        # Generate graphs
        graph, mdb_dump_path = generate_graphs(**instance)
        db_path = create_db(mdb_dump_path)

        plot_data[db_path] = {
            "time_mem": list(), 
            "time_mdb": list(),
            "mempeak_mem": list(),
            "mempeak_mdb": list(),
            "batch_size": list(),
        }
        for batch_size in batch_sizes:
            print(f"    Running for batch size: {batch_size}...")
            plot_data[db_path]["batch_size"].append(batch_size)
            
            
            # 1. In-memory graph
            clear_os()
            tracemalloc.start()
            t0_mem = time.perf_counter_ns()
            for batch in NeighborLoader(
                graph, num_neighbors=[5, 5], batch_size=batch_size
            ):
                # Here the batch would be passed to a model
                pass
            plot_data[db_path]["time_mem"].append((time.perf_counter_ns() - t0_mem) / 1e9)
            _, peak = tracemalloc.get_traced_memory()
            plot_data[db_path]["mempeak_mem"].append((peak / 1e6) + graph_size(graph))
            tracemalloc.stop()
            
            
            # 2. MillenniumDB graph
            clear_os()
            tracemalloc.start()
            server_process = start_server(db_path)
            with MDBClient("localhost", SERVER_PORT) as client:
                t0_mdb = time.perf_counter_ns()
                for batch in EvalGraphLoader(
                    client, num_neighbors=[5, 5], batch_size=batch_size
                ):
                    # Here the batch would be passed to a model
                    pass
                plot_data[db_path]["time_mdb"].append((time.perf_counter_ns() - t0_mdb) / 1e9)
                _, peak = tracemalloc.get_traced_memory()
                plot_data[db_path]["mempeak_mdb"].append(peak / 1e6)
                tracemalloc.stop()
            kill_server(server_process)

    return plot_data

## Run performance tests for instances

In [None]:
instances = [
    {"num_nodes": 100_000_000, "num_edges": 1_000_000_000, "num_node_feat": 4},
    #{"num_nodes": 1_000, "num_edges": 10_000, "num_node_feat": 4},
]
args = {
    "instances": instances, 
    "batch_sizes": [10_000, 100_000]
}

plot_data = run_performance_tests(**args)

Running for instance: {'num_nodes': 100000000, 'num_edges': 1000000000, 'num_node_feat': 4}...
  Generating randomized node features...
  Generating randomized graph edges...
  Writing the generated graph to a pickle...


In [None]:
def plot_comparison(x: List, y1: List, y2: List, x_label: str, y_label: str, title: str):
    plt.plot(x, y1, label="In-memory")
    plt.plot(x, y2, label="MillenniumDB")
    
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    plt.legend()
    plt.show()

In [None]:
print(plot_data)

### Time vs Batch Size

In [None]:
# Compare time over graph size
for db_path, data in plot_data.items():
    plot_comparison(
        data["batch_size"],
        data["time_mem"],
        data["time_mdb"],
        "Batch size",
        "Time (s)",
        f"Graph: {os.path.basename(db_path)}",
    )

### Peak memory usage

In [None]:
# Compare time over graph size
for db_path, data in plot_data.items():
    plot_comparison(
        data["batch_size"],
        data["mempeak_mem"],
        data["mempeak_mdb"],
        "Batch size",
        "Peak memory usage (MB)",
        f"Graph: {os.path.basename(db_path)}",
    )

### Cold storage size