In [47]:
import json
from sentence_transformers import SentenceTransformer
from lxml import html, etree
import torch
import torch.nn.functional as F
import dgl
from copy import deepcopy

# Tree LSTM Implementation

In [None]:
class Tree:
    def __init__(self, h_size):
        self.dgl_graph = dgl.DGLGraph()
        self.h_size = h_size

    def add_node(self, parent_id=None, tensor:torch.Tensor = torch.Tensor()):
        self.dgl_graph.add_nodes(1, data={'x': tensor.unsqueeze(0),
                                          'h': tensor.new_zeros(size=(1, self.h_size)),
                                          'c': tensor.new_zeros(size=(1, self.h_size))})
        added_node_id = self.dgl_graph.number_of_nodes() - 1
        if parent_id:
            self.dgl_graph.add_edges(added_node_id, parent_id)
        return added_node_id

    def add_node_bottom_up(self, child_ids, tensor: torch.Tensor):
        self.dgl_graph.add_nodes(1, data={'x': tensor.unsqueeze(0),
                                          'h': tensor.new_zeros(size=(1, self.h_size)),
                                          'c': tensor.new_zeros(size=(1, self.h_size))})
        added_node_id = self.dgl_graph.number_of_nodes() - 1
        for child_id in child_ids:
            self.dgl_graph.add_edges(child_id, added_node_id)
        return added_node_id

    def add_link(self, child_id, parent_id):
        self.dgl_graph.add_edges(child_id, parent_id)


class BatchedTree:
    def __init__(self, tree_list):
        graph_list = []
        for tree in tree_list:
            graph_list.append(tree.dgl_graph)
        self.batch_dgl_graph = dgl.batch(graph_list)

    def get_hidden_state(self):
        graph_list = dgl.unbatch(self.batch_dgl_graph)
        hidden_states = []
        max_nodes_num = max([graph.num_nodes() for graph in graph_list])
        for graph in graph_list:
            hiddens = graph.ndata['h']
            node_num, hidden_num = hiddens.size()
            if len(hiddens) < max_nodes_num:
                padding = hiddens.new_zeros(size=(max_nodes_num - node_num, hidden_num))
                hiddens = torch.cat((hiddens, padding), dim=0)
            hidden_states.append(hiddens)
        return torch.stack(hidden_states)


class TreeLSTM(torch.nn.Module):
    def __init__(self,
                 x_size,
                 h_size,
                 dropout,
                 cell_type='n_ary',
                 n_ary=None,
                 num_stacks=2):
        super(TreeLSTM, self).__init__()
        self.x_size = x_size
        self.dropout = torch.nn.Dropout(dropout)
        if cell_type == 'n_ary':
            self.cell = NaryTreeLSTMCell(n_ary, x_size, h_size)
        else:
            self.cell = ChildSumTreeLSTMCell(x_size, h_size)
        self.num_stacks = num_stacks
        self.linear = torch.nn.Linear(h_size, 7)

    def forward(self, batch: BatchedTree):
        batches = [deepcopy(batch) for _ in range(self.num_stacks)]
        for stack in range(self.num_stacks):
            cur_batch = batches[stack]
            if stack > 0:
                prev_batch = batches[stack - 1]
                cur_batch.batch_dgl_graph.ndata['x'] = prev_batch.batch_dgl_graph.ndata['h']
            # cur_batch.batch_dgl_graph.register_message_func(self.cell.message_func)
            # cur_batch.batch_dgl_graph.register_reduce_func(self.cell.reduce_func)
            # cur_batch.batch_dgl_graph.register_apply_node_func(self.cell.apply_node_func)
            cur_batch.batch_dgl_graph.ndata['iou'] = self.cell.W_iou(self.dropout(batch.batch_dgl_graph.ndata['x']))
            dgl.prop_nodes_topo(cur_batch.batch_dgl_graph, self.cell.message_func, self.cell.reduce_func, apply_node_func = self.cell.apply_node_func)
            h = cur_batch.get_hidden_state()
        logits = self.linear(h)

        return batches, logits




class NaryTreeLSTMCell(torch.nn.Module):
    def __init__(self, n_ary, x_size, h_size):
        super(NaryTreeLSTMCell, self).__init__()
        self.n_ary = n_ary
        self.h_size = h_size
        self.W_iou = torch.nn.Linear(x_size, 3 * h_size, bias=False)
        self.U_iou = torch.nn.Linear(n_ary * h_size, 3 * h_size, bias=False)
        self.b_iou = torch.nn.Parameter(torch.zeros(1, 3 * h_size), requires_grad=True)
        self.U_f = torch.nn.Linear(n_ary * h_size, n_ary * h_size)

    def message_func(self, edges):
        return {'h': edges.src['h'], 'c': edges.src['c']}

    def reduce_func(self, nodes):
        h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1)
        padding_hs = self.n_ary - nodes.mailbox['h'].size(1)
        padding = h_cat.new_zeros(size=(nodes.mailbox['h'].size(0), padding_hs * self.h_size))
        h_cat = torch.cat((h_cat, padding), dim=1)
        f = torch.sigmoid(self.U_f(h_cat)).view(nodes.mailbox['h'].size(0), self.n_ary, self.h_size)
        padding_cs = self.n_ary - nodes.mailbox['c'].size(1)
        padding = h_cat.new_zeros(size=(nodes.mailbox['c'].size(0), padding_cs, self.h_size))
        c = torch.cat((nodes.mailbox['c'], padding), dim=1)
        c = torch.sum(f * c, 1)
        return {'iou': nodes.data['iou'] + self.U_iou(h_cat), 'c': c}

    def apply_node_func(self, nodes):
        iou = nodes.data['iou'] + self.b_iou
        i, o, u = torch.chunk(iou, 3, 1)
        i, o, u = torch.sigmoid(i), torch.sigmoid(o), torch.tanh(u)
        c = i * u + nodes.data['c']
        h = o * torch.tanh(c)
        return {'h': h, 'c': c}


