### Karate Club Graph Clustering
In this notebook the GFlowNet is applied to the famous Karate Club Graph. Note that training is rather computationally expensive.

In [1]:
import torch
import pandas as pd
from Core.Core import GraphNet, GraphNetNodeOrder, check_gpu, GibbsSampleStates

Check whether Pytorch can use the GPU:

In [2]:
check_gpu()

Cuda is not available


The training will occur using a set number of samples, every 'epoch_interval' the network will be used to draw this number of samples to estimate the empirical distribution. The number of epochs to continue training for must also be set. The boolean variable 'node_order' specifies whether to train the network using a fixed node order, which is randomly generated every forward pass through the graph. 'GibbsStart' specifies whether to draw the initial sample from the GibbsSampler. The final variable to set is the 'GibbsProportion' which specifies which proportion of the samples to continue training on are from the GibbsSampler initialised using a sample from the previous samples from the GFlowNet.

In [3]:
# Parameters to set:
nSamples = 10   # Must be greater than 1
epochInterval = 1
minEpochs = 0   # Left in there for continuing training
maxEpochs = 2
nodeOrder = True
GibbsStart = False
GibbsProportion = .6

The sampled clusterings are saved in the 'Data' folder and the weights in the 'Weights' folder. Each filename consists of a string of the structure seen below, one can add a prefix to distinguish between runs.

In [4]:
prefix = ''
nodeOrderString = '_o' if nodeOrder else ''
filepathSamples = f'Data/{prefix}Karate{minEpochs}_{maxEpochs}_{nSamples}{nodeOrderString}_Samples_'
filepathWeights = f'Weights/{prefix}Karate{minEpochs}_{maxEpochs}_{nSamples}{nodeOrderString}'

The graph is loaded and the network is initialised. Here the number of hidden layers and the number of hidden units is specified.

In [5]:
nLayers = 5
nHidden = 64

Adj_karate = torch.tensor(pd.read_csv("Data/Adj_karate.csv", header=None, dtype=int).to_numpy())
n = Adj_karate.shape[0]
net = GraphNetNodeOrder(nNodes=n, nLayers=nLayers, nHidden=nHidden) if nodeOrder else GraphNet(nNodes=n)
net.save(prefix=filepathWeights, postfix=str(0))

Next the initial sample is drawn

In [6]:
X1 = GibbsSampleStates(Adj_karate, nSamples=nSamples, N=n) if GibbsStart \
        else net.sample_forward(Adj_karate, nSamples=nSamples, timer=True)
torch.save(X1, filepathSamples + f'{0}.pt')
nGibbs = int(nSamples * GibbsProportion)

  clustering_list[node_index] = torch.tensor(cluster_index_chosen + 1, dtype=torch.float32)
Sampling: 100%|██████████| 10/10 [00:10<00:00,  1.06s/it]


Finally, the training loop! Here the network's weights and the samples drawn each iteration are saved

In [7]:
for i in range(1, ((maxEpochs - minEpochs) // epochInterval) + 1):
        net.train(X1, epochs=epochInterval)  # Train an extra epoch interval
        # Take a sample from the GFlowNet part of the previous samples:
        z = X1[nGibbs:][torch.randint(nSamples - nGibbs, (1,))][0][net.n_nodes ** 2:].reshape((net.n_nodes, net.n_nodes))
        z = net.get_clustering_list(z)[0].reshape((-1, 1))
        # Sample again:
        gibbsSamples = GibbsSampleStates(Adj_karate, nSamples=nGibbs, N=net.n_nodes, z=z)
        gflowSamples = net.sample_forward(Adj_karate,
                                          nSamples=nSamples - nGibbs,
                                          timer=True,
                                          saveFilename=filepathSamples + f'{i * epochInterval}')
        X1 = torch.concat((gibbsSamples, gflowSamples), dim=0)
        net.save(prefix=filepathWeights, postfix=str(epochInterval * i))

  C = torch.t_copy(torch.tensor(C_in, dtype=torch.int64))
Training: 100%|██████████| 1/1 [00:00<00:00, 23.26it/s]
  tempSamples = Gibbs_sample_torch(torch.tensor(adjacency_matrix, dtype=torch.float32), T=nSamples * 2, z=z)
Gibbs Sampling: 100%|██████████| 12/12 [00:00<00:00, 35.75it/s]
  X1[i] = torch.concat((adjacency_matrix.flatten(), torch.tensor(sample.flatten())))
Sampling: 100%|██████████| 4/4 [00:04<00:00,  1.04s/it]
Training: 100%|██████████| 1/1 [00:00<00:00, 37.03it/s]
Gibbs Sampling: 100%|██████████| 12/12 [00:00<00:00, 36.84it/s]
Sampling: 100%|██████████| 4/4 [00:04<00:00,  1.03s/it]
