In [1]:
from torch_geometric.datasets import Planetoid
import torch
from torch_geometric.data import Data
from torch_geometric.data import DataLoader

In [2]:
x = torch.tensor([[-1, 0, 1], [0, 1, 2], [1, 2, 3], [2, 3, 4]], dtype=torch.float)
y = torch.tensor([0, 1, 0, 1], dtype=torch.float)
 
edge_index = torch.tensor([[0, 2, 1, 0, 3],
                           [3, 1, 0, 1, 2]], dtype=torch.long)
dataset_list = []
for i in range(4):
    data = Data(x=x, edge_index=edge_index, y=y[i])
    dataset_list.append(data)

In [3]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [4]:
custom_dataset = CustomDataset(dataset_list)
custom_dataset[0]

Data(x=[4, 3], edge_index=[2, 5], y=0.0)

In [5]:
dataloader = DataLoader(custom_dataset, batch_size=2, shuffle=True)



In [6]:
for batch in dataloader:
    print(batch)
    print(batch.y)

DataBatch(x=[8, 3], edge_index=[2, 10], y=[2], batch=[8], ptr=[3])
tensor([1., 0.])
DataBatch(x=[8, 3], edge_index=[2, 10], y=[2], batch=[8], ptr=[3])
tensor([0., 1.])


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import SAGEConv

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = SAGEConv(3, 16, aggr='mean')
        self.conv2 = SAGEConv(16, 2, aggr='mean')

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.softmax(x, dim=1)

        return x
    
model = Net()
print(model)
for batch in dataloader:
    out = model(batch)

Net(
  (conv1): SAGEConv(3, 16, aggr=mean)
  (conv2): SAGEConv(16, 2, aggr=mean)
)


In [5]:
import torch
from enum import Enum

class Node_Type(Enum):
    Activity = 1
    Condition = 2

a = torch.Tensor([Node_Type.Activity.value])
a

tensor([1.])