<a href="https://colab.research.google.com/github/Miyamura80/Fuma_Fuzz/blob/main/GNN_Slither.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q dm-haiku spektral optax neptune flax

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/352.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━[0m [32m286.7/352.1 kB[0m [31m8.4 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m352.1/352.1 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m140.1/140.1 kB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m442.6/442.6 kB[0m [31m27.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m135.6/135.6 kB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.3/184.3 kB[0m [31m18.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.7/10.7 MB[0m [31m69.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━

# Init


In [69]:
import jax
import jax.numpy as jnp
import haiku as hk
from flax import linen as nn
from flax.training import train_state
import optax
import neptune
import statistics
import itertools


import numpy as np
from typing import Sequence


# Models


## GIN

In [15]:
class GINLayer(nn.Module):
    hidden_dim: int
    eps: float = 0.0

    @nn.compact
    def __call__(self, x, adj):
        # Initialize the learnable parameters
        weight_init = jax.nn.initializers.xavier_uniform()
        bias_init = jax.nn.initializers.zeros
        w = self.param("w", weight_init, (x.shape[-1],self.hidden_dim)) # (f, h)
        b = self.param("b", bias_init, (self.hidden_dim,)) # (h)
        
        # Update the node features using the GIN aggregation function
        sum_neighbors = jnp.dot(adj, x) # (n,n) @ (n, f) -> (n,f)
        out = ((1 + self.eps) * x) + sum_neighbors # (n,f)
        return nn.relu(jnp.dot(out, w) + b) # (n,h)

class GINModel(nn.Module):
    hidden_dims: Sequence[int]
    eps_list: Sequence[float] = None

    def setup(self):
        self.epsilons = [0.0 for i in range(len(self.hidden_dims))] if self.eps_list==None else self.eps_list
        assert len(self.hidden_dims) == len(self.epsilons), "Hidden dimensions and epsilons should have the same length."

    @nn.compact
    def __call__(self, x, adj):
        for i,h_dim in enumerate(self.hidden_dims):
            x = GINLayer(h_dim, self.epsilons[i])(x, adj)
        return x

def example_usage2():
    hidden_dims = [32,32]
    model = GINModel(hidden_dims=hidden_dims)

    key = jax.random.PRNGKey(0)
    batch_adj = jax.random.randint(key, (64, 64), 0, 2)
    batch = jnp.ones((64, 10))

    params = model.init(jax.random.PRNGKey(0), batch, batch_adj)
    output = model.apply(params, batch, batch_adj)
    print("Output shape:", output.shape)


if __name__ == "__main__":
    example_usage2()


Output shape: (64, 32)


# Utils

## get_model(config: dict, hidden_channels: List[int]) -> Model, Params

In [64]:
def get_model(config: dict, hidden_channels: Sequence[int]):
    n_nodes = config["n_nodes"]

    if config["model"]=="GIN":
        model = GINModel(hidden_channels)
        params = model.init(jax.random.PRNGKey(0), 
                   jnp.ones([n_nodes, config["in_channels"]]),
                   jnp.ones([n_nodes, n_nodes]))
    else:
        model_name = config["model"]
        raise ValueError(f"No model of name: {model_name} found")
    return model, params



# Dataset


## Loading Standard Dataset: 70/20/10 Split
X_train, y_train, adj_train

In [66]:
import jax
import jax.numpy as jnp
import haiku as hk
import optax

from spektral.datasets import Cora
from spektral.utils import normalized_adjacency
import numpy as np

# GIN layer and GIN model functions (copy from the previous Haiku-based code snippet)

# Load the CORA dataset
data = Cora()
dataset = data[0]
adj = normalized_adjacency(dataset.a)
node_feats = dataset.x
labels = np.argmax(dataset.y, axis=-1)


# Set the ratios
train_ratio = 0.7
valid_ratio = 0.2
test_ratio = 0.1

# Calculate the number of samples for each set
n_samples = dataset.n_nodes
n_train = int(train_ratio * n_samples)
n_valid = int(valid_ratio * n_samples)

# Create indices for splitting
indices = np.arange(n_samples)
np.random.shuffle(indices)

# Split the indices
train_indices = indices[:n_train]
valid_indices = indices[n_train:(n_train + n_valid)]
test_indices = indices[(n_train + n_valid):]

# Split the data
adj_train = adj[train_indices, :][:, train_indices]
X_train = node_feats[train_indices]
y_train = labels[train_indices]

adj_val = adj[valid_indices, :][:, valid_indices]
X_val = node_feats[valid_indices]
y_val = labels[valid_indices]

adj_test = adj[test_indices, :][:, test_indices]
X_test = node_feats[test_indices]
y_test = labels[test_indices]


# Convert data to JAX arrays
adj_train = jnp.array(adj_train.todense(), dtype=jnp.float32)
X_train = jnp.array(X_train, dtype=jnp.float32)
y_train = jnp.array(y_train, dtype=jnp.int32)

adj_val = jnp.array(adj_val.todense(), dtype=jnp.float32)
X_val = jnp.array(X_val, dtype=jnp.float32)
y_val = jnp.array(y_val, dtype=jnp.int32)

adj_test = jnp.array(adj_test.todense(), dtype=jnp.float32)
X_test = jnp.array(X_test, dtype=jnp.float32)
y_test = jnp.array(y_test, dtype=jnp.int32)


  self._set_arrayXarray(i, j, x)


In [24]:
dataset

Graph(n_nodes=2708, n_node_features=1433, n_edge_features=None, n_labels=7)

# Training

## Setup

## Training Functions

In [81]:
# Needed for JAX apply_model quirk
NUM_CLASSES = dataset.n_labels

@jax.jit
def apply_model(state, x, adj, labels):
    """Computes gradients, loss and accuracy for a single batch."""
    def loss_fn(params):
        logits = state.apply_fn(params, x, adj)
        one_hot = jax.nn.one_hot(labels, NUM_CLASSES)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return grads, loss, accuracy

@jax.jit
def update_model(state, grads):
    return state.apply_gradients(grads=grads)

def train_epoch(state, x, adj, labels):
    """Train for a single epoch."""
    # BACKLOG: Add batches here
    grads, train_loss, train_accuracy = apply_model(state, x, adj, labels)
    state = update_model(state, grads)
    return state, train_loss, train_accuracy

def create_train_state(rng, config) -> train_state.TrainState:
    """Creates initial `TrainState`."""
    n_nodes = config["n_nodes"]
    hidden_channels = (config["num_layers"]-1)* [config["hidden_channels"]] \
                      + [config["out_channels"]]
    
    model, params = get_model(config, hidden_channels)
    tx = optax.sgd(config["lr"], config["momentum"])
    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx)

def train_and_evaluate(config: dict, neptune_client=None) -> train_state.TrainState:
    """Execute model training and evaluation loop.
    Args:
      config: Hyperparameter configuration for training and evaluation.
    Returns:
      The train state (which includes the `.params`).
    """
    
    # Create rng
    rng = jax.random.PRNGKey(0)
    neptune_client["params"]["rng"] = rng
    rng, current_rng = jax.random.split(rng)

    # Early stopping measures
    best_acc_this_config = 0.0
    best_train_tracker = []
    best_state_this_config = None

    for try_n in range(config["retry_num"]):
        # Create training state
        state = create_train_state(current_rng, config)
        rng, current_rng = jax.random.split(rng)

        # Early stopping parameters
        epochs_without_improvement = 0
        best_test_accuracy = 0.0
        best_model_state = None

        # List to store test accuracies
        try_train_tracker = []
        val_accuracies = []

        best_perform = True

        for epoch in range(1, config["epoch"] + 1):            
            state, train_loss, train_accuracy = train_epoch(state, X_train, adj_train, y_train)
            _, val_loss, val_accuracy = apply_model(state, X_val, adj_val, y_val)

            val_accuracies.append(val_accuracy)
            # Update best test accuracy and model state
            if val_accuracy > best_test_accuracy:
                best_test_accuracy = val_accuracy
                best_model_state = state
                epochs_without_improvement = 0
            else:
                epochs_without_improvement += 1
            
            # Early stopping
            if epochs_without_improvement >= config["early_stopping_threshold"]:
                best_perform = best_test_accuracy > best_acc_this_config
                print(f"Early stopping at epoch {epoch}.")
                break

            # Log
            try_train_tracker.append({
                "train_loss": train_loss,
                "train_accuracy": train_accuracy,
                "val_accuracy": val_accuracy,
                "val_loss": val_loss,
            })

            if epoch % 10 == 0:
                print(
                    'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, val_loss: %.4f, val_accuracy: %.2f'
                    % (epoch, train_loss, train_accuracy * 100, val_loss,
                      val_accuracy * 100))
        if best_perform:
            best_acc_this_config = best_test_accuracy
            best_model_this_config = best_model_state
            best_train_tracker = try_train_tracker

            neptune_client["params"]["best_val_accuracy"] = best_acc_this_config
            neptune_client["params"]["stop_epoch"] = epoch
            val_acc_floats = [arr.item() for arr in val_accuracies]
            neptune_client["params"]["val_acc_std"] = statistics.stdev(val_acc_floats)
            best_perform = False

    # Plot the best run
    for i,epoch_detail in enumerate(best_train_tracker):
        neptune_client['train/train_loss'].append(epoch_detail["train_loss"])
        neptune_client['train/train_accuracy'].append(epoch_detail["train_accuracy"])
        neptune_client['train/val_loss'].append(epoch_detail["val_loss"])
        neptune_client['train/val_accuracy'].append(epoch_detail["val_accuracy"])

    neptune_client.stop()
    return best_state_this_config




## Main Loop

In [82]:
# lr_list = [0.01, 0.1]
# num_layer_list = [1,2,3]
# hidden_ch_list = [16,32,64]
lr_list = [0.1]
num_layer_list = [2]
hidden_ch_list = [32]

prod_params = list(itertools.product(lr_list, num_layer_list,hidden_ch_list))

dataset_param_name = "Cora"



for (lr, n_layers,hidden_ch) in prod_params:
    run = neptune.init_run(
        capture_hardware_metrics=True,
        capture_stderr=True,
        capture_stdout=True,
    )
    config = {"lr": lr, 
              "optimizer": "SGD", 
              "loss": "ce",
              "epoch": 300, 
              "batch_size": 16, 
              "momentum": 0.9, 
              "early_stopping_threshold": 10,
              "retry_num": 5,
              # Dataset Related,
              "dataset": dataset_param_name,
              # Model Related
              "model": "GIN",
              "num_layers": n_layers,
              "hidden_channels": hidden_ch,
              "n_nodes": dataset.n_nodes,
              "in_channels": dataset.n_node_features,
              "out_channels": dataset.n_labels,
              }
    run["parameters"] = config

    state = train_and_evaluate(config, run)

https://app.neptune.ai/miyamura80/Slither-Graph-CFG/e/SLIT-1
epoch: 10, train_loss: 1.4122, train_accuracy: 50.03, val_loss: 1.6148, val_accuracy: 46.03
epoch: 20, train_loss: 0.6581, train_accuracy: 82.80, val_loss: 1.0586, val_accuracy: 70.43
epoch: 30, train_loss: 0.3996, train_accuracy: 88.50, val_loss: 0.8306, val_accuracy: 73.94
epoch: 40, train_loss: 0.2882, train_accuracy: 91.45, val_loss: 0.7657, val_accuracy: 75.42
Early stopping at epoch 44.
epoch: 10, train_loss: 1.4122, train_accuracy: 50.03, val_loss: 1.6148, val_accuracy: 46.03
epoch: 20, train_loss: 0.6581, train_accuracy: 82.80, val_loss: 1.0586, val_accuracy: 70.43
epoch: 30, train_loss: 0.3996, train_accuracy: 88.50, val_loss: 0.8306, val_accuracy: 73.94
epoch: 40, train_loss: 0.2882, train_accuracy: 91.45, val_loss: 0.7657, val_accuracy: 75.42
Early stopping at epoch 44.
epoch: 10, train_loss: 1.4122, train_accuracy: 50.03, val_loss: 1.6148, val_accuracy: 46.03
epoch: 20, train_loss: 0.6581, train_accuracy: 82.80, v