<a href="https://colab.research.google.com/github/Clearbloo/Feynman_GNN/blob/main/Feynman_GNN_v4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**Install and import libraries**


In [None]:
## Standard libraries
import os
import os.path as osp
import json
import math
import numpy as np 
import time
from time import time, ctime
import pandas as pd
import ast
from typing import Optional
from functools import partial
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, f1_score, \
    accuracy_score, precision_score, recall_score, roc_auc_score
from google.colab import drive
drive.mount("/content/gdrive", force_remount=True)

import copy
import os.path as osp
import re
import sys
import warnings
from collections.abc import Sequence
from typing import Any, Callable, List, Optional, Tuple, Union

import numpy as np
import torch.utils.data
from torch import Tensor

from torch_geometric.data import Data
from torch_geometric.data.makedirs import makedirs

IndexType = Union[slice, Tensor, np.ndarray, Sequence]

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline 
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()

# Load the TensorBoard notebook extension
%load_ext tensorboard

## Progress bar
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
from torch.nn import MSELoss
# Torchvision
import torchvision
from torchvision import transforms
# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default.
    !pip install pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

## Ray
try:
    import ray
except ModuleNotFoundError: # Google Colab does not have Ray installed by default.
    !pip install ray
    import ray
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
from ray.tune.integration.pytorch_lightning import TuneReportCallback, TuneReportCheckpointCallback

## Tensorboard
try:
  import tensorboardX
except ModuleNotFoundError:
  !pip install tensorboardX
  import tensorboardX

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "/content/gdrive/MyDrive/Part_III_Project/data/"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "/content/gdrive/MyDrive/Part_III_Project/saved_models/"

# Setting the seed
pl.seed_everything()

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = False #True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

# torch geometric
try: 
    import torch_geometric
except ModuleNotFoundError:
    # Installing torch geometric packages with specific CUDA+PyTorch version. 
    # See https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html for details 
    TORCH = torch.__version__.split('+')[0]
    CUDA = 'cu' + torch.version.cuda.replace('.','')

    !pip install --quiet torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
    !pip install --quiet torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
    !pip install --quiet torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
    !pip install --quiet torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
    !pip install --quiet torch-geometric 
    import torch_geometric
import torch_geometric.nn as geom_nn
import torch_geometric.data as geom_data
from torch_geometric.data import Dataset, Data, InMemoryDataset
from torch_geometric.loader import DataLoader
from torch.nn import Linear, BatchNorm1d, ModuleList
from torch_geometric.nn import TopKPooling 
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp


Mounted at /content/gdrive


ModuleNotFoundError: ignored

#**My own dataset class**

