In [1]:
import torch
import torch_geometric.transforms as T
import torch_geometric.nn.conv.cg_conv as CGConv
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch_geometric.utils import to_dgl
import dgl

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
data = torch.load('../data/hetero_graph_data.pt')

In [13]:
data

HeteroData(
  customer={
    x=[161086, 12],
    index=[161086],
  },
  product={
    x=[2708, 15],
    index=[2708],
  },
  (customer, order, product)={
    edge_index=[2, 11892915],
    edge_attr=[11892915, 11],
  },
  (product, rev_order, customer)={
    edge_index=[2, 11892915],
    edge_attr=[11892915, 11],
  }
)

In [4]:
data.validate()

True

In [7]:
data.metadata()[1]

[('customer', 'order', 'product'), ('product', 'rev_order', 'customer')]

In [12]:
data['customer'].index = torch.tensor(range(0, 161086))
data['product'].index = torch.tensor(range(0, 2708))

In [19]:
dgl_data = to_dgl(data)

In [20]:
dgl_data

Graph(num_nodes={'customer': 161086, 'product': 2708},
      num_edges={('customer', 'order', 'product'): 11892915, ('product', 'rev_order', 'customer'): 11892915},
      metagraph=[('customer', 'product', 'order'), ('product', 'customer', 'rev_order')])

In [16]:
dgl.save_graphs('../data/graph_data.bin', [dgl_data])

In [2]:
dgl_data = dgl.load_graphs('../data/graph_data.bin')[0][0]

In [3]:
dgl_data

Graph(num_nodes={'customer': 161086, 'product': 2708},
      num_edges={('customer', 'order', 'product'): 11892915, ('product', 'rev_order', 'customer'): 11892915},
      metagraph=[('customer', 'product', 'order'), ('product', 'customer', 'rev_order')])

In [12]:
dgl_data.edges['order']

EdgeSpace(data={'edge_attr': tensor([[  1.0000,  19.0000,   0.0000,  ...,   1.0000,   0.0000,   0.0000],
        [  2.0000, 166.0000,   0.0000,  ...,   1.0000,   0.0000,   0.0000],
        [  2.0000,  43.0000,   0.0000,  ...,   1.0000,   0.0000,   0.0000],
        ...,
        [ 10.0000,  27.8847,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
        [ 10.0000, 138.9000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
        [  6.0000, 118.8000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]])})

In [15]:
dgl_data.nodes['customer'].data['x'].shape

torch.Size([161086, 12])