In [14]:
import networkx as nx
from graph_utils import plot_graph
import numpy as np
import pandas as pd
from gurobipy import GRB, Model

In [15]:
static = pd.read_csv("csv_files/static_with_t.csv")
detected = pd.read_csv("csv_files/detection_graph_softmax.csv")

In [16]:
componenets = detected.columns
componenets = [x for x in componenets if x not in ["name", "number", "x", "y", "z"]]

In [17]:
level_order = [
    "root_cross_right",
    "root_cross_left",
    "root_cross_right2",
    "conn_k",
    "cross_e",
    "cross_bag_br",
    "cross_ij1",
    "cross_e_c",
    "cross_f",
    "bag_b_r",
    "bag_a_r",
    "conn_b",
    "cross_ij2",
    "conn_e",
    "conn_c",
    "cross_k",
    "cross_f30",
    "conn_i",
    "conn_j",
    "box",
    "conn_a",
    "conn_f_30",
    "cross_f37",
    "conn_f_60",
    "bag_b_l",
    "cross_gh",
    "conn_f_37",
    "conn_d",
    "conn_g",
    "conn_h",
    "bag_a_l",
]

In [18]:
quadratic_model = Model("quadratic")

In [19]:
variables = {}
for name in level_order:
    variables[f"{name}_x"] = quadratic_model.addVar(
        vtype=GRB.CONTINUOUS, lb=-250, ub=250, name=f"{name}_x"
    )
    variables[f"{name}_y"] = quadratic_model.addVar(
        vtype=GRB.CONTINUOUS, lb=-250, ub=250, name=f"{name}_y"
    )
    variables[f"{name}_z"] = quadratic_model.addVar(
        vtype=GRB.CONTINUOUS, lb=-250, ub=250, name=f"{name}_z"
    )

In [20]:
obj_fn = 0
for row in detected.iterrows():
    detected_x = row[1]["x"]
    detected_y = row[1]["y"]
    detected_z = row[1]["z"]
    for component in componenets:
        confidence = row[1][component]
        term = confidence * (
            (variables[f"{component}_x"] - detected_x) ** 2
            + (variables[f"{component}_y"] - detected_y) ** 2
            + (variables[f"{component}_z"] - detected_z) ** 2
        )
        obj_fn += term
    

In [21]:
for name in level_order:
    node_info = static[static["name"] == name]
    parent = node_info["parent"].values[0]
    distance_to_parent = node_info["distance_to_parent"].values[0]
    if name == "root_cross_left":
        quadratic_model.addQConstr(variables[f"{name}_x"] == -13)
        quadratic_model.addQConstr(variables[f"{name}_y"] == -20)
        quadratic_model.addQConstr(variables[f"{name}_z"] == 0)

    elif name == "root_cross_right":
        quadratic_model.addQConstr(variables[f"{name}_x"] == 10)
        quadratic_model.addQConstr(variables[f"{name}_y"] == -1)
        quadratic_model.addQConstr(variables[f"{name}_z"] == 0)
    else:
        quadratic_model.addQConstr(
            (
                (variables[f"{name}_x"] - variables[f"{parent}_x"]) ** 2
                + (variables[f"{name}_y"] - variables[f"{parent}_y"]) ** 2
                + (variables[f"{name}_z"] - variables[f"{parent}_z"]) ** 2
            )
            <= distance_to_parent**2
        )

In [22]:
quadratic_model.setObjective(obj_fn, GRB.MINIMIZE)

In [23]:
quadratic_model.optimize()

Gurobi Optimizer version 10.0.1 build v10.0.1rc0 (linux64)

CPU model: 11th Gen Intel(R) Core(TM) i5-11500H @ 2.90GHz, instruction set [SSE2|AVX|AVX2|AVX512]
Thread count: 6 physical cores, 12 logical processors, using up to 12 threads

Optimize a model with 6 rows, 93 columns and 6 nonzeros
Model fingerprint: 0xab29339c
Model has 54 quadratic objective terms
Model has 29 quadratic constraints
Coefficient statistics:
  Matrix range     [1e+00, 1e+00]
  QMatrix range    [1e+00, 2e+00]
  Objective range  [7e-02, 4e+03]
  QObjective range [2e+01, 2e+01]
  Bounds range     [2e+02, 2e+02]
  RHS range        [1e+00, 2e+01]
  QRHS range       [2e+01, 1e+04]
Presolve removed 6 rows and 6 columns
Presolve time: 0.01s
Presolved: 234 rows, 202 columns, 403 nonzeros
Presolved model has 30 second-order cone constraints
Ordering time: 0.00s

Barrier statistics:
 AA' NZ     : 1.419e+04
 Factor NZ  : 1.502e+04
 Factor Ops : 1.585e+06 (less than 1 second per iteration)
 Threads    : 1

                

In [24]:
import networkx as nx
from graph_utils import plot_graph
import numpy as np
import pandas as pd


G1 = nx.Graph()
for i in range(len(static)):
    G1.add_node(static["name"][i], pos=(static["x"][i], static["z"][i], static["y"][i]))