In [None]:
class myDataset(torch.utils.data.Dataset):
    r"""Dataset base class for creating graph datasets.
    See `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/
    create_dataset.html>`__ for the accompanying tutorial.

    Args:
        root (string, optional): Root directory where the dataset should be
            saved. (optional: :obj:`None`)
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        pre_filter (callable, optional): A function that takes in an
            :obj:`torch_geometric.data.Data` object and returns a boolean
            value, indicating whether the data object should be included in the
            final dataset. (default: :obj:`None`)
    """
    @property
    def raw_file_names(self) -> Union[str, List[str], Tuple]:
        r"""The name of the files in the :obj:`self.raw_dir` folder that must
        be present in order to skip downloading."""
        raise NotImplementedError

    @property
    def processed_file_names(self) -> Union[str, List[str], Tuple]:
        r"""The name of the files in the :obj:`self.processed_dir` folder that
        must be present in order to skip processing."""
        raise NotImplementedError

    def download(self):
        r"""Downloads the dataset to the :obj:`self.raw_dir` folder."""
        raise NotImplementedError


    def process(self):
        r"""Processes the dataset to the :obj:`self.processed_dir` folder."""
        raise NotImplementedError


    def len(self) -> int:
        r"""Returns the number of graphs stored in the dataset."""
        raise NotImplementedError


    def get(self, idx: int) -> Data:
        r"""Gets the data object at index :obj:`idx`."""
        raise NotImplementedError


    def __init__(self, root: Optional[str] = None,
                 transform: Optional[Callable] = None,
                 pre_transform: Optional[Callable] = None,
                 pre_filter: Optional[Callable] = None):
        super().__init__()

        if isinstance(root, str):
            root = osp.expanduser(osp.normpath(root))

        self.root = root
        self.transform = transform
        self.pre_transform = pre_transform
        self.pre_filter = pre_filter
        self._indices: Optional[Sequence] = None

        if 'download' in self.__class__.__dict__:
            self._download()

        if 'process' in self.__class__.__dict__:
            self._process()

    def indices(self) -> Sequence:
        return range(self.len()) if self._indices is None else self._indices

    @property
    def raw_dir(self) -> str:
        return osp.join(self.root, 'raw')

    @property
    def processed_dir(self) -> str:
        return osp.join(self.root, 'processed')

    @property
    def num_node_features(self) -> int:
        r"""Returns the number of features per node in the dataset."""
        data = self[0]
        data = data[0] if isinstance(data, tuple) else data
        if hasattr(data, 'num_node_features'):
            return data.num_node_features
        raise AttributeError(f"'{data.__class__.__name__}' object has no "
                             f"attribute 'num_node_features'")

    @property
    def num_features(self) -> int:
        r"""Returns the number of features per node in the dataset.
        Alias for :py:attr:`~num_node_features`."""
        return self.num_node_features

    @property
    def num_edge_features(self) -> int:
        r"""Returns the number of features per edge in the dataset."""
        data = self[0]
        data = data[0] if isinstance(data, tuple) else data
        if hasattr(data, 'num_edge_features'):
            return data.num_edge_features
        raise AttributeError(f"'{data.__class__.__name__}' object has no "
                             f"attribute 'num_edge_features'")

    @property
    def raw_paths(self) -> List[str]:
        r"""The absolute filepaths that must be present in order to skip
        downloading."""
        files = to_list(self.raw_file_names)
        return [osp.join(self.raw_dir, f) for f in files]

    @property
    def processed_paths(self) -> List[str]:
        r"""The absolute filepaths that must be present in order to skip
        processing."""
        files = to_list(self.processed_file_names)
        return [osp.join(self.processed_dir, f) for f in files]

    def _download(self):
        if files_exist(self.raw_paths):  # pragma: no cover
            return

        makedirs(self.raw_dir)
        self.download()

    def _process(self):
        f = osp.join(self.processed_dir, 'pre_transform.pt')
        if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
            warnings.warn(
                f"The `pre_transform` argument differs from the one used in "
                f"the pre-processed version of this dataset. If you want to "
                f"make use of another pre-processing technique, make sure to "
                f"sure to delete '{self.processed_dir}' first")

        f = osp.join(self.processed_dir, 'pre_filter.pt')
        if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):
            warnings.warn(
                "The `pre_filter` argument differs from the one used in the "
                "pre-processed version of this dataset. If you want to make "
                "use of another pre-fitering technique, make sure to delete "
                "'{self.processed_dir}' first")

        if files_exist(self.processed_paths):  # pragma: no cover
            return

        print('Processing...', file=sys.stderr)

        makedirs(self.processed_dir)
        self.process()

        path = osp.join(self.processed_dir, 'pre_transform.pt')
        torch.save(_repr(self.pre_transform), path)
        path = osp.join(self.processed_dir, 'pre_filter.pt')
        torch.save(_repr(self.pre_filter), path)

        print('Done!', file=sys.stderr)

    def __len__(self) -> int:
        r"""The number of examples in the dataset."""
        return len(self.indices())

    def __getitem__(
        self,
        idx: Union[int, np.integer, IndexType],
    ) -> Union['Dataset', Data]:
        r"""In case :obj:`idx` is of type integer, will return the data object
        at index :obj:`idx` (and transforms it in case :obj:`transform` is
        present).
        In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a
        tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type long or
        bool, will return a subset of the dataset at the specified indices."""
        if (isinstance(idx, (int, np.integer))
                or (isinstance(idx, Tensor) and idx.dim() == 0)
                or (isinstance(idx, np.ndarray) and np.isscalar(idx))):

            data = self.get(self.indices()[idx])
            data = data if self.transform is None else self.transform(data)
            return data

        else:
            return self.index_select(idx)

    def index_select(self, idx: IndexType) -> 'Dataset':
        r"""Creates a subset of the dataset from specified indices :obj:`idx`.
        Indices :obj:`idx` can be a slicing object, *e.g.*, :obj:`[2:5]`, a
        list, a tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type
        long or bool."""
        indices = self.indices()

        if isinstance(idx, slice):
            indices = indices[idx]

        elif isinstance(idx, Tensor) and idx.dtype == torch.long:
            return self.index_select(idx.flatten().tolist())

        elif isinstance(idx, Tensor) and idx.dtype == torch.bool:
            idx = idx.flatten().nonzero(as_tuple=False)
            return self.index_select(idx.flatten().tolist())

        elif isinstance(idx, np.ndarray) and idx.dtype == np.int64:
            return self.index_select(idx.flatten().tolist())

        elif isinstance(idx, np.ndarray) and idx.dtype == np.bool:
            idx = idx.flatten().nonzero()[0]
            return self.index_select(idx.flatten().tolist())

        elif isinstance(idx, Sequence) and not isinstance(idx, str):
            indices = [indices[i] for i in idx]

        else:
            raise IndexError(
                f"Only slices (':'), list, tuples, torch.tensor and "
                f"np.ndarray of dtype long or bool are valid indices (got "
                f"'{type(idx).__name__}')")

        dataset = copy.copy(self)
        dataset._indices = indices
        return dataset


    def shuffle(
        self,
        return_perm: bool = False,
    ) -> Union['Dataset', Tuple['Dataset', Tensor]]:
        r"""Randomly shuffles the examples in the dataset.

        Args:
            return_perm (bool, optional): If set to :obj:`True`, will also
                return the random permutation used to shuffle the dataset.
                (default: :obj:`False`)
        """
        perm = torch.randperm(len(self))
        dataset = self.index_select(perm)
        return (dataset, perm) if return_perm is True else dataset


    def __repr__(self) -> str:
        arg_repr = str(len(self)) if len(self) > 1 else ''
        return f'{self.__class__.__name__}({arg_repr})'



