# GraphRNNAutomaton

### Imports

In [1]:
import time, tqdm
#import wandb
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.nn import RNN, Linear, Dropout 
from exporter import read_automatas
from automata import Automata

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

cuda:0


### Data

In [2]:
automata_property = 'unique_accepting'

In [3]:
path = f'./data/{automata_property}_property_automatas'
automatas = read_automatas(path)

### Creating Model

In [4]:
class EdgeMLP(nn.Module):
    def __init__(self, m, input_dim):
        super(EdgeMLP, self).__init__()
        self.m = m

        self.l1 = Linear(in_features=input_dim, out_features=512)
        self.l2 = Linear(in_features=512, out_features=256)
        self.l3 = Linear(in_features=256, out_features=512)
        self.l4 = Linear(in_features=512, out_features=2*m+3)
    
        self.dropout = Dropout(p=.3)

    def forward(self, x):
        res = F.sigmoid(self.l1(x))
        res = self.dropout(res)
        res = F.leaky_relu(self.l2(res), negative_slope=.02)
        res = self.dropout(res)
        res = F.leaky_relu(self.l3(res), negative_slope=.02)
        res = self.dropout(res)
        res = F.sigmoid(self.l4(res))

        return res

In [5]:
class NodeRNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=1, dropout=0):
        super(NodeRNN, self).__init__()
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim 
        self.rnn = RNN(input_dim, hidden_dim, num_layers, dropout=dropout)

    def forward(self, x, h):
        _, h = self.rnn(x, h)
        return h[-1]
    
    def get_sos(self):
        return torch.zeros((1,self.input_dim))
    
    def get_initial_hidden(self):
        return torch.zeros((1,self.hidden_dim))

In [6]:
class AutomatonRNN(nn.Module):
    def __init__(self, m, hidden_dim):
        super(AutomatonRNN, self).__init__()
        self.m = m
        self.node_rnn = NodeRNN(2*m+3, hidden_dim)
        self.edge_model = EdgeMLP(m, input_dim=hidden_dim)

    def forward(self, x, h):
        hidden = self.node_rnn(x, h)
        return self.edge_model(hidden), hidden
    
    def get_sos(self):
        return self.node_rnn.get_sos()

    def get_initial_hidden(self):
        return self.node_rnn.get_initial_hidden()
        

### Obtaining data from a graph

In [7]:
def get_targets(graph, node, m):
    y_final = node in graph['final_states']
    y_final = torch.tensor([int(y_final)],dtype=torch.float)

    initial_pos = max(0, node - m)
    in_conns = np.array(graph['transitions'][initial_pos:node,node],dtype=np.float32)
    loop_con = np.array([graph['transitions'][node,node]], dtype=np.float32)
    out_conns = np.array(graph['transitions'][node,initial_pos:node],dtype=np.float32)
    
    padding_size = max(0, m - node)
    padding = np.zeros(padding_size,dtype=np.float32)
    y_conns = torch.tensor(np.concatenate((padding, in_conns, loop_con, out_conns, padding)))
    return y_conns, y_final

def get_nodes(graph):
    return range(len(graph['transitions']))

In [8]:
def unfold_pred(res, m):
    conns = res[:2*m+1]
    final_prob = res[2*m+1]
    end_prob = res[2*m+2]
    return conns, final_prob, end_prob

### Loss function

In [9]:
def compose_loss(y_hat, y):
    conns_hat, final_prob_hat, end_prob_hat = y_hat
    conns, final_prob, end_prob = y
    # Convert to batch and BCE loss for conns
    conns_loss = nn.BCELoss()(conns_hat, conns)
    # BCE loss for final prob
    final_prob_loss = nn.BCELoss()(final_prob_hat, final_prob)
    # BCE loss for end prob
    end_prob_loss = nn.BCELoss()(end_prob_hat, end_prob)

    # Total loss us the sum of all losses
    return conns_loss + final_prob_loss + end_prob_loss

### Training Loop

