In [None]:
from GraphColor.dataloader import ColorDataset, ColorMultiDataset, RandColoring, ColoringOneHot
from torch_geometric.nn.models import GAT
from torch_geometric.loader import DataLoader
import torch_geometric.transforms
import torch
import torch.multiprocessing as mp
from functools import partial
import torch.nn.functional as F
import os
#from torch.utils.tensorboard import SummaryWriter
import wandb
from numpy.random import default_rng
import math

#%%
from torch_geometric.utils import to_networkx
from torch_geometric.nn import SAGEConv, GATv2Conv, GCNConv, global_mean_pool, BatchNorm, LayerNorm
import matplotlib.pyplot as plt
import networkx as nx
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

hypers = {
    'num_features': 32,
    'embedding_dim': 64,
    'n_colors': 21,

}
def min_size(data, n):
    return data.x.shape[0] > n
NUM_PROCESSES = 4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

pre_transforms = torch_geometric.transforms.Compose(
    [torch_geometric.transforms.ToUndirected()]) #, ColoringOneHot(hypers['num_features'], cat=False), RandColoring(hypers['num_features'])
transforms = torch_geometric.transforms.Compose(
   [ torch_geometric.transforms.ToDevice(device)])

# def min_size(n, data):
#    1
#    return data.x.shape[0] > n



filters = partial(min_size, n=50)  #curry the funtion to keep graphs with more than 50 nodes
# torch_geometric.transforms.ComposeFilters([partial(min_size, n=50)])
# length no filter 11929

import torch
from torch.nn import Dropout, Linear
import torch.nn.functional as F
# from torch.nn import Linear


class AmazonNet(torch.nn.Module):
    """
    Based on the Network used in Graph Coloring with Physics-Inspired Graph Neural Networks.
    In the paper they used a 2 Conv layer Network.
    In this approach the Conv was replaced with Transformers.
    """
    def __init__(self, num_features, hidden_dim, num_classes, n_heads=3):
        super(AmazonNet, self).__init__()
        self.conv1 = SAGEConv(num_features, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, num_classes)
        self.dropout = torch.nn.Dropout(p=0.2)
        


    def forward(self, x, edge_index, batch):
        #x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = x.relu()
        
        x = self.conv2(x, edge_index)
        x = self.dropout(x)

        return x



graph_dataset = ColorMultiDataset(root='data/', pre_transform=pre_transforms, transform=transforms, pre_filter=filters)
for i, data in enumerate(graph_dataset):
    try:
        if not data.validate():
            print(f"Error in data entry No:{i} name:{data.name}")
    except ValueError:
        print(f"IndexError in data entry No:{i} name:{data.name}")
        continue
rng = default_rng()
choice = rng.permutation(len(graph_dataset))
idx = math.floor(len(graph_dataset)*0.8)
train_set = graph_dataset[0:idx]
test_set = graph_dataset[idx:-1]
loader_train = DataLoader(train_set, batch_size=1, shuffle=True, num_workers=0, pin_memory=False)
loader_test = DataLoader(test_set, batch_size=1, shuffle=True, num_workers=0, pin_memory=False)

print("...Creating Model...")

config = {
        "learning_rate": 0.02,
        'feature_rep': "Embed:",
        "dataset": "BA",
        "epochs": 1000,
        "log_interval": 1,
        #'NUM_ACCUMULATION_STEPS': 8,
        
        **hypers
    } 
from itertools import chain
wandb.init(
    # set the wandb project where this run will be logged
    project="Graphs-AAS",
    name="Recreation AmazonNet Original",
    # track hyperparameters and run metadata
    config=config
   
)

#model = GAT(config['num_features'], config['embedding_dim'], 3 , loader_train.dataset.num_classes, jk=None)
embed = torch.nn.Embedding(config['n_colors'], config['num_features'])
embed.to(device)
model = AmazonNet(config['num_features'], config['embedding_dim'], loader_train.dataset.num_classes)
model.to(device)
wandb.watch(model, log_freq=1)
params = chain(model.parameters())
optimizer = torch.optim.AdamW(params, lr=config['learning_rate'])
criterion = torch.nn.CrossEntropyLoss()

In [None]:
print("...Start Training...")
pre_pools = []
post_pools = []
first_convs = []
names = []
hash_tensor = torch.vmap(lambda x: x % config['n_colors'])
def pots_loss_func(probs, adj_tensor):
    """
    Function to compute cost value based on soft assignments (probabilities)

    :param probs: Probability vector, of each node belonging to each class
    :type probs: torch.tensor
    :param adj_tensor: Adjacency matrix, containing internode weights
    :type adj_tensor: torch.tensor
    :return: Loss, given the current soft assignments (probabilities)
    :rtype: float
    """

    # Multiply probability vectors, then filter via elementwise application of adjacency matrix.
    #  Divide by 2 to adjust for symmetry about the diagonal
    loss_ = torch.mul(adj_tensor, (probs @ probs.T)).sum() / 2

    return loss_
def loss_func_color_hard(coloring, nx_graph):
    """
    Function to compute cost value based on color vector (0, 2, 1, 4, 1, ...)

    :param coloring: Vector of class assignments (colors)
    :type coloring: torch.tensor
    :param nx_graph: Graph to evaluate classifications on
    :type nx_graph: networkx.OrderedGraph
    :return: Cost of provided class assignments
    :rtype: torch.tensor
    """

    cost_ = 0
    for (u, v) in nx_graph.edges:
        cost_ += 1*(coloring[u] == coloring[v])*(u != v)

    return cost_