def to_list(value: Any) -> Sequence:
    if isinstance(value, Sequence) and not isinstance(value, str):
        return value
    else:
        return [value]


def files_exist(files: List[str]) -> bool:
    # NOTE: We return `False` in case `files` is empty, leading to a
    # re-processing of files on every instantiation.
    return len(files) != 0 and all([osp.exists(f) for f in files])


def _repr(obj: Any) -> str:
    if obj is None:
        return 'None'
    return re.sub('(<.*?)\\s.*(>)', r'\1\2', obj.__repr__())

In [None]:
class FeynmanDataset(myDataset):
    def __init__(self, dataset_size, filename, reprocess: bool = False, root=DATASET_PATH, test: bool = False, train: bool = False, val: bool = False, pred: bool = False, transform=None, pre_transform=None, pre_filter=None):
      #print(args)
      #print(kwargs)
      #print(self.__class__)
      """
      root = directory where dataset should be stored. Contains raw data in raw_dir and processed data in processed_dir
      test, train, val = bools, what type of dataset you want. default all false
      """
      self.filename = filename
      self.test = test 
      self.train = train
      self.val = val
      self.pred = pred
      self.reproc = reprocess
      self.label=""
      if self.train == True:
        self.label="train"
      if self.val == True:
        self.label="val"
      if self.test == True:
        self.label="test"
      if self.pred == True:
        self.label="pred"
      
      self.dataset_size = dataset_size
      super().__init__(root, transform, pre_transform, pre_filter)


    @property
    def raw_file_names(self):
      #skips download if this is found
      return self.filename

    @property
    def processed_file_names(self):
      #will skip the process method if the following files are found
      proc_files = []
      for idx in range(self.dataset_size):
        proc_files +=  [f'{self.label}_data_{idx}.pt']
      return proc_files

    def download(self):
        # Download to `self.raw_dir`. In the future I will make this call a python file to build the dataset as a csv
        print("No files to download")
        pass

    def process(self):
      self.data = pd.read_csv(self.raw_paths[0])
      self.data = self.data.sample(n=self.dataset_size)

      
      #create a list of all y values
      all_y_values = []
      for row, feyndiag in self.data.iterrows():
        y = self._get_targets(feyndiag)
        all_y_values += [y]
      
      #cycle through graphs and create data objects for each
      idx=0
      print("Saving Data objects")
      for row, feyndiag in tqdm(self.data.iterrows(), total=self.data.shape[0]):
        
        #node features
        x = self._get_node_features(feyndiag)
        #edge features
        edge_attr = self._get_edge_features(feyndiag)
        #adjacency list
        edge_index = self._get_adj_list(feyndiag)
        #targets
        y = self._get_targets(feyndiag)
        #normalized targets to the interval [0,1]
        y_norm = (y-min(all_y_values))/(max(all_y_values)-min(all_y_values))


        #create data object
        data = Data(x=x, edge_index = edge_index, edge_attr=edge_attr, y=y, y_norm=y_norm)
        
        #save file
        torch.save(data, osp.join(self.processed_dir, f'{self.label}_data_{idx}.pt'))
        idx+=1
      

    def _get_node_features(self, diagram):
      """
      This will return a list of the node feature vectors (which are 1D)
      [Number of Nodes, 1]
      """
      x = ast.literal_eval(diagram.loc['x'])
      x = torch.tensor(x,dtype=torch.float).view(-1,1)
      return x

    def _get_edge_features(self, diagram):
      """
      This will return a list of the edge feature vectors (which are 11D)
      [Number of Edges, 11]
      """
      attr = ast.literal_eval(diagram.loc['edge_attr'])
      return torch.tensor(attr,dtype=torch.float).view(-1,11)
      
    def _get_adj_list(self, diagram):
      """
      This will return a list of the adjacency vectors (which are 2D)
      [2, Number of Edges]
      """
      adj_list = ast.literal_eval(diagram.loc['edge_index'])
      return torch.tensor(adj_list,dtype=torch.long).view(2,-1)

    def _get_targets(self, diagram):
      """
      This will return a list of the target vectors (which are 1D)
      [Number of targets, 1]
      """
      y = diagram.loc['y']
      return torch.tensor(y,dtype=torch.float)

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f'{self.label}_data_{idx}.pt'))
        return data

