In [222]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch_geometric.data import Data
import networkx as nx
import matplotlib.pyplot as plt


height = 5
width = 3
hidden_dim = 5

n_inputs = 2
n_outputs = 1

def build_edges(n_inputs: int, n_outputs: int, height: int, width: int):
    """
    Builds edges like 2d_grid_graph
    """
    #hidden neurons
    edge_list = list(nx.grid_2d_graph(height, width).edges())
    node_list = list(nx.grid_2d_graph(height, width).nodes())


    #replace each element of edge_list with its index in node_list
    for i in range(len(edge_list)):
        edge_list[i] = (node_list.index(edge_list[i][0]), node_list.index(edge_list[i][1]))
        
    edges = torch.tensor(edge_list)
    
    #input neurons
    input_edges = torch.tensor([
        [
            [x, (height*width) + y] for x in range(width)
        ] for y in range(n_inputs)
    ]).view(-1, 2)

    #output neurons
    output_edges = torch.tensor([
        [
            [(height*width)-(x+1), (height*width) + y+n_inputs] for x in range(width)
        ] for y in range(n_outputs)
    ]).view(-1, 2)

    #merge edges and input_edges
    edges = torch.cat((edges, input_edges, output_edges), dim=0).transpose(0,1)
    return edges



edges = build_edges(n_inputs, n_outputs, height, width)

n_nodes = height*width + n_inputs + n_outputs

hidden_dim = 5
type_dict = {"hidden": [1, 0, 0], "input": [0, 1, 0], "output": [0, 0, 1]}
total_hidden_dim = hidden_dim + len(type_dict["hidden"]) #hidden data + type

# def set_node_state(input_data: torch.Tensor, output_data: torch.Tensor):
x = torch.zeros(n_nodes, total_hidden_dim)

n_hidden_nodes = height*width
x[:n_hidden_nodes] = torch.concat((torch.zeros(hidden_dim), torch.tensor(type_dict["hidden"])))
x[n_hidden_nodes:n_hidden_nodes+n_inputs] = torch.concat((torch.zeros(hidden_dim), torch.tensor(type_dict["input"])))
x[n_hidden_nodes+n_inputs:n_hidden_nodes+n_inputs+n_outputs] = torch.concat((torch.zeros(hidden_dim), torch.tensor(type_dict["output"])))




data = Data(edge_index=edges, x=x)

graph = utils.to_networkx(data, to_undirected=True, remove_self_loops = True)
nx.draw(graph)


<IPython.core.display.Javascript object>

In [250]:
#Pytorch geometric graph classification
from torch_geometric.nn import GCNConv


class UpdateRule(torch.nn.Module):
    def __init__(self, width = 64):
        super(UpdateRule, self).__init__()
        torch.manual_seed(12345)

        self.conv1 = GCNConv(total_hidden_dim, width)
        # self.conv1 = GCNConv(hidden_dim, width)
        # self.conv2 = GCNConv(width, width)
        self.conv3 = GCNConv(width, hidden_dim)

    def forward(self, x, edge_index):
        # x, types = remove_type(x.clone())
        types = get_type(x.clone(), hidden_dim=hidden_dim)
        # print(x)
        x = self.conv1(x, edge_index)
        x = x.relu()
        # x = self.conv2(x, edge_index)
        # x = x.relu()
        x = self.conv3(x, edge_index)
        x = x.tanh()
        
        x = torch.cat([x.clone(), types], dim=-1)
        return x
update_rule = UpdateRule()


In [255]:
%matplotlib notebook
import torch_geometric.utils as utils
from utils import reset_inputs, get_type, remove_type, get_output
from tqdm.notebook import tqdm
import time

fig = plt.figure()
ax = fig.add_subplot(111)
plt.ion()


class ANDDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.data = [torch.tensor([-1,-1]), torch.tensor([-1,1]), torch.tensor([1,-1]), torch.tensor([1,1])]
        self.target = [-1,1,1,-1]
        # self.data = [torch.tensor([0,0]), torch.tensor([1,1])]
        # self.target = [0,1]
        # self.data = [torch.tensor([1,1])]
        # self.target = [1]
        
    def __getitem__(self, index):
        return self.data[index], self.target[index]
    def __len__(self):
        return len(self.data)

dataset_loader = DataLoader(ANDDataset(), batch_size=1, shuffle=True)
update_rule = UpdateRule()