def train(data, patience=1000, tolerance=1e-4, seed=1):
    model.train()
    # Tracking
    best_loss = torch.tensor(float('Inf'))
    best_cost = torch.tensor(float('Inf'))
    best_coloring = None
    prev_loss = 1.  # initial loss value (arbitrary)
    cnt = 0  # track number times early stopping is triggered
    adj_mat = torch_geometric.utils.to_dense_adj(data.edge_index, data.batch, max_num_nodes=data.num_nodes)
    for epoch in range(config['epochs']):
        #idx = torch.tensor(0, device=device)
    
        #input = torch.squeeze(embed(hash_tensor(data.x.long())))
        #out, pre_pool, post_pool, _ = model(input, data.edge_index, data.batch)  # , data.batch Perform a single forward pass.
        input = torch.squeeze(embed(hash_tensor(data.x.long())))
        out = model(input, data.edge_index, data.batch)  # , data.batch Perform a single forward pass.
        probas = F.softmax(out, dim=1)
        loss = pots_loss_func(probas, adj_mat) 
        coloring = torch.argmax(probas, dim=1)
        cost_hard = loss_func_color_hard(coloring, graph_nx)

        if cost_hard < best_cost:
            best_loss = loss
            best_cost = cost_hard
            best_coloring = coloring
        if (abs(loss - prev_loss) <= tolerance) | ((loss - prev_loss) > 0):
            cnt += 1
        else:
            cnt = 0
        # update loss tracking
        prev_loss = loss
        if cnt >= patience:
            print(f'Stopping early on epoch {epoch}. Patience count: {cnt}')
            break
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        wandb.log({'train/pots_loss': loss.item(),
                      'train/color_cost':cost_hard,
                     })

       # Final coloring

    print(f'Final coloring: {coloring}, soft loss: {loss}')
    return probas, coloring, loss, epoch
graph_nx = torch_geometric.utils.to_networkx(graph_dataset[0])
train(graph_dataset[0])
wandb.finish()


In [None]:
graph_dataset[0].x

In [None]:
model.eval()
loss = torch.nn.CrossEntropyLoss()
data0 = graph_dataset[0]
data1 = graph_dataset[1]
loss_series = []
df_pre = []
df_post = []
labels = []
first_convs = []
for data in test_set:
    #input = torch.squeeze(embed(hash_tensor(data.x.long())))
    out, pre, post, first_conv = model(data.x, data.edge_index, data.batch)
    out = F.softmax(out, dim=1)
    loss_series.append(loss(out, data.y).cpu().detach().numpy())
    df_pre.append(pre.cpu().detach().numpy())
    df_post.append(post.cpu().detach().numpy())
    labels.append(data.y.cpu().detach().numpy()[0])
    first_convs.append(first_conv.cpu().detach().numpy())
labels = pd.Series(labels, dtype="category")

In [None]:
sns.histplot(labels)

This Histplot shows that our testset, ignoring the class 7,  is rather balanced.
As such the set should be representative to all datasets

## Investigate wether the Nodes in a Graph have different features PRE pooling
As the graphs have different numbers of nodes they cant be stacked and have to be inspected induvidually.
But aggregate measures can be used.
For this the mean and the standard deviation of each node activation are taken featurewise.

These graph level aggregates are then again compared agaisnt all other graphs, solverwise.

This behaviour occurs with both encoding schemes


In [None]:
# graph wise std and mean of each feature dim
first_convs_std = np.array(list(map(lambda x: np.std(x, axis=0), first_convs)))
first_convs_mean = np.array(list(map(lambda x: np.mean(x, axis=0), first_convs)))

In [None]:
tmp = pd.DataFrame(first_convs_mean)
tmp['label'] = labels
fig, axs = plt.subplots(1, 2)
fig.suptitle('Mean of activations')
tmp.groupby(['label']).mean().T.plot(kind='bar', ax=axs[0])
tmp.groupby(['label']).std().T.plot(kind='bar', ax=axs[1])
for ax in axs.flat:
    ax.tick_params(
    axis='x',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    top=False,         # ticks along the top edge are off
    labelbottom=False) # labels along the bottom edge are off

plt.savefig('graphs/AmzNet_RandColoring_mean.png', format='png')

In [None]:
tmp = pd.DataFrame(first_convs_std)
tmp['label'] = labels
fig, axs = plt.subplots(1, 2)
fig.suptitle('STD of activations')
tmp.groupby(['label']).mean().T.plot(kind='bar', ax=axs[0])
tmp.groupby(['label']).std().T.plot(kind='bar', ax=axs[1])
for ax in axs.flat:
    ax.tick_params(
    axis='x',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    top=False,         # ticks along the top edge are off
    labelbottom=False) # labels along the bottom edge are off

plt.savefig('graphs/AmzNet_RandColoring_std.png', format='png')

Across all graph labels there is no differenciation between each label.???
From CITE we know that each layer in a GNN needs to be expressive to lead to a useful network architecture.
Warrant is therefore needed to investiate 1 ONE good layer and not stack them

In [None]:
height = df_post.std()

y_pos = np.arange(len(height))

# Create bars
plt.bar(y_pos, height)
plt.show()

In [None]:
pd.DataFrame(np.vstack(df_post)).std()

In [None]:
wandb.finish()