#**Loss functions**
Defining some loss functions

In [None]:
class LogCoshLoss(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y_t, y_prime_t):
        ey_t = y_t - y_prime_t
        return torch.mean(torch.log(torch.cosh(ey_t + 1e-12)))


In [None]:
class RMSLELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
        
    def forward(self, pred, actual):
        return torch.sqrt(self.mse(torch.log(pred + 1), torch.log(actual + 1)))

In [None]:
class RMSELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
    
    def forward(self, pred, actual):
      return torch.sqrt(self.mse(pred, actual))

#**Training code and GNN model using Lightning Module**

* Lightning training module
* Uses Transformer convolution layer

In [None]:
class FeynModel(pl.LightningModule):
    def __init__(self, c_in, c_out, layer_name, model_params, data_dir=DATASET_PATH, filename='QED_data.csv'):
        """
        c_in = channels in (feature dimensions, e.g. RGB is 3)
        c_out = channels out (target dimension, e.g. classification is 1)
        """
        super().__init__()
        self.data_dir=data_dir
        self.filename=filename
        self.batch_size = model_params["model_batch_size"]
        embedding_size = model_params["model_embedding_size"]
        n_heads = model_params["model_attention_heads"]
        self.n_layers = model_params["model_layers"]
        dropout_rate = model_params["model_dropout_rate"]
        top_k_ratio =  model_params["model_top_k_ratio"]
        self.top_k_every_n = model_params["model_top_k_every_n"]
        dense_neurons = model_params["model_dense_neurons"]
        edge_dim = model_params["model_edge_dim"]-3 #remove momenta from edge_attr
        edge_num = 5 #need to update this
                
        gnn_layer = gnn_layer_by_name[layer_name]
        self.lr = model_params["model_learning_rate"]
        self.weight_decay = model_params["model_weight_decay"]
        self.lin_dropout_prob = model_params["model_lin_dropout_prob"]
        self.save_hyperparameters()
        self.loss_fn = MSELoss()

        self.conv_layers = ModuleList([])
        self.transf_layers = ModuleList([])
        self.pooling_layers = ModuleList([])
        self.bn_layers = ModuleList([])

        # Transformation layer
        self.conv1 = gnn_layer(in_channels=c_in,
                               out_channels=embedding_size, 
                               heads=n_heads, 
                               dropout=dropout_rate,
                               edge_dim=edge_dim
                               ) 

        self.transf1 = Linear(embedding_size*n_heads, embedding_size)
        self.bn1 = BatchNorm1d(embedding_size)

        # Other layers
        for i in range(self.n_layers):
            self.conv_layers.append(gnn_layer(embedding_size, 
                                              embedding_size, 
                                              heads=n_heads, 
                                              dropout=dropout_rate,
                                              edge_dim=edge_dim,
                                              ))

            self.transf_layers.append(Linear(embedding_size*n_heads, embedding_size))
            self.bn_layers.append(BatchNorm1d(embedding_size))
            if i % self.top_k_every_n == 0:
                self.pooling_layers.append(TopKPooling(embedding_size, ratio=top_k_ratio))
            

        # Linear layers
        self.linear0 = Linear(embedding_size*2+3*2*edge_num, embedding_size*2)
        self.linear1 = Linear((embedding_size)*2, dense_neurons)
        self.linear2 = Linear(dense_neurons, c_out)

        """
        could use super node instead of topKPooling an linear layers
        or more topK pooling rather than linear layers
        """

    def forward(self, x, edge_index, edge_attr, batch_index):
        # Remove momenta from edge features
        p = edge_attr[:,8:11]
        p = p.reshape(max(batch_index)+1,-1)
        edge_attr = edge_attr[:,0:8]

        # Initial transformation
        x = self.conv1(x, edge_index, edge_attr)
        x = F.leaky_relu(self.transf1(x))
        x = self.bn1(x)

        # Holds the intermediate graph representations
        global_representation = []

        for i in range(self.n_layers):
            x = self.conv_layers[i](x, edge_index, edge_attr)
            x = F.leaky_relu(self.transf_layers[i](x))
            x = self.bn_layers[i](x)
            # Always aggregate last layer
            if i % self.top_k_every_n == 0 or i == self.n_layers:
                x , edge_index, edge_attr, batch_index, _, _ = self.pooling_layers[int(i/self.top_k_every_n)](
                    x, edge_index, edge_attr, batch_index
                    )
                # Add current representation
                global_representation.append(torch.cat([gmp(x, batch_index), gap(x, batch_index)], dim=1))
    
        x = sum(global_representation)

        #add momenta on
        x = torch.cat((x,p),1)

        # Output block
        x = F.leaky_relu(self.linear0(x))
        x = F.dropout(x,p=self.lin_dropout_prob, training=self.training)
        x = F.leaky_relu(self.linear1(x))
        x = F.dropout(x, p=self.lin_dropout_prob, training=self.training)
        x = torch.sigmoid(self.linear2(x))

        return x
    
    def training_step(self, batch, batch_idx):
        x, edge_index, edge_attr, y = batch['x'], batch['edge_index'], batch['edge_attr'], batch['y_norm']
        batch_idx = batch['batch']
        y_hat = self(x,
                     edge_index,
                     edge_attr,
                     batch_idx
        )
        loss = self.loss_fn(y_hat, y.view(-1,1))
        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=False, batch_size=max(batch_idx)+1)
        return loss

    def validation_step(self, batch, batch_idx):
        x, edge_index, edge_attr, y = batch['x'], batch['edge_index'], batch['edge_attr'], batch['y_norm']
        batch_idx = batch['batch']
        y_hat = self(x,
                     edge_index,
                     edge_attr,
                     batch_idx
        )
        loss = self.loss_fn(y_hat, y.view(-1,1))
        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=max(batch_idx)+1)
        return loss

    def test_step(self, batch, batch_idx):
        x, edge_index, edge_attr, y = batch['x'], batch['edge_index'], batch['edge_attr'], batch['y_norm']
        batch_idx = batch['batch']
        y_hat = self(x,
                     edge_index,
                     edge_attr,
                     batch_idx)
        loss = self.loss_fn(y_hat, y.view(-1,1))
        self.log("test_loss", loss, prog_bar=True, on_step=True, on_epoch=False, batch_size=max(batch_idx)+1)
        return loss

    def predict_step(self, batch, batch_idx):
        x, edge_index, edge_attr, y = batch['x'], batch['edge_index'], batch['edge_attr'], batch['y_norm']
        batch_idx = batch['batch']
        y_hat = self(x,
                     edge_index,
                     edge_attr,
                     batch_idx)
        return y_hat.item(), y.item()

    def prepare_data(self):
      print("No data to prepare") 

    def setup(self, stage: Optional[str]) -> None:
      #def prepare_data(self):
    #def setup(self, *args, **kwargs):
      #print(args)
      #print(kwargs)
      """
      I should edit this data to make the y values more uniformly distributed across the range
      """
      print("Loading datasets...")
      self.train_dataset = FeynmanDataset(100000, reprocess=False, root=self.data_dir, filename=self.filename, train=True)
      self.test_dataset = FeynmanDataset(10000, reprocess=False, root=self.data_dir, filename=self.filename, test=True)
      self.val_dataset = FeynmanDataset(50000, reprocess=False, root=self.data_dir, filename=self.filename, val=True)
      self.pred_dataset = FeynmanDataset(100, reprocess=False, root=self.data_dir, filename=self.filename, pred=True)
      print("Finished all!")
      #return super().setup(stage=stage)


    def train_dataloader(self):
      return DataLoader(dataset=self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=2)

    def val_dataloader(self):
      return DataLoader(dataset=self.val_dataset, batch_size=self.batch_size, num_workers=2)

    def test_dataloader(self):
      return DataLoader(dataset=self.test_dataset, batch_size=self.batch_size, num_workers=2)

    def predict_dataloader(self):
        return DataLoader(dataset=self.pred_dataset, batch_size=1) #keep this batch_size as one to get predictions to work

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),
                                lr=self.lr,
                                weight_decay=self.weight_decay,
                                )


