# GNN with Pytorch Geometric - heterogeneous
In this notebook, our data will be loaded with the heterogeneous function of Pytorch Geometric. Afterwards, the loaded data will be tested with a simple graph neural network (GNN). </br>
Note that, due to problems with the local environment and pytorch, this notebook was tested and developed to work in google colab. 

## Install and import necesarry pytorch packages

In [None]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

1.12.1+cu113


In [None]:
from torch_geometric.data import HeteroData
from torch_geometric.data import InMemoryDataset
from typing import Callable, List, Optional
import os.path as osp

from torch_geometric.nn import GATConv, Linear, to_hetero
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

from torch_geometric.nn import GATConv, Linear, to_hetero
import torch.nn.functional as F

### Additional functions
As the txt to array function did not work for our data, a txt to array for floats and one for integers is made

In [None]:
def parse_txt_array_float(src, sep=None, start=0, end=None, dtype=None, device=None):
    src = [[float(x) for x in line.split(sep)[start:end]] for line in src]
    src = torch.tensor(src, dtype=dtype).squeeze()
    return src


def read_txt_array_float(path, sep=None, start=0, end=None, dtype=None, device=None):
    with open(path, 'r') as f:
        src = f.read().split('\n')[:-1]
    return parse_txt_array_float(src, sep, start, end, dtype, device)


def parse_txt_array_int(src, sep=None, start=0, end=None, dtype=None, device=None):
    src = [[int(x) for x in line.split(sep)[start:end]] for line in src]
    src = torch.tensor(src, dtype=dtype).squeeze()
    return src


def read_txt_array_int(path, sep=None, start=0, end=None, dtype=None, device=None):
    with open(path, 'r') as f:
        src = f.read().split('\n')[:-1]
    return parse_txt_array_int(src, sep, start, end, dtype, device)

### Create dataloader
Load the heterodata witht the created dataloader

In [None]:


class hetero(InMemoryDataset):
  def __init__(self, root: str, preprocess: Optional[str] = None,
                transform: Optional[Callable] = None,
                pre_transform: Optional[Callable] = None):
      preprocess = None if preprocess is None else preprocess.lower()
      self.preprocess = preprocess
      assert self.preprocess in [None, 'transe']
      super().__init__(root, transform, pre_transform)
      self.data, self.slices = torch.load(self.processed_paths[0])

  @property
  def num_classes(self) -> int:
    return int(self.data['molecule']['y'].max())+1

  @property
  def raw_dir(self) -> str:
      return osp.join(self.root,'raw')

  @property
  def processed_dir(self) -> str:
      return osp.join(self.root, 'processed')

  @property
  def raw_file_names(self) -> List[str]:
      file_names = os.listdir(osp.join(self.root, 'raw'))

      if self.preprocess is not None:
          file_names += [f'mag_{self.preprocess}_emb.pt']

      return file_names

  @property
  def processed_file_names(self) -> str:
      if self.preprocess is not None:
          return f'data_{self.preprocess}.pt'
      else:
          return 'data.pt'



  def process(self):
    data = HeteroData()

    data['molecule'].x = read_txt_array_float(osp.join(self.root, 'raw', 'att_molecule.txt'), sep=',', dtype=torch.long)
    data['molecule'].y = read_txt_array_int(osp.join(self.root, 'raw', 'mol_y.txt'), sep=',', dtype=torch.long)
    data['molecule'].train_mask = read_txt_array_int(osp.join(self.root, 'raw', 'train_mask.txt'), sep=',', dtype=torch.long)
    data['molecule'].test_mask = read_txt_array_int(osp.join(self.root, 'raw', 'test_mask.txt'), sep=',', dtype=torch.long)
    data['atom'].x = read_txt_array_float(osp.join(self.root, 'raw', 'att_atom.txt'), sep=',', dtype=torch.long)
    data['bond'].x = read_txt_array_float(osp.join(self.root, 'raw', 'att_bond.txt'), sep=',', dtype=torch.long)
    data['ring'].x = read_txt_array_float(osp.join(self.root, 'raw', 'att_ring.txt'), sep=',', dtype=torch.long)
    data['reaction'].x = read_txt_array_int(osp.join(self.root, 'raw', 'att_reaction.txt'), sep=',', dtype=torch.long)

    data['molecule', 'has_atom1', 'atom'].edge_index = read_txt_array_int(osp.join(self.root, 'raw', 'A_HAS_ATOM_Molecule.txt'), sep=',', dtype=torch.long).t()
    data['molecule', 'has_bond1', 'bond'].edge_index = read_txt_array_int(osp.join(self.root, 'raw', 'A_HAS_BOND_Molecule.txt'), sep=',', dtype=torch.long).t()
    data['molecule', 'has_ring', 'ring'].edge_index = read_txt_array_int(osp.join(self.root, 'raw', 'A_HAS_RING.txt'), sep=',', dtype=torch.long).t()
    data['atom', 'bonded_with', 'bond'].edge_index = read_txt_array_int(osp.join(self.root, 'raw', 'A_BONDED_WITH.txt'), sep=',', dtype=torch.long).t()
    data['ring', 'has_atom2', 'atom'].edge_index = read_txt_array_int(osp.join(self.root, 'raw', 'A_HAS_ATOM_Ring.txt'), sep=',', dtype=torch.long).t()
    data['ring', 'has_bond2', 'bond'].edge_index = read_txt_array_int(osp.join(self.root, 'raw', 'A_HAS_BOND_Ring.txt'), sep=',', dtype=torch.long).t()
    data['molecule', 'reacts_in', 'reaction'].edge_index = read_txt_array_int(osp.join(self.root, 'raw', 'A_REACTS_IN.txt'), sep=',', dtype=torch.long).t()
    data['reaction', 'produces', 'molecule'].edge_index = read_txt_array_int(osp.join(self.root, 'raw', 'A_PRODUCES.txt'), sep=',', dtype=torch.long).t()

    torch.save(self.collate([data]), self.processed_paths[0])

  def __repr__(self) -> str:
    return 'hetero()'