for i in range(len(static)):
    if static["parent"][i] != 0 and static["parent"][i] != "0":
        G1.add_edge(static["name"][i], static["parent"][i])


G2 = nx.Graph()
for name in level_order:
    # if x and y and z are  different than 250 or -250
    if variables[f"{name}_x"].x != 250 and variables[f"{name}_x"].x != -250:
        # print(name, variables[f"{name}_x"].x, variables[f"{name}_y"].x, variables[f"{name}_z"].x)
        G2.add_node(
            name,
            pos=(
                variables[f"{name}_x"].x,
                variables[f"{name}_z"].x,
                variables[f"{name}_y"].x,
            ),
        )


plot_graph(static_graph=G1, detected_graph=G2, name="qcqp2", save=True)

Graph saved to qcqp2.html


In [25]:
def calculate_distance(point1, point2):
    """Calculates the Euclidean distance between two 3D points"""
    x1, y1, z1 = point1
    x2, y2, z2 = point2
    return np.sqrt((x1 - x2)**2 + (y1 - y2)**2 + (z1 - z2)**2)

def find_closest_node(G1, G2):
    """Finds the closest G1 node for each node in G2"""
    closest_nodes = {}
    for g2_node in G2.nodes(data=True):
        g2_pos = g2_node[1]['pos']
        min_dist = float('inf')
        min_node = None
        for g1_node in G1.nodes(data=True):
            g1_pos = g1_node[1]['pos']
            dist = calculate_distance(g1_pos, g2_pos)
            if dist < min_dist:
                min_dist = dist
                min_node = g1_node[0]
        closest_nodes[g2_node[0]] = min_node
    return closest_nodes

# Call the function
closest_nodes = find_closest_node(G1, G2)

# If you want to see the result in a DataFrame
df_closest_nodes = pd.DataFrame(closest_nodes.items(), columns=['G2_Node', 'Closest_G1_Node'])
print(df_closest_nodes)


              G2_Node    Closest_G1_Node
0    root_cross_right   root_cross_right
1     root_cross_left    root_cross_left
2   root_cross_right2  root_cross_right2
3              conn_k             conn_k
4             cross_e            cross_e
5        cross_bag_br       cross_bag_br
6           cross_ij1          cross_ij1
7           cross_e_c          cross_e_c
8             cross_f            cross_f
9             bag_b_r            bag_b_r
10            bag_a_r            bag_a_r
11             conn_b             conn_b
12          cross_ij2          cross_ij2
13             conn_e             conn_c
14             conn_c             conn_c
15            cross_k            cross_k
16          cross_f30          cross_f30
17             conn_i             conn_i
18             conn_j             conn_j
19                box                box
20             conn_a             conn_a
21          conn_f_30          conn_f_30
22          cross_f37          cross_f37
23          conn

In [26]:
detected

Unnamed: 0,name,number,x,y,z,conn_a,conn_f_30,conn_f_37,conn_f_60,conn_e,...,conn_g,conn_h,conn_k,conn_d,bag_b_r,bag_a_r,conn_b,conn_i,conn_j,bag_a_l
0,conn_a,104,20.0,-50.0,-9.0,0.866898,0.019781,0.027157,0.020787,0.000000,...,0.000000,0.000000,0.000000,0.015087,0.000000,0.018105,0.000000,0.009052,0.000335,0.005700
1,conn_f_30,30,12.0,-86.0,0.0,0.000000,0.950971,0.000000,0.000000,0.012007,...,0.005503,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.020012,0.000000
2,conn_f_37,37,17.0,-115.0,-9.0,0.010904,0.000000,0.951823,0.005897,0.000000,...,0.000000,0.007900,0.000000,0.000000,0.000445,0.000000,0.008456,0.010792,0.000000,0.003783
3,conn_f_60,60,-24.0,-120.0,-16.0,0.008304,0.006878,0.000000,0.940863,0.000000,...,0.000000,0.002013,0.001929,0.000000,0.000000,0.002936,0.006375,0.007633,0.006039,0.003020
4,conn_e,10,-5.0,-87.0,-12.0,0.003093,0.011599,0.034796,0.000000,0.877829,...,0.000000,0.000000,0.000000,0.000000,0.057219,0.000000,0.000000,0.000000,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
175,bag_a_r,236,13.0,-150.0,0.0,0.000000,0.015237,0.000000,0.000000,0.000000,...,0.033769,0.004530,0.020179,0.000000,0.000000,0.838566,0.000000,0.000000,0.000000,0.009060
176,conn_b,222,21.0,-86.0,0.0,0.023835,0.000000,0.005700,0.003109,0.005441,...,0.000000,0.021763,0.000000,0.017358,0.015286,0.001814,0.848180,0.017099,0.004404,0.000000
177,conn_i,212,16.0,-75.0,3.0,0.012937,0.006102,0.000000,0.017087,0.024410,...,0.000000,0.020260,0.003661,0.003173,0.000000,0.009520,0.006347,0.844508,0.016599,0.000000
178,conn_j,214,22.0,-73.0,3.0,0.000000,0.053169,0.048829,0.000000,0.000000,...,0.000000,0.023872,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.824760,0.033095