#**Tuning the hyperparameters with Ray Tune**


In [None]:
def train_feyn_no_tune(params, num_gpus, num_epochs=10):
  """
  Function to train the Feynman GNN without a hyperparameter search.
  params = The hyperparameters to use, stored as a dictionary with the notation "model_..."
  """
  #need to make layer type a hyperparameter
  model_params = {k: v[0] for k, v in params.items() if k.startswith("model_")}
  model = FeynModel(c_in=-1, #train_dataset.num_node_features 
                    c_out=1,  #train_dataset.num_classes
                    layer_name="GAT",
                    model_params=model_params,
                    #filename=
                    )
  trainer = pl.Trainer(logger=TensorBoardLogger(CHECKPOINT_PATH, name="saved_model"),
                       max_epochs=num_epochs,
                       gpus=math.ceil(num_gpus),
                       log_every_n_steps=10,
                       #progress_bar_refresh_rate=0,
                       callbacks=[EarlyStopping('val_loss',patience=10)],
                       )
  trainer.fit(model)
  trainer.validate(model)
  trainer.test(model)


In [None]:
def train_feyn_tune(config, num_epochs=10, num_gpus=0):
  """
  function to run a training run that will be called later by the tuning function
  """
  model = FeynModel(c_in=-1, #train_dataset.num_node_features 
                    c_out=1,  #train_dataset.num_classes
                    layer_name="GAT",
                    model_params=config,
                    #filename=
                    )
  trainer = pl.Trainer(logger=TensorBoardLogger(save_dir=tune.get_trial_dir(),
                                                name="",
                                                version="."),
                       max_epochs=num_epochs,
                       gpus=math.ceil(num_gpus),
                       log_every_n_steps=10,
                       #progress_bar_refresh_rate=0,
                       callbacks=[TuneReportCallback({"loss": "val_loss",   
                                                      #"mean_accuracy": "val_acc"
                                                      },
                                                     on="validation_end"),
                                  #EarlyStopping('val_loss',patience=10)
                                  ]
                       )
  trainer.fit(model)

