# Setup

## Installs

In [None]:
#!pip install torch==2.4.0
import torch

In [None]:
torch_version = torch.__version__
print(torch_version)
scatter_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
sparse_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"

#!pip install torch-scatter -f $scatter_src
#!pip install torch-sparse -f $sparse_src
!pip install torch-geometric
import torch

2.9.0+cu126
Looking in links: https://pytorch-geometric.com/whl/torch-2.9.0+cu126.html
Collecting torch-scatter
  Downloading torch_scatter-2.1.2.tar.gz (108 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m108.0/108.0 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: torch-scatter
  Building wheel for torch-scatter (setup.py) ... [?25l[?25hdone
  Created wheel for torch-scatter: filename=torch_scatter-2.1.2-cp312-cp312-linux_x86_64.whl size=3857007 sha256=7cd0a793327f8f9b220cf810f913d3a58038df922fc2c1fe2db3b501d4946b92
  Stored in directory: /root/.cache/pip/wheels/84/20/50/44800723f57cd798630e77b3ec83bc80bd26a1e3dc3a672ef5
Successfully built torch-scatter
Installing collected packages: torch-scatter
Successfully installed torch-scatter-2.1.2
Looking in links: https://pytorch-geometric.com/whl/torch-2.9.0+cu126.html
Collecting torch-sparse
  Downloading torch_sparse-0.6.1

In [None]:
# Constants
num_cells = 1000
feature_size = 1 #Controls how many features we are predicting in GNN. If one-dimensional GCA then keep this one
# (density range: we flip if within density range)
lo = 0.0
hi = 0.4

batch_size = 64 #Default in Grattarola et al. is 32
epochs = 5000
lr = 0.01



In [None]:
import torch.nn as nn
import torch_geometric as pyg
import torch.nn.functional as F # Import F for ReLU

class GNCA(torch.nn.Module):
  ''' General GNCA class, generalizable for many problem types
  Architecture:
    1. MLP on node embedding
    2. Concatenation with NN on neighbors (GCN)
    3. Postprocessing MLP on result
  MLP pre and post have 256 size hidden units, post-processing MLP has number of units equal to state size
  Activation function of post-processing MLP is sigmoid for binary state spaces, tanh if between -1,1, and no activation otherwise
  '''
  def __init__(self,dims=None,activation='relu',batch_norm=False):
    super(GNCA, self).__init__() # Corrected super class call
    self.hidden_dim = 256
    self.activation = activation
    #We want dims to be passed in as a tuple [input_dim, output_dim]
    if dims is None:
      self.input_dim = 2 #Assume planar embedding
      self.output_dim = 2 #Assume output dim also planar
    else:
      self.input_dim = dims[0]
      self.output_dim = dims[1]

    # Now, we define our MLP layers
    self.mlp_pre = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim, bias=False),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim, bias=False),
            nn.ReLU())

    # GCN layer for neighbor aggregation
    # The GCNConv layer expects input of shape (num_nodes, in_channels)
    self.conv = pyg.nn.GCNConv(self.hidden_dim, self.hidden_dim, bias=False)


    self.mlp_post = nn.Sequential(
        nn.Linear(self.hidden_dim + self.hidden_dim, self.output_dim, bias=False), # Input dimension is sum of mlp_pre output and GCN output
        nn.Sigmoid() # Keep Sigmoid for now, can be changed based on task
    )

  def forward(self,x,edge_index):
    '''Forward pass (see Grattarola et al. for details)'''
    # x shape is (batch_size, num_cells, input_dim)
    batch_size, num_cells, input_dim = x.size()

    # Reshape for mlp_pre
    x_reshaped = x.view(-1, input_dim) # (batch_size * num_cells, input_dim)
    h_x_reshaped = self.mlp_pre(x_reshaped) # (batch_size * num_cells, hidden_dim)
    h_x = h_x_reshaped.view(batch_size, num_cells, self.hidden_dim) # (batch_size, num_cells, hidden_dim)


    # Apply GCNConv. GCNConv expects (num_nodes, in_channels) and edge_index.
    # If we have a batch of graphs, we would need a batch object.
    # Since we have a single edge_index for all num_cells, we can apply GCNConv
    # to the flattened features (treating all nodes across the batch as a single graph for this layer).
    # However, the GCNConv aggregates based on edge_index. If edge_index connects nodes
    # within each of the 'batch_size' graphs independently, this flattened approach is wrong.
    # Assuming edge_index represents the connections for the entire set of num_cells,
    # and each batch element is a state on this single large graph:

    # Flatten h_x for GCNConv
    h_x_flattened = h_x.view(-1, self.hidden_dim) # (batch_size * num_cells, hidden_dim)

    # Apply GCNConv. Note: This assumes edge_index connects nodes across the *entire* flattened graph.
    # If your edge_index is for a single graph of num_cells, and you want GCNConv applied
    # independently to each graph in the batch, you would need to use a PyG Batch object.
    h_Nx_flattened = self.conv(h_x_flattened, edge_index) # (batch_size * num_cells, hidden_dim)

    # Reshape h_Nx back to include batch dimension
    h_Nx = h_Nx_flattened.view(batch_size, num_cells, self.hidden_dim) # (batch_size, num_cells, hidden_dim)


    # Concatenate along the feature dimension
    h_concat = torch.cat([h_x, h_Nx], dim=2) # (batch_size, num_cells, hidden_dim + hidden_dim)

    # Reshape for mlp_post
    h_concat_reshaped = h_concat.view(-1, self.hidden_dim + self.hidden_dim) # (batch_size * num_cells, hidden_dim + hidden_dim)

    out_reshaped = self.mlp_post(h_concat_reshaped) # (batch_size * num_cells, output_dim)

    # Reshape output back to (batch_size, num_cells, output_dim)
    out = out_reshaped.view(batch_size, num_cells, self.output_dim)

    return out.squeeze(-1) # Squeeze the last dimension if output_dim is 1