In [10]:
def train_model(model, optim, dataloader, criterion, epochs):
    for epoch in range(epochs):
        start_time = time.time()
        loss_val = 0
        for _, graph in tqdm(enumerate(dataloader)):
            x = model.get_sos()
            h = model.get_initial_hidden()
            iter_loss = 0
            graph_nodes = get_nodes(graph)
            for node in graph_nodes: # CHECK IF X USES BATCH AND IS NOT ONE
                # Process the targets of the graph
                optim.zero_grad()
                y_conns, y_final = get_targets(graph, node, model.m)
                y_end = node == (len(graph_nodes)-1)
                y_end = torch.tensor([int(y_end)],dtype=torch.float)

                # Run one iteration of the model
                pred, hidden = model(x, h)
                
                # Compute the loss function
                y = torch.cat((y_conns, y_final, y_end))
                loss = criterion(pred, y)
                loss.backward(retain_graph=True)
                optim.step()

                # Update hidden and x values for next iteration
                h = hidden.reshape(1,-1).detach().requires_grad_()
                x = pred.reshape(1,-1).detach().requires_grad_()

                # Add the loss value
                iter_loss += loss.item()

            loss_val += iter_loss
            #wandb.log({'train_loss':iter_loss})
        if not epoch%0:
            print(f"Epoch {epoch}, duration: {time.time()-start_time}s -- TRAIN: loss {loss_val/len(dataloader)}")
                
    return model, loss_val/len(dataloader)

### Wandb

### Creating model and optimizer

In [11]:
m = 20
hidden_dim = 256
automaton_rnn = AutomatonRNN(m, hidden_dim)
criterion = nn.BCELoss(weight=torch.Tensor(np.ones(2*m+3)*100))

optim = torch.optim.Adam(automaton_rnn.parameters(), lr=.002)

### Run training!

In [None]:
automaton_rnn, training_loss = train_model(automaton_rnn, optim, automatas, criterion, epochs=2)

In [None]:
print(f'Final training loss {training_loss}')

In [28]:
class Graph:
    def __init__(self):
        self.nodes = {}
        self.final_nodes = set()
        
    def add_node(self, node, conns, is_final):
        self.nodes[node] = set()
        m = (len(conns)-1)//2
        in_conns = conns[max(0, m-node):m]
        loop_p = float(conns[m])
        out_conns = conns[m+1:len(conns)-max(0,m-node)]

        for target, p_in in enumerate(in_conns):
            p_in = float(p_in)
            in_connection = np.random.choice([False, True], p=[1-p_in, p_in])
            if in_connection:
                self.nodes[target].add(node)

            p_out = float(out_conns[target])
            out_connection = np.random.choice([False, True], p=[1-p_out, p_out])
            if out_connection:
                self.nodes[node].add(target)
        
        loop_connection = np.random.choice([False, True], p=[1-loop_p, loop_p])
        if loop_connection:
            self.nodes[node].add(node)

        if is_final:
            self.final_nodes.add(node)

In [29]:
def generate(model, max_nodes):
    with torch.no_grad():
        graph = Graph()
        x = model.get_sos()
        h = model.get_initial_hidden()
        end = False
        node = 0
        while not end:
            x, h = model(x, h)
            conns, final_prob, end_prob = unfold_pred(x, model.m)
            final_prob = float(final_prob)
            is_final = np.random.choice([False, True], p=[1-final_prob, final_prob])
            graph.add_node(node, conns, is_final)
            end_prob = float(end_prob)
            end = np.random.choice([False, True], p=[1-end_prob, end_prob])
            node += 1
            x = x.reshape(1,-1)
            h = h.reshape(1,-1)

            if node > max_nodes:
                end = True
        return graph


In [35]:
graph = generate(automaton_rnn, 25)

In [36]:
print(f'Final Nodes: {graph.final_nodes}')
print(f'Nodes: {graph.nodes}')

Final Nodes: set()
Nodes: {0: set(), 1: set(), 2: {1, 2}, 3: set(), 4: set(), 5: set(), 6: set(), 7: {6}, 8: {3}}
