# GraphRNNAutomaton

### Imports

In [4]:
import time, tqdm
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import RNN, Linear, Dropout 
from sklearn.metrics import accuracy_score

DEVICE = torch.device('conda:0' if torch.cuda.is_available() else 'cpu')


ModuleNotFoundError: No module named 'wandb'

### Data

### Creating Model

In [None]:
class EdgeMLP(nn.Module):
    def __init__(self, m, input_dim):
        self.m = m

        self.l1 = Linear(in_features=input_dim, out_features=512)
        self.l2 = Linear(in_features=256, out_features=256)
        self.l3 = Linear(in_features=256, out_features=512)
        self.l4 = Linear(in_features=512, out_features=2m+3)
    
        self.dropout = Dropout(p=.3)

    def forward(self, x):
        res = self.sigmoid(self.l1(x))
        res = self.dropout(res)
        res = F.leaky_relu(self.l2(res), negative_slope=.02)
        res = self.dropout(res)
        res = F.leaky_relu(self.l3(res), negative_slope=.02)
        res = self.dropout(res)
        res = F.sigmoid(self.l4(res))

        conns = res[:,:2*self.m+1]
        final_prob = res[:,2*self.m+1]
        end_prob = res[:,2*self.m+2]

        return conns, final_prob, end_prob

In [None]:
class NodeRNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=1, dropout=0):
        super(NodeRNN, self).__init__()
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim 
        self.rnn = RNN(input_dim, hidden_dim, num_layers, dropout=dropout)

    def forward(self, x, h):
        _, h = self.rnn(x, h)
        return h[-1]
    
    def get_initial_hidden(self):
        return torch.zeros((1,self.hidden_dim))

In [None]:
class AutomatonRNN(nn.Module):
    def __init__(self, m, hidden_dim):
        self.node_rnn = NodeRNN(2m+3, hidden_dim)
        self.edge_model = EdgeMLP(m, input_dim=hidden_dim)

    def forward(self, x, h):
        hidden = self.node_rnn(x, h)
        return self.edge_model(hidden), hidden
    
    def get_initial_hidden(self):
        return self.node_rnn.get_initial_hidden()
        

### Training Loop

In [None]:
def get_targets(graph, node):
    y_conns, y_final, y_end = 0, 0, 0
    return y_conns, y_final, y_end

In [None]:
def train_model(model, criterion, optim, dataloader):
    loss_val = 0
    for i, graph in enumerate(dataloader):
        start_time = time.time()
        optim.zero_grad()
        x = "SOS".to(DEVICE) # Could be get_initial_x()
        h = model.get_initial_hidden()
        iter_loss = 0
        for node in graph.nodes: # CHECK IF X USES BATCH AND IS NOT ONE
            # Process the targets of the graph
            y_conns, y_final, y_end = get_targets(graph, node)

            # Run one iteration of the model
            pred, hidden = model(x, h)

            # Update hidden and x values for next iteration
            h = hidden
            x = pred

            # Unfold the predictions
            conns, final_prob, end_prob = pred
            
            # Compute the loss function
            loss = criterion((conns, final_prob, end_prob), (y_conns, y_final, y_end))

            loss.backward()
            optim.step()

            # Add the loss value
            iter_loss += loss.item()

        loss_val += iter_loss
        #wandb.log({'train_loss':iter_loss})

        if not i%5:
            print(f"Iter {i}, duration: {time.time()-start_time}s -- TRAIN: loss {iter_loss}")
                
    return model, loss_val/len(dataloader)

### Wandb

### Training