In [3]:
import os
import os.path as osp
import random

import numpy as np
import torch
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.loader import DataLoader
import scipy.io as sio

In [4]:
raw_dir = '/home/jacob/Documents/DiGress/Data/Network/data'
splits = ['train', 'valid', 'test']
all_types = set()
raw_datas = {}
for split in splits:
    raw_name = f'data_{split}.mat'
    path = osp.join(raw_dir, raw_name)
    mat = sio.loadmat(path, squeeze_me=True, struct_as_record=False)
    data = mat['data']
    raw_datas[split] = data
    for item in data:
        types = item.rType
        if isinstance(types, np.ndarray):
            types = types.tolist()
        if not isinstance(types, (list, tuple)):
            types = [types]
        all_types.update(types)

type_list = sorted(all_types)
type_to_idx = {t: i for i, t in enumerate(type_list)}
num_types = len(type_list)

for i, split in enumerate(splits):
    data_list = []
    for item in raw_datas[split]:
        types = item.rType
        if isinstance(types, np.ndarray):
            types = types.tolist()
        if not isinstance(types, (list, tuple)):
            types = [types]
        idxs = [type_to_idx[t] for t in types]
        x = torch.eye(num_types, dtype=torch.float)[idxs]
        # Edges with relation attribute
        edges, rel_attrs = [], []
        if hasattr(item, 'rEdge') and item.rEdge is not None:
            for (u, v, r) in item.rEdge:
                u_idx, v_idx = int(u) - 1, int(v) - 1
                edges.append([u_idx, v_idx])
                rel_attrs.append([r])
        if edges:
            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
            edge_attr = torch.tensor(rel_attrs, dtype=torch.long)
        else:
            edge_index = torch.empty((2, 0), dtype=torch.long)
            edge_attr = torch.empty((0, 1), dtype=torch.long)

        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
        data_list.append(data)
        print(data)