def tune_feyn_asha(config, gpus_per_trial=0, num_epochs=10, num_samples=10):

    scheduler = ASHAScheduler(
        max_t=num_epochs,
        grace_period=1,
        reduction_factor=2)

    reporter = CLIReporter(
        parameter_columns=[
                           "model_batch_size",
                           "model_weight_decay",
                           "model_learning_rate",
                           "model_embedding_size",
                           "model_attention_heads",
                           "model_layers",
                           "model_dropout_rate",
                           "model_top_k_ratio",
                           "model_top_k_every_n",
                           "model_dense_neurons",
                           "model_edge_dim",
                           "model_lin_dropout_prob"],
        metric_columns=["loss", "training_iteration"])

    train_fn_with_parameters = tune.with_parameters(train_feyn_tune,
                                                    num_epochs=num_epochs,
                                                    num_gpus=gpus_per_trial,
                                                    )
    
    resources_per_trial = {"cpu": 1, "gpu": gpus_per_trial}

    analysis = tune.run(train_fn_with_parameters,
        resources_per_trial=resources_per_trial,
        metric="loss",
        mode="min",
        config=config,
        num_samples=num_samples,
        scheduler=scheduler,
        progress_reporter=reporter,
        name="tune_mnist_asha")

    print("Best hyperparameters found were: ", analysis.best_config)