optimizer = torch.optim.Adam(update_rule.parameters(), lr=0.0005)

losses = []
n_steps = 10
for epoch in range(200000):
    loss = 0
    # data = utils.from_networkx(G, group_node_attrs=all)
    x = data.x.float()

    for sub_epoc in range(5):
        for problem_data_x, problem_data_y in dataset_loader:
            
            for i in range(n_steps):
                x = reset_inputs(x, problem_data_x.float(), hidden_dim=hidden_dim).float()
                print(x[-1], problem_data_x.float())
                x = update_rule(x.clone(), data.edge_index)
                network_output = get_output(x.clone(), hidden_dim=hidden_dim)#.sigmoid()
            
            if sub_epoc > 2:
                loss += F.mse_loss(problem_data_y.float(), network_output)
    
    loss /= (5 * n_steps * len(dataset_loader))
    
    losses.append(loss.item())
        
        
    loss.backward()  
    optimizer.step()  
    optimizer.zero_grad()  
    print(f"\r Epoch {epoch} | Loss {loss} | Network out: {network_output} | {x.mean()}", end="")
    
    if epoch % 100 == 0:
        print()

ax.clear()
ax.plot(losses)
fig.canvas.draw()

<IPython.core.display.Javascript object>

tensor([0., 0., 0., 0., 0., 0., 0., 1.]) tensor([[1., 1.]])
tensor([-0.1578,  0.1196,  0.1525, -0.1017, -0.1590,  0.0000,  0.0000,  1.0000],
       grad_fn=<SelectBackward0>) tensor([[1., 1.]])
tensor([-0.0406,  0.1923,  0.1583,  0.0302, -0.2014,  0.0000,  0.0000,  1.0000],
       grad_fn=<SelectBackward0>) tensor([[1., 1.]])
tensor([-0.0978,  0.2066,  0.1592,  0.0041, -0.1826,  0.0000,  0.0000,  1.0000],
       grad_fn=<SelectBackward0>) tensor([[1., 1.]])
tensor([-0.0783,  0.2219,  0.1572,  0.0244, -0.1961,  0.0000,  0.0000,  1.0000],
       grad_fn=<SelectBackward0>) tensor([[1., 1.]])
tensor([-0.0875,  0.2250,  0.1562,  0.0233, -0.1914,  0.0000,  0.0000,  1.0000],
       grad_fn=<SelectBackward0>) tensor([[1., 1.]])
tensor([-0.0857,  0.2284,  0.1555,  0.0253, -0.1939,  0.0000,  0.0000,  1.0000],
       grad_fn=<SelectBackward0>) tensor([[1., 1.]])
tensor([-0.0864,  0.2294,  0.1550,  0.0262, -0.1934,  0.0000,  0.0000,  1.0000],
       grad_fn=<SelectBackward0>) tensor([[1., 1.]])
te

KeyboardInterrupt: 

In [None]:
import numpy as np
np.set_printoptions(precision=3)

n_steps=20
for problem_data_x, problem_data_y in dataset_loader:
    data = utils.from_networkx(G, group_node_attrs=all)
    x = data.x.float()
    for i in range(n_steps):
        x = reset_inputs(x, problem_data_x.float(), hidden_dim=hidden_dim).float()
        x = update_rule(x.clone(), data.edge_index)
        network_output = get_output(x.clone(), hidden_dim=hidden_dim)
        # print(network_output)
        # print(x[:, 0].detach().numpy())
        
    # print()
        
    print(f"Input: {problem_data_x} | Expected: {problem_data_y} | Actual: {network_output}")
    print()

Input: tensor([[-1,  1]]) | Expected: tensor([1]) | Actual: tensor([0.3333], grad_fn=<IndexBackward0>)

Input: tensor([[ 1, -1]]) | Expected: tensor([1]) | Actual: tensor([0.3333], grad_fn=<IndexBackward0>)

Input: tensor([[1, 1]]) | Expected: tensor([-1]) | Actual: tensor([0.3333], grad_fn=<IndexBackward0>)

Input: tensor([[-1, -1]]) | Expected: tensor([-1]) | Actual: tensor([-0.9933], grad_fn=<IndexBackward0>)



In [None]:
import numpy as np
np.set_printoptions(precision=3)

x[:, 0].detach().numpy()

In [None]:
#pytorch random tensor
torch.rand(1)