# Graph Auto Encoder Parser
<i>Riordan Callil 2021</i><br>
Goal: We want to take a sentence and product a parse tree. Our input will be a graph where each node is a word connected via egdes in the order they are written, this edges simply represent adjacency. The output of the parser will be a new graph where each edge represents dependency.

Dataloading
--------------
- Penn Treebank Dataset (PTB)



Model
--------

In [8]:
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import GCNConv
import torch

class VGAE(nn.Module):
    def __init__(self, input_dim, hid_dim_1, hid_dim_2):
        super(VGAE, self).__init__()
        self.gc1 = GCNConv(input_dim, hid_dim_1)
        self.fc1 = nn.Linear(hid_dim_1, hid_dim_2)
        self.fc2 = nn.Linear(hid_dim_1, hid_dim_2)
        self.decode = VGD()
    
    def encode(self, x, adj):
        x = self.gc1(x, adj)
        return self.fc1(x), self.fc2(x)
    
    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu
    
    def forward(self, x, adj):
        mu, logvar = self.encode(x, adj)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

class VGD(nn.Module):
    def __init__(self):
        super(VGD, self).__init__()
        self.dropout = 0.5
        self.activation = torch.sigmoid
    
    def forward(self, z):
        z = F.dropout(z, self.dropout, training=self.training)
        adj = self.activation(torch.mm(z, z.t()))
        return adj

edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[0,0,0,0,0,0,0,0,0,0], [0,0,0,0,0,0,0,0,0,0], [0,0,0,0,0,0,0,0,0,0]], dtype=torch.float)
test_model = VGAE(10, 5, 1)
test_model(x, edge_index)

(tensor([[0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 1.0000]], grad_fn=<SigmoidBackward>),
 tensor([[0.0566],
         [0.0566],
         [0.0566]], grad_fn=<AddmmBackward>),
 tensor([[0.4465],
         [0.4465],
         [0.4465]], grad_fn=<AddmmBackward>))

Training
---------

Visualisations
--------------
Metrics: UAS, LAS.