![image.png](attachment:image.png)

Original paper: https://arxiv.org/abs/1711.08028


This notebook will show the example of how to solve a Sudoku puzzle with a recurrent relational network.

x is the input feature vector. <br>
h is the hidden state of the node.<br>
o is the output.<br>
m is the the message between the nodes which can be modeled with an MLP.<br>

Input vector x could also be modeled with an MLP based on the input(cell value, row and column).<br>
Each cell in the sudoku grid is considered a node in the graph.

![image.png](attachment:image.png)

In [83]:
import torch
import torch.nn.functional as F

In [31]:
#to model the message passing and feature vector(x_i) generation
class InMLP(torch.nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.relu = torch.nn.ReLU()
        self.linear = torch.nn.Linear(in_dim, out_dim)
        
    def forward(self, x):
        x = F.relu(self.linear(x))
        x = F.relu()
        x = self.linear(x)
        return x

In [84]:
#to model the message passing and feature vector(o_i) generation
class OutMLP(torch.nn.Module):
    def __init__(self, in_dim, out_dim): #in_dim=32, out_dim=1
        super().__init__()
        self.relu = torch.nn.ReLU()
        self.linear = torch.nn.Linear(in_dim, out_dim)
        
    def forward(self, x):
        x = F.relu(self.linear(x))
        x = F.relu()
        x = self.linear(x)
        return x

In [4]:
#to model the nodes
class LSTM(torch.nn.Module):
    def __init__(self, in_dim, hid_dim, n_layers): #in_dim=96 , hid_dim=32 (same used in bAbI), n_layers=1
        super().__init__()
        self.lstm = torch.nn.LSTM(in_dim, hid_dim, n_layers)
        
    def forward(self,x):
        x = self.lstm(x)
        return x

In [73]:
def preprocess(d, row, col):
    emb_d = torch.nn.Embedding(16,16)
    emb_row = torch.nn.Embedding(16,16)
    emb_col = torch.nn.Embedding(16,16)
    
    x = torch.cat((emb_d(d), emb_row(row), emb_col(col)))
    net = InMLP(48,96)
    x = net(x.flatten())
    return x, net

In [76]:
x, net = preprocess(torch.tensor([3]),torch.tensor([3]),torch.tensor([3]))
print(x)
print(net)

tensor([ 2.7750e-01,  1.0491e+00, -1.3678e+00, -2.5899e-01,  8.1340e-01,
        -7.6282e-01, -5.2294e-01,  9.4355e-01, -4.0130e-02, -1.8793e-01,
         2.9365e-01,  8.4923e-01, -1.0424e+00,  9.3863e-02, -1.0368e+00,
        -3.5215e-01,  2.3206e-01, -5.3005e-01,  1.0708e+00,  1.8314e+00,
        -1.3838e-01, -1.0330e-03,  1.7078e-01, -5.3633e-01, -3.2656e-01,
         1.1710e+00,  2.8755e-01,  9.4489e-03,  8.1821e-01,  3.5780e-01,
         1.0696e+00, -8.0348e-01,  6.8669e-01, -1.0829e+00, -1.2657e+00,
         9.6731e-01,  8.3017e-01, -2.2961e-01, -8.3212e-01, -2.8902e-01,
         1.3323e+00,  9.3642e-01,  7.4951e-01,  1.8815e+00, -5.2058e-01,
         7.7603e-02,  6.0791e-01, -8.3076e-01,  2.0323e-01,  3.8496e-01,
         9.3031e-02,  4.8259e-01,  1.5106e+00, -6.2850e-01,  2.1813e-01,
         2.0737e-01, -1.1934e+00,  1.0627e+00,  1.6528e-01,  5.3368e-01,
        -3.5828e-01, -4.2735e-02, -6.1172e-01, -1.4117e+00, -1.8862e-01,
        -3.5565e-01,  4.1457e-01,  2.0340e-01, -2.7

In [None]:
#in_tensor = (d, row, col)
class RRN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = LSTM(96, 32, 1)
        self.out_mlp = OutMLP(32,1)
        
    def preprocess(self, d, row, col):
        emb_d = torch.nn.Embedding(16,16)
        emb_row = torch.nn.Embedding(16,16)
        emb_col = torch.nn.Embedding(16,16)

        x = torch.cat((emb_d(d), emb_row(row), emb_col(col)))
        net = InMLP(48,96)
        x = net(x.flatten())
        return x, net
        
    def forward(self,x): #x - (d, row, col)
        x, net = self.preprocess(*x)
        #add messages and sum them to x
        x = self.lstm(x)
        x = self.out_mlp(x)
        return x