https://medium.com/@BorisAKnyazev/tutorial-on-graph-neural-networks-for-computer-vision-and-beyond-part-1-3d9fada3b80d

In [None]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
from scipy.spatial.distance import cdist
from varname import nameof

In [2]:
img_size = 28
col, row = np.meshgrid(np.arange(img_size), np.arange(img_size))
coord = np.stack((col, row), axis=2).reshape(-1, 2) / img_size

dist = cdist(coord, coord)
sigma = 0.2*np.pi
A = np.exp(-dist**2/sigma**2)

In [7]:
coord.shape

(784, 2)

Simple fully connected layer for MNIST

In [9]:
class FullyNet(nn.Module):
    def __init__(self):
        super(FullyNet, self).__init__()
        self.fc = nn.Linear(784, 10, bias=True)
    
    def forward(self):
        return self.fc(x.view(x.size(0), -1))        

Convolution net

In [10]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv = nn.Conv2d(1, 10, 28, stride=1, padding=14)
        self.fc = nn.Linear(4 * 4 * 10, 10, bias=False)

    def forward(self, x):
        x = F.relu(self.conv(x))
        x = F.max_pool2d(x, 7)
        return self.fc(x.view(x.size(0), -1))

Graph net

In [None]:
class GraphNet(nn.Module):
    def __init__(self, img_size=28, pred_edge=False):
        super(GraphNet, self).__init__()
        self.pred_edge = pred_edge
        N = img_size**2
        self.fc = nn.Linear(N, 10, bias=False)
        if pred_edge:
            col, row = np.meshgrid(np.arange(img_size), np.arange(img_size))
            coord = np.stack((col, row), axis=2).reshape(-1, 2)
            coord = (coord - np.mean(coord, axis=0)) / (np.std(coord, axis=0) + 1e-5)
            coord = torch.from_numpy(coord).float()  # 784,2
            coord = torch.cat((coord.unsqueeze(0).repeat(N, 1,  1),
                                    coord.unsqueeze(1).repeat(1, N, 1)), dim=2)
            #coord = torch.abs(coord[:, :, [0, 1]] - coord[:, :, [2, 3]])
            self.pred_edge_fc = nn.Sequential(nn.Linear(4, 64),
                                              nn.ReLU(),
                                              nn.Linear(64, 1),
                                              nn.Tanh())
            self.register_buffer('coord', coord)
        else:
            A = self.precompute_adjacency_images(img_size)
            self.register_buffer('A', A)
        
    @staticmethod
    def precompute_adjacency_images(img_size):
        col, row = np.meshgrid(np.arange(img_size), np.arange(img_size))
        print("row shape: ", row.shape)
        print("col shape: ", col.shape)
        
        coord = np.stack((col, row), axis=2).reshape(-1, 2) / img_size
        print("coord shape: ", coord.shape)
        
        dist = cdist(coord, coord)
        print("dist shape: ", dist.shape)
        
        sigma = 0.05*np.pi
        # normalize graph
        A = np.exp(-dist/sigma**2)
        print("A shape: ", A.shape)
        
        A[A < 0.01] = 0
        A = torch.from_numpy(A).float()
        
        D = A.sum(1)
        print("D shape: ", D.shape)
        
        D_hat = (D + 1e-5)**(-0.5)
        print("D_hat shape: ", D_hat.shape)
        
        A_hat = D_hat.view(-1, 1) * A*D_hat.view(1, -1)
        print("A_hat shape: ", A_hat.shape)
        
        A_hat[A_hat > 0.0001] = A_hat[A_hat > 0.0001] - 0.2

        print(A_hat[:10, :10])
        return A_hat
    
    def forward(self, x):
        B = x.size(0)
        if self.pred_edge:
            self.A = self.pred_edge_fc(self.coord).squeeze()
            
        avg_neighbor_features = (torch.bmm(self.A.unsqueeze(0).expand(B, -1, -1),
                                 x.view(B, -1, 1)).view(B, -1))
        return self.fc(avg_neighbor_features)

In [None]:
def train(model, device, train_loader, optimizer, epoch, log_interval):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))

In [None]:
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(
    '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [None]:
def main():
    # Training settings
    
#     parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
#     parser.add_argument('--model', type=str, default='graph', choices=['fc', 'graph', 'conv'],
#                         help='model to use for training (default: fc)')
#     parser.add_argument('--batch-size', type=int, default=64,
#                         help='input batch size for training (default: 64)')
#     parser.add_argument('--test-batch-size', type=int, default=1000,
#                         help='input batch size for testing (default: 1000)')
#     parser.add_argument('--epochs', type=int, default=10,
#                         help='number of epochs to train (default: 10)')
#     parser.add_argument('--lr', type=float, default=0.001,
#                         help='learning rate (default: 0.001)')
#     parser.add_argument('--pred_edge', action='store_true', default=False,
#                         help='predict edges instead of using predefined ones')
#     parser.add_argument('--seed', type=int, default=1,
#                         help='random seed (default: 1)')
#     parser.add_argument('--log-interval', type=int, default=200,
#                         help='how many batches to wait before logging training status')
    use_cuda = True
    model_mode = 'graph'
    torch.manual_seed(1)

#     device = torch.device("cuda" if use_cuda else "cpu")
    device = torch.device("cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=64, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=1000, shuffle=False, **kwargs)

    if model_mode == 'fc':
        assert not args.pred_edge, "this flag is meant for graphs"
        model = FullyNet()
    elif model_mode == 'graph':
        model = GraphNet(pred_edge=True)
    else:
        model = ConvNet()
    model.to(device)
    print(model)
    optimizer = optim.SGD(model.parameters(), lr=0.001, weight_decay=1e-1 if model_mode == 'conv' else 1e-4)
    print('number of trainable parameters: %d' %
          np.sum([np.prod(p.size()) if p.requires_grad else 0 for p in model.parameters()]))
    log_interval=200
    for epoch in range(1, 10 + 1):
        train(model, device, train_loader, optimizer, epoch, log_interval)
        test(model, device, test_loader)

In [27]:
main()

GraphNet(
  (fc): Linear(in_features=784, out_features=10, bias=False)
  (pred_edge_fc): Sequential(
    (0): Linear(in_features=4, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=1, bias=True)
    (3): Tanh()
  )
)
number of trainable parameters: 8225

Test set: Average loss: 0.9349, Accuracy: 7209/10000 (72%)


Test set: Average loss: 0.5357, Accuracy: 8434/10000 (84%)


Test set: Average loss: 0.6193, Accuracy: 8264/10000 (83%)


Test set: Average loss: 0.4077, Accuracy: 8815/10000 (88%)


Test set: Average loss: 0.3731, Accuracy: 8959/10000 (90%)


Test set: Average loss: 0.3999, Accuracy: 8826/10000 (88%)


Test set: Average loss: 0.3670, Accuracy: 8937/10000 (89%)


Test set: Average loss: 0.3637, Accuracy: 8952/10000 (90%)


Test set: Average loss: 0.3741, Accuracy: 8932/10000 (89%)


Test set: Average loss: 0.3466, Accuracy: 9018/10000 (90%)

