In [2]:
import numpy as np
import torch

import matplotlib.pyplot as plt



In [3]:
data = np.loadtxt('MNIST/mnist_train.csv', delimiter=',',)

In [52]:
labels = data[:,0]
images = data[:,1:]>0

In [74]:
# distances from center in an 28x28 image
x = np.linspace(-13.5, 13.5, 28)


dists = np.sqrt(np.add.outer(x**2, x**2))

In [None]:
from torch_geometric.nn import GCNConv, Linear, GraphConv
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from typing import Union, Tuple
from torch_geometric.typing import Adj, OptTensor, Size, OptPairTensor, Tensor
from torch_geometric.utils import spmm

class CustomGraphConv(MessagePassing):
	def __init__(
		self,
		aggr: str = 'mean',
		bias: bool = True,
		out_channels: int = 8,
		**kwargs,
	):
		super().__init__(aggr=aggr, **kwargs)


		self.lin_rel = Linear(2, out_channels=out_channels, bias=bias)

		self.reset_parameters()

	def reset_parameters(self):
		super().reset_parameters()
		self.lin_rel.reset_parameters()
		# self.lin_root.reset_parameters()


	def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
				edge_weight: OptTensor = None, size: Size = None, add_root_weight : bool = True) -> Tensor:

		print("x.shape")

		msg = self.propagate(edge_index, x=x[:,0], edge_weight=edge_weight,
							 size=None)
		
		# propagate again
		print("msg.shape")
		print(msg.shape)

		msg = torch.cat([x, msg], dim=1)


		out = self.lin_rel(msg)
		
		return out


	def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor:
		print(self.aggr)
		return spmm(adj_t, x[0], reduce=self.aggr)
	

class CustomGraphNCA(torch.nn.Module):
	def __init__(self):
		super(CustomGraphNCA, self).__init__()
		N = 32
		self.conv1 = CustomGraphConv(out_channels=N)
		self.lin = Linear(N, N)
		self.lin2 = Linear(N, 10)

	def forward(self, x, edge_index):
		x = F.relu(self.conv1(x, edge_index))
		x = F.relu(self.lin(x))
		x = F.relu(self.lin2(x))


		# take weighted average of the guesses
		# out = F.softmax(x, dim=1)
		out = F.relu(x)
		# out = torch.sum(out, dim=0)
		# nonzero mean
		# nonzero = torch.nonzero(out)
		# out = out[nonzero].mean(dim=0)
		out = out.sum(dim=0)
		# out = F.softmax(out, dim=0)
		
		return out
	

# define edges for a 28x28 grid where each node is connected to its 8 neighbors
N = 28

adjacency = np.zeros((N*N, N*N))


for i in range(N):
	for j in range(N):
		neighbors = []
		if i > 0:
			neighbors.append((i-1)*N+j)
		if i < N-1:
			neighbors.append((i+1)*N+j)
		if j > 0:
			neighbors.append(i*N+j-1)
		if j < N-1:
			neighbors.append(i*N+j+1)
		if i > 0 and j > 0:
			neighbors.append((i-1)*N+j-1)
		if i > 0 and j < N-1:
			neighbors.append((i-1)*N+j+1)
		if i < N-1 and j > 0:
			neighbors.append((i+1)*N+j-1)
		if i < N-1 and j < N-1:
			neighbors.append((i+1)*N+j+1)
		for n in neighbors:
			adjacency[i*N+j, n] = 1
			adjacency[n, i*N+j] = 1




# make the adjacency matrix into edge list
edges = []
for i in range(N*N):
	for j in range(N*N):
		if adjacency[i,j] == 1:
			edges.append((i,j))

edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()




In [121]:
len(images)

60000

In [128]:
model = CustomGraphNCA()
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


accuracy = torch.tensor(0, dtype=torch.float32)
iteration = 0
for img, tg in zip(images, labels):
    image = torch.tensor(img, dtype=torch.float32).unsqueeze(1)
    image_with_dists = torch.cat([image, torch.tensor(dists.reshape(-1, 1), dtype=torch.float32)], dim=1)
    target = torch.tensor(tg, dtype=torch.float32)

    optimizer.zero_grad()
    output = model(image_with_dists, edge_index)

    # output in range (-inf, 0) make into logits
    

    loss_val = loss(output.view(1, -1), target.view(1).long())

    # print(output.shape)
    loss_val.backward()
    optimizer.step()
    iteration += 1

    props = F.softmax(output, dim=0)

    accuracy += (props.argmax() == tg).float()
    if iteration % 1000 == 0:
        print(f'Iteration {iteration} | {iteration/len(images)*100:.3}, loss: {loss_val.item()}')
        print(f'Accuracy: {accuracy/1000}')
        accuracy = torch.tensor(0, dtype=torch.float32)

    if iteration % 10000 == 0:
        plt.imshow(img.reshape(28,28))
        plt.title('Predicted: %d, True: %d' % (props.argmax(), tg))
        plt.show()
        print(props)



x.shape
torch.Size([784, 2])
x0.shape
torch.Size([784])


IndexError: Dimension out of range (expected to be in range of [-1, 0], but got -2)

In [None]:
1/10