In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import nltk
import random
import numpy as np
from collections import Counter
import nltk
flatten = lambda l: [item for sublist in l for item in sublist]

* https://github.com/rguthrie3/DeepDependencyParsingProblemSet/tree/master/data

In [2]:
USE_CUDA = torch.cuda.is_available()

In [3]:
def getBatch(batch_size,train_data):
    random.shuffle(train_data)
    sindex=0
    eindex=batch_size
    while eindex < len(train_data):
        batch = train_data[sindex:eindex]
        temp = eindex
        eindex = eindex+batch_size
        sindex = temp
        yield batch
    
    if eindex >= len(train_data):
        batch = train_data[sindex:]
        yield batch

In [79]:
def make_vector(sents, word2index):
    idxs = list(map(lambda w: word2index[w] if w in word2index.keys() else word2index["<unk>"], sents))
    tensor = Variable(torch.LongTensor(idxs)).cuda() if USE_CUDA else  Variable(torch.LongTensor(idxs))
    return tensor.view(1,-1)

In [22]:
class TrainsitionState(object):
    
    def __init__(self,tagged_sent):
        self.root = ('ROOT','<root>',-1)
        self.stack=[self.root]
        self.buffer=[(s[0],s[1],i) for i,s in enumerate(tagged_sent)]
        self.address = [s[0] for s in tagged_sent] + [self.root[0]]
        self.arcs=[]
        self.terminal=False
        
    def __str__(self):
        return 'stack : %s \nbuffer : %s' % (str([s[0] for s in self.stack]), str([b[0] for b in self.buffer]))
    
    def shift(self):
        
        if len(self.buffer)>=1:
            self.stack.append(self.buffer.pop(0))
        else:
            print("Empty buffer")
            
    def left_arc(self,relation=None):
        
        if len(self.stack)>=2:
            arc={}
            s2 = self.stack[-2]
            s1 = self.stack[-1]
            arc['graph_id'] = len(self.arcs)
            arc['form'] = s1[0]
            arc['addr'] = s1[2]
            arc['head']=s2[2]
            arc['pos'] = s1[1]
            if relation:
                arc['relation']=relation
            self.arcs.append(arc)
            self.stack.pop(-2)
            
        elif self.stack==[self.root]:
            print("Element Lacking")
    
    def right_arc(self,relation=None):
        
        if len(self.stack)>=2:
            arc={}
            s2 = self.stack[-2]
            s1 = self.stack[-1]
            arc['graph_id'] = len(self.arcs)
            arc['form'] = s2[0]
            arc['addr'] = s2[2]
            arc['head']=s1[2]
            arc['pos'] = s2[1]
            if relation:
                arc['relation']=relation
            self.arcs.append(arc)
            self.stack.pop(-1)
            
        elif self.stack==[self.root]:
            print("Element Lacking")
    
    def get_left_most(self,index):
        left=['<NULL>','<NULL>',None]
        
        if index==None: return left
        for arc in self.arcs:
            if arc['head']==index:
                left=[arc['form'],arc['pos'],arc['addr']]
                break
        return left
    
    def get_right_most(self,index):
        right=['<NULL>','<NULL>',None]
        
        if index==None: return right
        for arc in reversed(self.arcs):
            if arc['head']==index:
                right=[arc['form'],arc['pos'],arc['addr']]
                break
        return right
    
    def is_done(self):
        return len(self.buffer)==0 and self.stack==[self.root]

In [23]:
temp = TrainsitionState(nltk.pos_tag("He has good control .".split()))
print(temp)
temp.shift()
temp.shift()
print(temp)
temp.left_arc()
print(temp)
print(temp.arcs)
temp.shift()
temp.shift()
print(temp)
temp.left_arc()
print(temp)
temp.right_arc()
print(temp)
temp.shift()
temp.right_arc()
print(temp)
temp.right_arc()
print(temp)
print(temp.arcs)
temp.is_done()

stack : ['ROOT'] 
buffer : ['He', 'has', 'good', 'control', '.']
stack : ['ROOT', 'He', 'has'] 
buffer : ['good', 'control', '.']
stack : ['ROOT', 'has'] 
buffer : ['good', 'control', '.']
[{'head': 0, 'addr': 1, 'graph_id': 0, 'pos': 'VBZ', 'form': 'has'}]
stack : ['ROOT', 'has', 'good', 'control'] 
buffer : ['.']
stack : ['ROOT', 'has', 'control'] 
buffer : ['.']
stack : ['ROOT', 'has'] 
buffer : ['.']
stack : ['ROOT', 'has'] 
buffer : []
stack : ['ROOT'] 
buffer : []
[{'head': 0, 'addr': 1, 'graph_id': 0, 'pos': 'VBZ', 'form': 'has'}, {'head': 2, 'addr': 3, 'graph_id': 1, 'pos': 'NN', 'form': 'control'}, {'head': 3, 'addr': 1, 'graph_id': 2, 'pos': 'VBZ', 'form': 'has'}, {'head': 4, 'addr': 1, 'graph_id': 3, 'pos': 'VBZ', 'form': 'has'}, {'head': 1, 'addr': -1, 'graph_id': 4, 'pos': '<root>', 'form': 'ROOT'}]