# Now we look at using GNCA's to emulate point clouds

## Imports

In [None]:
!pip install pygsp

In [None]:
import pygsp

Get graphs we test on

In [None]:
def get_cloud(name, **kwargs):
  '''Credit Grattarola et al. '''
  graph_class = getattr(pygsp.graphs, name)
  graph = graph_class(**kwargs)

  y = graph.coords
  a = graph.W.astype("f4")

  #spektral version does the following:
  #output = Graph(x=y, a=a)
  #where Graph(x=None, a=None, e=None, y=None)
  #a is an adjacency matrix in spektral, so must convert to a PyG version
  #torch_geometric version:
  edge_index = torch.tensor(a).nonzero()
  edge_index = edge_index.t().contiguous()
  output = pyg.data.Data(x=torch.from_numpy(y), edge_index=edge_index)

  #output.name = name

  return output

graphs = [
        get_cloud("Grid2d", N1=20, N2=20),
        get_cloud("Bunny"),
        get_cloud("Minnesota"),
        get_cloud("Logo"),
        get_cloud("SwissRoll", N=200), #Below graphs are new graphs we want to test on
        get_cloud("comet",N=47,k=31), # primes just in case
  `     get_cloud("BarabasiAlbert"N=150) #150 node random graph according to Barabasi-Albert construction
    ]

In [None]:
def normalize_sphere(graph):
    offset = torch.mean(graph.x, dim=-2, keepdim=True)
    scale = torch.abs((graph.x)).max()
    graph.x = (graph.x - offset) / scale

    return graph

In [None]:
import torch

class stateCache:
  '''Cache that stores states of GNCA. We sample from this for training
  cache stored as an array of pyg dataset objects'''
  def __init__(self, initial_state, size=1024):
    ''' Takes as input initial_state - sphere-normalized initial state graph dataset objects'''
    self.init_state = initial_state
    self.counter = torch.zeros(size) #[size,] tensor for keeping count of how many times we pick from index
    self.cache = [initial_state for i in range(size)] #array of graph dataset objects
  def sample(self,count):
    '''Sample count graphs from cache for training
    Returns samples in array, and indices where chosen from'''
    #Pick count random idxs from [0,size]
    idxs = torch.randint(0,len(self.cache), (count,))
    return [self.cache[i] for i in idxs], idxs
  def update(self,idxs,states,counts):
    '''Update cache with new states.
    Replace random choice in cache with initial state graph'''
    self.cache[idxs] = states
    self.counter[idxs] += counts
    #randomly choose index to replace with initial state
    idx = torch.randint(0,len(self.cache), (1,))
    self.cache[idx] = self.init_state
    self.counter[idx] = 0
    return
  def initial(self):
    return self.init_state



## Training GNCA
'To train the GNCA, we apply the transition for a given number of steps t and use backpropagation
through time (BPTT) to update the weights, with loss MSE
for mini-batches
of size K consisting of states S(k)
for k = 1, . . . , K. This ensures that the GNCA will learn to
converge to the target state in t steps. Second, during training, we use a cache to store the states
$$τ^t_θ(S(k))$$ reached by the GNCA after each forward pass.

Then, we use the cache as a replay memory
and train the GNCA on batches of states S
(k)
sampled from the cache. For every batch, the cache
is updated with the new states reached by the GNCA after t steps, and one element of the cache is
replaced with S¯ to avoid catastrophic forgetting. The cache has a size of 1024 states and is initialised
entirely with S¯. By using the cache, the GNCA is trained also on states that result from a repeated
application of the transition function. *This strategy encourages the GNCA to remain at the target
state after reaching it, while also ensuring an adequate exploration of the state space during training.*

**In other words**, we will do **two training runs**: first, we train the GNCA on the target state. Then we train it on the cache produced by this first process, to encourage fixation to an attractor'


In [None]:
#hyperparameters
cache_size = 1024 #match paper definition
batch_size = 8 #ibid lol
batches_in_epoch = 10
step_set = [3,5,7,11,13,17,19,23,29,31,37]
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = torch.nn.BCELoss()

In [None]:
def run(graph,model):

  def train_step(model,batch,steps,y):
    '''steps number of training steps given our model'''
    model.train()
    optimizer.zero_grad()
    x = batch.x
    for t in range(steps):
        x = model(x, batch.edge_index)
    print(f"In a training step, our output's shape is {out.shape}") #should be batch_size, y.shape[0],y.shape[1]
    loss = loss_fn(y.expand(batch_size,y.shape[0],y.shape[1]),x)
    loss.backward()
    optimizer.step()
    return x, loss.item()

  #history
  history = [] #can make more complex later
  y = graph.x
  #init cache
  cache = stateCache(normalize_sphere(y),size=cache_size)
  data_list = [] #for batching. Will be comprised of pyg dataset objects
  #loop over epochs now
  for _ in range(epochs):
    loss = 0
    for j in range(batches_in_epoch):
      #sample from cache
      x,idxs = cache.sample(batch_size)

      #pick random step size - deviate from paper and use primes for better periodicity handling
      step_set_length = len(step_set)
      step_idx = torch.randint(0,step_set_length,(1,))
      step = step_set[step_idx]

      #build batched data
      data_list = []
      for i in range(batch_size):
        data_list.append([x[i],edge_index]) #identical edge set
      batch = pyg.data.Batch.from_data_list(data_list) #batch our data for quicker training

      #train model
      out, loss_step = train_step(model,batch,step,y)
      loss+=loss+step

    cache.update(idxs,out,torch.ones(batch_size)*step)
    loss/=batches_in_epoch
    history.append(loss)
  #after epochs are done, training done
  return history,cache,model #come back for more detailed analytics





now we just gotta run 'run()' function and train model!

In [None]:
#to do

In [None]:
''' Grattarola et al implementation:
for graph in graphs:
        graph = NormalizeSphere()(graph)

        model = GNNCASimple(activation=args.activation, batch_norm=False)
        optimizer = Adam(learning_rate=args.lr)
        loss_fn = MeanSquaredError()

        history, state_cache = run(graph)

        # Unpack data
        y = graph.x
        a = sp_matrix_to_sp_tensor(graph.a)

        # Run model for the twice the maximum number of steps in the cache
        x = state_cache.initial_state()
        x = x[None, ...]
        steps = 2 * int(np.max(state_cache.counter))
        zs = [x]
        for _ in range(steps):
            z = model([zs[-1], a], training=False)
            zs.append(z.numpy())
        zs = np.vstack(zs)
        z = zs[-1]

        out_dir = f"{args.outdir}/{graph.name}"
        os.makedirs(out_dir, exist_ok=True)
        with open(f"{out_dir}/config.txt", "w") as f:
            f.writelines([f"{k}={v}\n" for k, v, in vars(args).items()])
        np.savez(f"{out_dir}/run_point_cloud.npz", y=y, z=z, history=history, zs=zs)

        # Plot difference between target and output points
        plt.figure(figsize=(2.5, 2.5))
        cmap = plt.get_cmap("Set2")
        plt.scatter(*y[:, :2].T, color=cmap(0), marker=".", s=1)
        plt.tight_layout()
        plt.savefig(f"{out_dir}/target.pdf")

        plt.figure(figsize=(2.5, 2.5))
        cmap = plt.get_cmap("Set2")
        plt.scatter(*z[:, :2].T, color=cmap(1), marker=".", s=1)
        plt.tight_layout()
        plt.savefig(f"{out_dir}/endstate.pdf")

        # Plot loss and loss trend
        plt.figure(figsize=(2.6, 2.5))
        cmap = plt.get_cmap("Set2")
        plt.plot(history["loss"], alpha=0.3, color=cmap(0), label="Real")
        plt.plot(gaussian_filter1d(history["loss"], 50), color=cmap(0), label="Trend")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.xscale("log")
        plt.yscale("log")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"{out_dir}/loss.pdf")

        # Plot change between consecutive state
        plt.figure(figsize=(2.5, 2.5))
        cmap = plt.get_cmap("Set2")
        change = np.abs(zs[:-1] - zs[1:]).mean((1, 2))
        loss = np.array([loss_fn(y, zs[i]).numpy() for i in range(len(zs))])
        plt.plot(change, label="Abs. change", color=cmap(0))
        plt.plot(loss, label="Loss", color=cmap(1))
        plt.xlabel("Step")
        plt.xscale("log")
        plt.yscale("log")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"{out_dir}/change.pdf")

        # Plot evolution of states
        n_states = 10
        plt.figure(figsize=(n_states * 2.0, 2.1))
        for i in range(n_states):
            plt.subplot(1, n_states, i + 1)
            plt.scatter(*zs[i, :, :2].T, color=cmap(1), marker=".", s=1)
            plt.title(f"t={i}")
        plt.tight_layout()
        plt.savefig(f"{out_dir}/evolution.pdf")

        # Plot the average number of steps for the states in the cache
        plt.figure(figsize=(2.5, 2.5))
        cmap = plt.get_cmap("Set2")
        s_avg, s_std = np.array(history["steps_avg"]), np.array(history["steps_std"])
        s_max, s_min = np.array(history["steps_max"]), np.array(history["steps_min"])
        plt.plot(s_avg, label="Avg.", color=cmap(0))
        plt.fill_between(
            np.arange(len(s_std)),
            s_avg - s_std,
            s_avg + s_std,
            alpha=0.5,
            color=cmap(0),
        )
        plt.plot(s_max, linewidth=0.5, linestyle="--", color="k", label="Max")
        plt.xlabel("Epoch")
        plt.ylabel("Number of steps in cache")
        plt.legend()
        plt.xscale("log")
        plt.tight_layout()
        plt.savefig(f"{out_dir}/steps_in_cache.pdf")

    plt.show()
    '''