class ChildSumTreeLSTMCell(torch.nn.Module):
    def __init__(self, x_size, h_size):
        super(ChildSumTreeLSTMCell, self).__init__()
        self.W_iou = torch.nn.Linear(x_size, 3 * h_size, bias=False)
        self.U_iou = torch.nn.Linear(h_size, 3 * h_size, bias=False)
        self.b_iou = torch.nn.Parameter(torch.zeros(1, 3 * h_size), requires_grad=True)
        self.U_f = torch.nn.Linear(h_size, h_size)

    def message_func(self, edges):
        return {'h': edges.src['h'], 'c': edges.src['c']}

    def reduce_func(self, nodes):
        h_tild = torch.sum(nodes.mailbox['h'], 1)
        f = torch.sigmoid(self.U_f(nodes.mailbox['h']))
        c = torch.sum(f * nodes.mailbox['c'], 1)
        return {'iou': nodes.data['iou'] + self.U_iou(h_tild), 'c': c}

    def apply_node_func(self, nodes):
        iou = nodes.data['iou'] + self.b_iou
        i, o, u = torch.chunk(iou, 3, 1)
        i, o, u = torch.sigmoid(i), torch.sigmoid(o), torch.tanh(u)
        c = i * u + nodes.data['c']
        h = o * torch.tanh(c)
        return {'h': h, 'c': c}

# Pre processing and training

In [None]:
model = SentenceTransformer('sentence-transformers/all-roberta-large-v1', device='cuda')
H_SIZE = 30
X_SIZE = 1024

In [66]:
class Node:
    def __init__(self, parent_id = None):
        self.children = []
        self.data = ""
        self.embedding = None
        self.id = None
        self.parent_id = parent_id
    def add_child(self, child):
        self.children.append(child)
    def set_data(self, tag, text):
        self.data = ""
        if tag is etree.Comment:
            self.data += "comment"
        elif tag != None:
            self.data += tag

        if text != None:
            self.data += text

    def set_id(self, id):
        self.id = id



def loadData():
    f = open('/content/drive/MyDrive/data.json', 'r')
    val = f.read()
    f.close()
    d = json.loads(val)
    return d

def idTree(htmlTree):
    nodeIds = dict()
    count = 0
    for i in htmlTree.getiterator():
        nodeIds[i] = count
        count += 1
    return nodeIds

def convertCodeToTree(code):
    t = html.fromstring(code)
    tree = Node()

    nodeIds = idTree(t)

    # Want dict mapping object -> id


    def parseTree(htmlTree, treeObject):
        tree_id = nodeIds[htmlTree]
        treeObject.set_data(htmlTree.tag, htmlTree.text)
        treeObject.set_id(tree_id)
        for treeChild in htmlTree.getchildren():
            new_tree = Node(parent_id = tree_id)
            parseTree(treeChild, new_tree)
            treeObject.add_child(new_tree)
        return treeObject

    return parseTree(t, tree)


def embbedNode(nodes):
    nodes.embedding =  torch.tensor(model.encode([nodes.data])[0])
    for i in nodes.children:
        embbedNode(i)

def makeTree(t):
    finTree = Tree(H_SIZE)


    def addNodes(root):
        finTree.add_node(parent_id = root.parent_id, tensor=root.embedding)
        for i in root.children:
            addNodes(i)

    addNodes(t)
    return finTree

def batchTrees(treeList):
    return BatchedTree(treeList)

def trainLSTM(data, labels):

    model = TreeLSTM(x_size=X_SIZE, h_size=H_SIZE, dropout=0.3, cell_type='child_sum', num_stacks=1)
    optimizer = torch.optim.Adagrad(model.parameters(),
                          lr=0.01,
                          weight_decay=0.0001)

    for epoch in range(20):

      g = data.batch_dgl_graph
      n = g.number_of_nodes()
      b, logits = model(data)

      logp = F.log_softmax(logits, 2)
      logp = torch.reshape(logp, (logp.shape[0]*logp.shape[1], logp.shape[2]))
      labels = torch.reshape(labels, (-1,))
      loss = F.nll_loss(logp, labels, reduction='sum')
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      pred = torch.argmax(logits, 2)
      pred = torch.reshape(pred, (-1,))

      acc = float(torch.sum(torch.eq(labels, pred))) / len(labels)
      print("Epoch {:05d} | Loss {:.4f} | Acc {:.4f} |".format(
          epoch, loss.item(), acc))
    return model(data)


In [67]:
data = loadData()
treeList = []
for i in data['urls'][0:2]:
    trep = convertCodeToTree(data['code'][i])
    embbedNode(trep)
    lstmTree = makeTree(trep)
    treeList.append(lstmTree)

lstmTreeBatch = batchTrees(treeList)





AttributeError: 'numpy.ndarray' object has no attribute 'unsqueeze'

In [9]:
# model = trainLSTM(lstmTreeBatch)

In [19]:
# data = loadData()

In [51]:
test = data['code'][data['urls'][0]]
# t = html.fromstring(test)
val = convertCodeToTree(test)

In [42]:
for i in t.getchildren():
    print(i.tag)

head
body


In [64]:
trep.embedding