True

In [124]:
import pydot

In [279]:
def plot_tree(state,image_name):
    graph = pydot.Dot(graph_type='graph')
    for arc in state.arcs:
        edge = pydot.Edge(arc['form'],state.address[arc['head']])
        graph.add_edge(edge)
    
    graph.write_png('t_graph.png')

# Data load & Preprocessing 

In [24]:
def get_feat(transition_state,word2index,tag2index,label2index=None):
    word_feats=[]
    tag_feats = []
    
    word_feats.append(transition_state.stack[-1][0]) if len(transition_state.stack)>=1 and \
    transition_state.stack[-1][0] in word2index.keys() else word_feats.append('<NULL>') # s1
    word_feats.append(transition_state.stack[-2][0]) if len(transition_state.stack)>=2 and \
    transition_state.stack[-2][0] in word2index.keys() else word_feats.append('<NULL>') # s2
    word_feats.append(transition_state.stack[-3][0]) if len(transition_state.stack)>=3 and \
    transition_state.stack[-3][0] in word2index.keys() else word_feats.append('<NULL>') # s3
    
    tag_feats.append(transition_state.stack[-1][1]) if len(transition_state.stack)>=1 and \
    transition_state.stack[-1][1] in tag2index.keys() else tag_feats.append('<NULL>') # st1
    tag_feats.append(transition_state.stack[-2][1]) if len(transition_state.stack)>=2 and \
    transition_state.stack[-2][1] in tag2index.keys() else tag_feats.append('<NULL>') # st2
    tag_feats.append(transition_state.stack[-3][1]) if len(transition_state.stack)>=3 and \
    transition_state.stack[-3][1] in tag2index.keys() else tag_feats.append('<NULL>') # st3
    
    
    word_feats.append(transition_state.buffer[0][0]) if len(transition_state.buffer)>=1 and \
    transition_state.buffer[0][0] in word2index.keys() else word_feats.append('<NULL>') # b1
    word_feats.append(transition_state.buffer[1][0]) if len(transition_state.buffer)>=2 and \
    transition_state.buffer[1][0] in word2index.keys() else word_feats.append('<NULL>') # b2
    word_feats.append(transition_state.buffer[2][0]) if len(transition_state.buffer)>=3 and \
    transition_state.buffer[2][0] in word2index.keys() else word_feats.append('<NULL>') # b3
    
    tag_feats.append(transition_state.buffer[0][1]) if len(transition_state.buffer)>=1 and \
    transition_state.buffer[0][1] in tag2index.keys() else tag_feats.append('<NULL>') # bt1
    tag_feats.append(transition_state.buffer[1][1]) if len(transition_state.buffer)>=2 and \
    transition_state.buffer[1][1] in tag2index.keys() else tag_feats.append('<NULL>') # bt2
    tag_feats.append(transition_state.buffer[2][1]) if len(transition_state.buffer)>=3 and \
    transition_state.buffer[2][1] in tag2index.keys() else tag_feats.append('<NULL>') # bt3
    
    
    lc_s1 = transition_state.get_left_most(transition_state.stack[-1][2]) if len(transition_state.stack)>=1 \
    else transition_state.get_left_most(None)
    rc_s1 = transition_state.get_right_most(transition_state.stack[-1][2]) if len(transition_state.stack)>=1 \
    else transition_state.get_right_most(None)
    
    lc_s2 = transition_state.get_left_most(transition_state.stack[-2][2]) if len(transition_state.stack)>=2 \
    else transition_state.get_left_most(None)
    rc_s2 = transition_state.get_right_most(transition_state.stack[-2][2]) if len(transition_state.stack)>=2 \
    else transition_state.get_right_most(None)
    
    words, tags, _ = zip(*[lc_s1,rc_s1,lc_s2,rc_s2])
    
    word_feats.extend(words)
    
    tag_feats.extend(tags)
    
    
    return make_vector(word_feats,word2index), make_vector(tag_feats,tag2index)

In [26]:
data = open('../DeepDependencyParsingProblemSet/data/train.txt','r').readlines()
vocab = open('../DeepDependencyParsingProblemSet/data/vocab.txt','r').readlines()

In [80]:
train_data = [[nltk.pos_tag(d.split('|||')[0].split()), d.split('|||')[1][:-1].split()] for d in data]