## Use the data loader
Load the chosen files with the LoadDate class. This dataset consists of only molecules, no reaction nodes, the exact description of these files can be found in notebook 2.3

In [None]:
dataset = hetero(root='2-Pytorch geometric data/hetero/MOLfiles/')

Processing...
Done!


Show the first graph as example

In [None]:
data = dataset[0]
data

HeteroData(
  [1mmolecule[0m={
    x=[31, 2],
    y=[31],
    train_mask=[31],
    test_mask=[31]
  },
  [1matom[0m={ x=[232, 17] },
  [1mbond[0m={ x=[218, 4] },
  [1mring[0m={ x=[17, 3] },
  [1mreaction[0m={ x=[13] },
  [1m(molecule, has_atom1, atom)[0m={ edge_index=[2, 232] },
  [1m(molecule, has_bond1, bond)[0m={ edge_index=[2, 218] },
  [1m(molecule, has_ring, ring)[0m={ edge_index=[2, 17] },
  [1m(atom, bonded_with, bond)[0m={ edge_index=[2, 436] },
  [1m(ring, has_atom2, atom)[0m={ edge_index=[2, 102] },
  [1m(ring, has_bond2, bond)[0m={ edge_index=[2, 102] },
  [1m(molecule, reacts_in, reaction)[0m={ edge_index=[2, 13] },
  [1m(reaction, produces, molecule)[0m={ edge_index=[2, 13] }
)

## Graph Neural Network
In this part, we will test the loaded graph dataset by creating a basic GNN. The code below is broadly based on a tutorial of PyG: https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing#scrollTo=HvhgQoO8Svw4.</br> The goal of this GNN is to classify whether molecules that are part of a reaction or not.

### Create the GNN model

In [None]:
# Create the GAT class
class GAT(torch.nn.Module):
    # Define the hidden channels
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GATConv((-1, -1), hidden_channels, add_self_loops=False)
        self.conv2 = GATConv(hidden_channels, hidden_channels, add_self_loops=False)
        self.conv3 = GATConv((-1, -1), out_channels, add_self_loops=False)

    def forward(self, x, edge_index):
        # obtain the node embeddings
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)
        return x

# Call the model
model = GAT(hidden_channels=64, out_channels=dataset.num_classes)
# Convert the GAT to a heterogeneous model 
model = to_hetero(model, data.metadata(), aggr='sum')
print(model)

GraphModule(
  (conv1): ModuleDict(
    (molecule__has_atom1__atom): GATConv((-1, -1), 64, heads=1)
    (molecule__has_bond1__bond): GATConv((-1, -1), 64, heads=1)
    (molecule__has_ring__ring): GATConv((-1, -1), 64, heads=1)
    (atom__bonded_with__bond): GATConv((-1, -1), 64, heads=1)
    (ring__has_atom2__atom): GATConv((-1, -1), 64, heads=1)
    (ring__has_bond2__bond): GATConv((-1, -1), 64, heads=1)
    (molecule__reacts_in__reaction): GATConv((-1, -1), 64, heads=1)
    (reaction__produces__molecule): GATConv((-1, -1), 64, heads=1)
  )
  (conv2): ModuleDict(
    (molecule__has_atom1__atom): GATConv(64, 64, heads=1)
    (molecule__has_bond1__bond): GATConv(64, 64, heads=1)
    (molecule__has_ring__ring): GATConv(64, 64, heads=1)
    (atom__bonded_with__bond): GATConv(64, 64, heads=1)
    (ring__has_atom2__atom): GATConv(64, 64, heads=1)
    (ring__has_bond2__bond): GATConv(64, 64, heads=1)
    (molecule__reacts_in__reaction): GATConv(64, 64, heads=1)
    (reaction__produces__molec

### Run the GNN
Below the first test of the GNN with heterogeneous data can be seen. As can be seen, the GNN can be created and it worked on the confidential dataset, but some but some tweaking has to be done in order to train it with this dataset. The error given states that the scalar type has to be 'long' and this type is given to all the tensors, therefore its uncertain if the error due to an error in this notebook or a bug in the to_hetero module of pytorch geometric.

In [None]:
# Set the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Set the criterion
criterion = torch.nn.CrossEntropyLoss()

def train():
  model.train()
  
  out = model(data.x_dict, data.edge_index_dict)
  mask = data['molecule'].train_mask
  loss = criterion(out['molecule'][mask], data['molecule'].y[mask])
  loss.backward()
  optimizer.step()
  optimizer.zero_grad()
  return float(loss)

def test():
  model.eval()
  
  out = model(data.x_dict, data.edge_index_dict)
  mask = data['molecule'].train_mask

  print(out['molecule'][mask])
  # mask = data['molecule'].test_mask
  # loss = criterion(out['molecule'][mask], data['molecule'].y[mask])
  return out

for epoch in range(1, 2):
  print(test())
  

RuntimeError: ignored