In [14]:
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx

# from datasets.mp3d import MP3D, extract_object_graph, extract_room_graph
from neural_tree.construct import generate_htree, nx_dsg_jth_to_torch
import datasets.mp3d as mp3d

mp3d.BASE_DIR = './'

## Extracting MP3D Hydra Data

In [18]:
# required
dset = mp3d.MP3D(complete=True)

data = dset[0]
graph_torch = data['dsg_torch']

object_graph = mp3d.extract_object_graph(graph_torch, tonx=False)
room_graph = mp3d.extract_room_graph(graph_torch, tonx=False)

In [19]:
graph_torch

HeteroData(
  [1mobjects[0m={
    x=[74, 6],
    pos=[74, 3],
    label=[74]
  },
  [1mrooms[0m={
    x=[12, 6],
    pos=[12, 3],
    label=[12]
  },
  [1m(objects, objects_to_objects, objects)[0m={
    edge_index=[2, 284],
    edge_attrs={
      objects_to_objects=[0],
      rooms_to_rooms=[0],
      rooms_to_objects=[0]
    }
  },
  [1m(rooms, rooms_to_rooms, rooms)[0m={
    edge_index=[2, 16],
    edge_attrs={
      objects_to_objects=[0],
      rooms_to_rooms=[0],
      rooms_to_objects=[0]
    }
  },
  [1m(rooms, rooms_to_objects, objects)[0m={
    edge_index=[2, 74],
    edge_attrs={
      objects_to_objects=[0],
      rooms_to_rooms=[0],
      rooms_to_objects=[0]
    }
  }
)

In [20]:
# this generates htree for each connected component of graph_torch

htree_list = generate_htree(graph_torch)

In [21]:
# the number of htrees generated

len(htree_list)

4

In [22]:
# use .jth to extract jth

nx_dsg_jth = htree_list[0].jth

In [23]:
# convert one htree in the list to HeteroData

torch_dsg_jth = nx_dsg_jth_to_torch(nx_dsg_jth)

In [24]:
torch_dsg_jth

HeteroData(
  [1mobject[0m={
    x=[741, 6],
    y=[741],
    pos=[741, 3]
  },
  [1mroom[0m={
    x=[253, 6],
    y=[253],
    pos=[253, 3]
  },
  [1mobject-room[0m={
    x=[329, 6],
    y=[329],
    pos=[329, 3]
  },
  [1mroom-room[0m={
    x=[8, 6],
    y=[8],
    pos=[8, 3]
  },
  [1m(object-room, 0, object)[0m={ edge_index=[2, 741] },
  [1m(object-room, 1, room)[0m={ edge_index=[2, 237] },
  [1m(room-room, 2, room)[0m={ edge_index=[2, 16] },
  [1m(room-room, 3, object-room)[0m={ edge_index=[2, 75] },
  [1m(object-room, 4, object-room)[0m={ edge_index=[2, 254] },
  [1m(room-room, 5, room-room)[0m={ edge_index=[2, 14] }
)

In [25]:
graph_torch

HeteroData(
  [1mobjects[0m={
    x=[74, 6],
    pos=[74, 3],
    label=[74]
  },
  [1mrooms[0m={
    x=[12, 6],
    pos=[12, 3],
    label=[12]
  },
  [1m(objects, objects_to_objects, objects)[0m={
    edge_index=[2, 284],
    edge_attrs={
      objects_to_objects=[0],
      rooms_to_rooms=[0],
      rooms_to_objects=[0]
    }
  },
  [1m(rooms, rooms_to_rooms, rooms)[0m={
    edge_index=[2, 16],
    edge_attrs={
      objects_to_objects=[0],
      rooms_to_rooms=[0],
      rooms_to_objects=[0]
    }
  },
  [1m(rooms, rooms_to_objects, objects)[0m={
    edge_index=[2, 74],
    edge_attrs={
      objects_to_objects=[0],
      rooms_to_rooms=[0],
      rooms_to_objects=[0]
    }
  }
)