Data(x=[8, 13], edge_index=[2, 12], edge_attr=[12, 1])
Data(x=[7, 13], edge_index=[2, 12], edge_attr=[12, 1])
Data(x=[7, 13], edge_index=[2, 9], edge_attr=[9, 1])
Data(x=[8, 13], edge_index=[2, 11], edge_attr=[11, 1])
Data(x=[7, 13], edge_index=[2, 11], edge_attr=[11, 1])
Data(x=[7, 13], edge_index=[2, 11], edge_attr=[11, 1])
Data(x=[5, 13], edge_index=[2, 6], edge_attr=[6, 1])
Data(x=[8, 13], edge_index=[2, 12], edge_attr=[12, 1])
Data(x=[7, 13], edge_index=[2, 11], edge_attr=[11, 1])
Data(x=[6, 13], edge_index=[2, 9], edge_attr=[9, 1])
Data(x=[7, 13], edge_index=[2, 9], edge_attr=[9, 1])
Data(x=[8, 13], edge_index=[2, 12], edge_attr=[12, 1])
Data(x=[6, 13], edge_index=[2, 9], edge_attr=[9, 1])
Data(x=[6, 13], edge_index=[2, 9], edge_attr=[9, 1])
Data(x=[8, 13], edge_index=[2, 13], edge_attr=[13, 1])
Data(x=[6, 13], edge_index=[2, 9], edge_attr=[9, 1])
Data(x=[5, 13], edge_index=[2, 7], edge_attr=[7, 1])
Data(x=[7, 13], edge_index=[2, 12], edge_attr=[12, 1])
Data(x=[6, 13], edge_index

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x757999f4efd0>>
Traceback (most recent call last):
  File "/home/jacob/anaconda3/envs/digress/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 


Data(x=[8, 13], edge_index=[2, 12], edge_attr=[12, 1])
Data(x=[7, 13], edge_index=[2, 11], edge_attr=[11, 1])
Data(x=[6, 13], edge_index=[2, 9], edge_attr=[9, 1])
Data(x=[6, 13], edge_index=[2, 9], edge_attr=[9, 1])
Data(x=[6, 13], edge_index=[2, 9], edge_attr=[9, 1])
Data(x=[6, 13], edge_index=[2, 8], edge_attr=[8, 1])
Data(x=[8, 13], edge_index=[2, 13], edge_attr=[13, 1])
Data(x=[6, 13], edge_index=[2, 9], edge_attr=[9, 1])
Data(x=[7, 13], edge_index=[2, 11], edge_attr=[11, 1])
Data(x=[6, 13], edge_index=[2, 7], edge_attr=[7, 1])
Data(x=[7, 13], edge_index=[2, 9], edge_attr=[9, 1])
Data(x=[8, 13], edge_index=[2, 13], edge_attr=[13, 1])
Data(x=[7, 13], edge_index=[2, 11], edge_attr=[11, 1])
Data(x=[6, 13], edge_index=[2, 8], edge_attr=[8, 1])
Data(x=[7, 13], edge_index=[2, 11], edge_attr=[11, 1])
Data(x=[7, 13], edge_index=[2, 9], edge_attr=[9, 1])
Data(x=[6, 13], edge_index=[2, 9], edge_attr=[9, 1])
Data(x=[7, 13], edge_index=[2, 11], edge_attr=[11, 1])
Data(x=[8, 13], edge_index=[2,

In [5]:
import torch_geometric.utils

In [6]:
raw_dataset = torch.load('/home/jacob/Documents/DiGress/data/planar/raw/test.pt')

data_list = []
for adj in raw_dataset:
    n = adj.shape[-1]
    X = torch.ones(n, 1, dtype=torch.float)
    y = torch.zeros([1, 0]).float()
    edge_index, _ = torch_geometric.utils.dense_to_sparse(adj)
    edge_attr = torch.zeros(edge_index.shape[-1], 2, dtype=torch.float)
    edge_attr[:, 1] = 1
    num_nodes = n * torch.ones(1, dtype=torch.long)
    data = torch_geometric.data.Data(x=X, edge_index=edge_index, edge_attr=edge_attr,
                                        y=y, n_nodes=num_nodes)
    data_list.append(data)
    print("~" * 50)
    print(data)

    data_list.append(data)

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Data(x=[64, 1], edge_index=[2, 358], edge_attr=[358, 2], y=[1, 0], n_nodes=[1])
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Data(x=[64, 1], edge_index=[2, 358], edge_attr=[358, 2], y=[1, 0], n_nodes=[1])
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Data(x=[64, 1], edge_index=[2, 358], edge_attr=[358, 2], y=[1, 0], n_nodes=[1])
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Data(x=[64, 1], edge_index=[2, 350], edge_attr=[350, 2], y=[1, 0], n_nodes=[1])
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Data(x=[64, 1], edge_index=[2, 360], edge_attr=[360, 2], y=[1, 0], n_nodes=[1])
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Data(x=[64, 1], edge_index=[2, 356], edge_attr=[356, 2], y=[1, 0], n_nodes=[1])
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Data(x=[64, 1], edge_index=[2, 350], edge_attr=[350, 2], y=[1, 0], n_nodes=[1])
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Data(x=[64, 1], edge_index=[2, 3

In [21]:
torch.zeros([1, 0]).float()

tensor([], size=(1, 0))

In [3]:
import torch
def fill_diagonal_batch(tensor, value):
    bs, n, _ = tensor.shape
    indices = torch.arange(n, device=tensor.device)
    tensor[:, indices, indices] = value  # Fill with scalar
    return tensor

# Example:
batch_tensor = torch.randn(3, 4, 4)  # (bs=3, n=4)
filled_tensor = fill_diagonal_batch(batch_tensor, 0.0)

In [4]:
filled_tensor

tensor([[[ 0.0000,  0.6064, -0.7532, -0.5429],
         [ 1.3584,  0.0000,  1.2060, -0.3411],
         [-0.0103,  1.0938,  0.0000, -1.5752],
         [ 0.3727,  0.3504,  0.1109,  0.0000]],

        [[ 0.0000, -0.3154, -1.7026, -0.1793],
         [ 0.8449,  0.0000, -1.7058, -1.4850],
         [-0.3475,  0.8404,  0.0000,  0.0291],
         [-1.2044, -0.5580, -0.2038,  0.0000]],

        [[ 0.0000,  0.4896, -0.9906,  0.6461],
         [-0.3323,  0.0000,  1.1444,  0.0899],
         [ 1.0081, -1.6179,  0.0000, -0.3174],
         [-0.0133, -0.3786,  0.4702,  0.0000]]])