#**Create layer dictionary and Hyperparameters**


In [None]:
#layer name dictionary
gnn_layer_by_name = {
    "GCN": geom_nn.GCNConv,
    "GAT": geom_nn.GATConv,
    "GraphConv": geom_nn.GraphConv,
    "NNConv": geom_nn.NNConv,
    "RGCN": geom_nn.RGCNConv,
    "Trans": geom_nn.TransformerConv
}

#Hyperparameters to use if not tuning
HYPERPARAMETERS = {
    "model_batch_size": [64],
    "model_weight_decay": [0.000001],
    "model_learning_rate": [0.01],
    "model_embedding_size": [128],
    "model_attention_heads": [4],
    "model_layers": [5],
    "model_dropout_rate": [0.5],
    "model_top_k_ratio": [0.2],
    "model_top_k_every_n": [1],
    "model_dense_neurons": [4],
    "model_edge_dim": [11],
    "model_lin_dropout_prob": [0.3],
    }

#Hyperparameters for ray tune to search through
config = {
    "model_batch_size": tune.choice([64]),
    "model_weight_decay": tune.choice([0.000001]),
    "model_learning_rate": tune.loguniform(0.0001,0.1),
    "model_embedding_size": tune.choice([4]),
    "model_attention_heads": tune.choice([4]),
    "model_layers": tune.choice([3]),
    "model_dropout_rate": tune.choice([0.5]),
    "model_top_k_ratio": tune.choice([0.2]),
    "model_top_k_every_n": tune.choice([1]),
    "model_dense_neurons": tune.choice([4]),
    "model_edge_dim": tune.choice([11]),
    "model_lin_dropout_prob": tune.choice([0.3]),
    }

In [None]:
print(f"Torch version: {torch.__version__}")
print(f"Cuda available: {torch.cuda.is_available()}")
print(f"Torch geometric version: {torch_geometric.__version__}")

if torch.cuda.is_available():
  gpus=1
else:
  gpus=0

In [None]:
train_feyn_no_tune(HYPERPARAMETERS,gpus,10)

In [None]:
tune_feyn_asha(config, gpus_per_trial=gpus)

#**Predict with last model**

In [None]:
out = trainer.predict(model, dataloaders=pred_loader)
print(out)

#**TensorBoard Logs and running training**


In [None]:
%tensorboard --logdir /content/gdrive/MyDrive/Part_III_Project/saved_models/lightning_logs