In [81]:
train_x,train_y = list(zip(*train_data))
train_x_f = flatten(train_x)
sents,pos_tags = list(zip(*train_x_f))

In [82]:
tag2index = {v:i for i,v in enumerate(set(pos_tags))}
tag2index['<root>']=len(tag2index)
tag2index['<NULL>']=len(tag2index)

In [83]:
vocab = [v.split('\t')[0] for v in vocab]
word2index = {v:i for i,v in enumerate(vocab)}
word2index['ROOT'] = len(word2index)
word2index['<NULL>'] = len(word2index)

In [84]:
actions = ['SHIFT','REDUCE_L','REDUCE_R']
action2index = {v:i for i,v in enumerate(actions)}

In [161]:
p_train=[]

In [162]:
for tx,ty in train_data:
    state = TrainsitionState(tx)
    transition = ty+['REDUCE_R']
    while len(transition)!=0:
        feat = get_feat(state,word2index,tag2index)
        action = transition.pop(0)
        actionTensor = Variable(torch.LongTensor([action2index[action]])).view(1,-1).cuda() if USE_CUDA \
        else Variable(torch.LongTensor([action2index[action]])).view(1,-1)
        p_train.append([feat,actionTensor])
        if action=='SHIFT':
            state.shift()
        elif action=='REDUCE_R':
            state.right_arc()
        elif action=='REDUCE_L':
            state.left_arc()

In [163]:
p_train[0]

[(Variable containing:
   9151  9152  9152  2106     2   353  9152  9152  9152  9152
  [torch.LongTensor of size 1x10], Variable containing:
     43    44    44    30    34    29    44    44    44    44
  [torch.LongTensor of size 1x10]), Variable containing:
  0
 [torch.LongTensor of size 1x1]]

In [187]:
class NeuralDependencyParser(nn.Module):
    
    def __init__(self,w_size,w_embed_dim,t_size,t_embed_dim,hidden_size,target_size):
        
        super(NeuralDependencyParser, self).__init__()
        
        self.w_embed =  nn.Embedding(w_size,w_embed_dim)
        self.t_embed = nn.Embedding(t_size,t_embed_dim)
        self.hidden_size = hidden_size
        self.target_size = target_size
        self.linear = nn.Linear((w_embed_dim+t_embed_dim)*10,self.hidden_size)
        self.out = nn.Linear(self.hidden_size,self.target_size)
        
        self.w_embed.weight.data.uniform_(-0.01, 0.01) # init
        self.t_embed.weight.data.uniform_(-0.01, 0.01) # init
        
    def forward(self,words,tags):
        
        wem = self.w_embed(words).view(words.size(0),-1)
        tem = self.t_embed(tags).view(tags.size(0),-1)
        inputs = torch.cat([wem,tem],1)
        h1 = torch.pow(self.linear(inputs),3) # cube function
        preds = -self.out(h1)
        return F.log_softmax(preds)

In [None]:
STEP = 5
W_EMBED_SIZE = 50
T_EMBED_SIZE = 10
HIDDEN_SIZE=512
LR = 0.01

In [188]:
model = NeuralDependencyParser(len(word2index),W_EMBED_SIZE,len(tag2index),T_EMBED_SIZE,HIDDEN_SIZE,len(action2index))
if USE_CUDA:
    model = model.cuda()

loss_function = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(),lr=LR)

In [191]:
losses=[]

In [192]:
for i, batch in enumerate(getBatch(32,p_train)):
    
    model.zero_grad()
    inputs, targets = list(zip(*batch))
    words, tags = list(zip(*inputs))
    words = torch.cat(words)
    tags = torch.cat(tags)
    targets = torch.cat(targets)
    preds = model(words,tags)
    loss = loss_function(preds,targets.view(-1))
    loss.backward()
    optimizer.step()
    
    losses.append(loss.data.cpu().tolist()[0] if USE_CUDA else loss.data.tolist()[0])
    
    if i % 100==0:
        print(np.mean(losses))
        losses=[]

1.09684097767
0.737392584085
0.618838954717
0.693304691613
1.06236269116
1.83809881449
4.97554338366
9.61246664047
12.1350490642
12.1166139519
9.60266982526
8.15794750815
7.09933471501
6.81534143507
6.49781926632
5.87462216884
5.22908880115
4.69333528399
4.52874009252
4.62807468295
4.31034019113
3.43664512917
3.45341291964
3.25474904835
3.85090582579
3.66804424644
3.28917575121
3.16772120774
3.17027136356
2.78850893259
2.66806945607
3.0550786984
3.04333923131
3.15591797143
2.94243722975
2.90874896646
2.83968224376
2.8125838387
3.11807466447
3.24308602437


KeyboardInterrupt: 

In [185]:
losses[0]

1.0932177305221558