In this part of our tutorial, we will explore how GNNs can be used to estimate missing tabular data on the [extended Iris dataset](https://www.kaggle.com/datasets/samybaladram/iris-dataset-extended/data). The approach we will be implementing is called GRAPE and can be found [this paper](https://proceedings.neurips.cc/paper/2020/file/dc36f18a9a0a776671d4879cae69b551-Paper.pdf).

In [None]:
# @title Required setup
!pip3 install torch_geometric

# Downloads and unpacks the dataset
!kaggle datasets download -d samybaladram/iris-dataset-extended
!unzip iris-dataset-extended.zip

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn import model_selection
from sklearn import preprocessing
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data
import torch_geometric.datasets as datasets


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

First, let's look at the data.

In [None]:
iris_df = pd.read_csv('iris_extended.csv')
iris_df.sample(5)

We have a bunch of categorical variables, let's encode them using one-hot encoding.

In [None]:
iris_df = pd.get_dummies(iris_df, drop_first=True)
iris_df.sample(5)

In [None]:
# @title Consider normalizing the dataset

NORMALIZE = False  # @param {'type': 'boolean'}

iris_df = iris_df.to_numpy()
if NORMALIZE:
  iris_df = preprocessing.MinMaxScaler().fit_transform(iris_df)

Now let's encode the data into a graph according to the following rules:
Nodes - dataset entries and features, edges - feature values.

In [None]:
def encode_data(data: pd.DataFrame, train_mask: np.ndarray) -> np.ndarray:
  """Encodes tabular data into a graph."""
  # Number of dataset entries.
  num_entries = data.shape[0]

  # Number of features in the dataset.
  num_features = data.shape[1]

  # Computes the number of edges in the graph.
  num_edges = num_entries * num_features

  # Creates train and test indices according to the `train_mask`.
  train_indices = np.arange(num_edges)[train_mask]
  test_indices = np.arange(num_edges)[~train_mask]

  # Finds the index of the first feature node.
  # First `num_entries` nodes correspons to the entries of the dataset.
  least_feature_node_id = num_entries

  # Specifies nodes features. Here, we are using them to specify the
  # one-hot-encoded type of a node.
  entry_nodes = np.concatenate(
      [np.ones((num_entries, 1)), np.zeros((num_entries, num_features))], axis=1
  )
  feature_nodes = np.concatenate(
      [np.zeros((num_features, 1)), np.identity(num_features)], axis=1
  )
  nodes_features = np.concatenate([entry_nodes, feature_nodes]).astype(
      np.float32
  )

  # Defines graph connectivity and has the final shape of [2, `num_edges`].
  edge_index = []
  # Edge feature matrix with shape [`num_edges`, `num_features`].
  edge_attr = []
  # Retrieves edge indices (indices of nodes that are connected by that edge).
  # Builds a directed graph, where all nodes start in an entry node and end in
  # a feature node.
  for entry_index, features_per_entry in enumerate(data):
    for feature_index, feature_value in enumerate(features_per_entry):
      edge_index.append([entry_index, least_feature_node_id + feature_index])
      edge_attr.append(feature_value)

  edge_index = np.array(edge_index, dtype=np.int64).T
  edge_attr = np.array(edge_attr, dtype=np.float32).reshape(-1, 1)

  # Splits edges and attributes into train and tests subsets.
  edge_index_train = edge_index[:, train_indices]
  edge_index_test = edge_index[:, test_indices]
  edge_attr_train = edge_attr[train_indices]
  edge_attr_test = edge_attr[test_indices]
  return Data(
      x=nodes_features,
      edge_index_train=edge_index_train,
      edge_index_test=edge_index_test,
      edge_attr_train=edge_attr_train,
      edge_attr_test=edge_attr_test,
  )

In [None]:
TRAIN_RATIO = 0.7  # @param {'type': 'number'}
train_mask = (
    np.random.RandomState(0)
    .binomial(1, TRAIN_RATIO, iris_df.shape[0] * iris_df.shape[1])
    .astype(bool)
)

In [None]:
# @title Let's visualize the resulting train/test split

plt.imshow(train_mask.reshape(iris_df.shape[0], iris_df.shape[1])[:40])
plt.colorbar()
plt.show()

In [None]:
class Net(torch.nn.Module):

  def __init__(
      self,
      *,
      node_input_dim: int,
      edge_input_dim: int,
      node_hidden_dim: int,
      edge_hidden_dim: int,
  ):
    super().__init__()

    self.node_conv = torch_geometric.nn.SAGEConv(
        node_input_dim, node_hidden_dim
    )
    self.edge_update_mlps = nn.Sequential(
        nn.Linear(2 * node_hidden_dim + edge_input_dim, edge_hidden_dim),
        torch.nn.ReLU(),
        nn.Linear(edge_hidden_dim, edge_input_dim),
        torch.nn.ReLU(),
    )

  def forward(
      self, x: torch.Tensor, edge_attr: torch.Tensor, edge_index: torch.Tensor
  ):
    x = self.node_conv(x, edge_index)
    x_from = x[edge_index[0]]
    x_to = x[edge_index[1]]
    edge_attr = self.edge_update_mlps(
        torch.cat([x_from, x_to, edge_attr], dim=-1)
    )
    return x, edge_attr

In [None]:
NUM_EPOCHS = 200  # @param {'type': 'number'}
LEARNING_RATE = 0.001  # @param {'type': 'number'}
WEIGHT_DECAY = 0.0000001  # @param {'type': 'number'}


def train(gnn: torch.nn.Module, graph: Data) -> tuple[list[float], list[float]]:
  train_loss, val_loss = [], []
  # Puts all of the tensors to the device in use.
  x = torch.from_numpy(graph.x).to(device)
  edge_attr = torch.from_numpy(graph.edge_attr_train).to(device)
  edge_attr_test = torch.from_numpy(graph.edge_attr_test).to(device)
  edge_index_train = torch.from_numpy(graph.edge_index_train).to(device)
  edge_index_test = torch.from_numpy(graph.edge_index_test).to(device)

  optimizer = torch.optim.Adam(
      gnn.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
  )
  for epoch in range(NUM_EPOCHS):
    gnn.train()
    optimizer.zero_grad()
    out, out_edge = gnn(x, edge_attr, edge_index_train)
    loss = F.mse_loss(edge_attr, out_edge)
    loss.backward()
    optimizer.step()
    out.detach().to('cpu')
    out_edge.detach().to('cpu')
    del out
    del out_edge
    train_loss.append(loss.item())
    with torch.no_grad():
      out, out_edge_test = gnn(x, edge_attr_test, edge_index_test)
      loss = F.mse_loss(edge_attr_test, out_edge_test)
      out.detach().to('cpu')
      out_edge_test.detach().to('cpu')
      val_loss.append(loss.item())
  return train_loss, val_loss

In [None]:
graph = encode_data(iris_df, train_mask)

In [None]:
# @title Instantiate the GNN

NODE_HIDDEN_DIM = 128  # @param {'type': 'number'}
EDGE_HIDDEN_DIM = 128  # @param {'type': 'number'}


gnn = Net(
    node_input_dim=graph.x.shape[1],
    edge_input_dim=graph.edge_attr_train.shape[1],
    node_hidden_dim=NODE_HIDDEN_DIM,
    edge_hidden_dim=EDGE_HIDDEN_DIM,
)
gnn = gnn.to(device)

In [None]:
train_loss, val_loss = train(gnn, graph)

plt.plot((x := list(range(NUM_EPOCHS))), train_loss, label='train loss')
plt.plot(x, val_loss, label='val loss')
plt.legend()
plt.title('Training progress')
plt.xlabel('Epoch')
plt.ylabel('Loss value')
plt.show()