In [1]:
import torch
from torch import nn
import torch.distributions as ds

In [244]:
def reset_lstm(lstm):
    for parameter in lstm.named_parameters():
        name = parameter[0]
        if "bias" in name:
            nn.init.constant_(parameter[1], val=5)
        elif "ih" in name:
            nn.init.xavier_uniform_(parameter[1])
        elif "hh" in name:
            nn.init.orthogonal_(parameter[1])
        else:
            raise ValueError("Problem")

In [245]:
def masked_softmax(logits, mask):
    probs = torch.softmax(logits, dim=-1) * mask
    probs = probs + (mask.sum(dim=-1, keepdim=True) == 0.).to(dtype=torch.float32)
    Z = probs.sum(dim=-1, keepdim=True)
    return probs / Z

In [246]:
def gumbel_softmax(logits, temperature, mask=None):
    epsilon = 1e-20
    
    # get gumbel noise
    unif = ds.Uniform(0,1).sample(logits.size())
    gumbel_noise = -(-(unif + epsilon).log() + epsilon).log()
    
    # get samples 
    new_logits = (logits + gumbel_noise) / temperature
    if mask is None:
        y = new_logits.softmax(dim=-1)
    else:
        y = masked_softmax(new_logits, mask)
        
    # hard samples
    y_st = torch.zeros_like(y).scatter_(-1, y.argmax(dim=-1, keepdim=True), 1.0)
    # sample with gradients
    y = (y_st - y).detach() + y
    return y

In [247]:
def cat_entropy(logits, mask):
    probs = masked_softmax(logits, mask) + 1e-17
    entropy = -(probs.log() * probs * mask).sum(-1) * (mask.sum(-1) != 1.).float()
    return entropy

def cat_norm_entropy(logits, mask):
    log_n = (mask.sum(-1) + 1e-17).log()
    entropy = cat_entropy(logits, mask)
    return entropy / (log_n + 1e-17)

def cat_logprob(logits, mask, values):
    # values is one-hot encoded
    lprobs = masked_softmax(logits, mask).log()
    log_prob = torch.gather(lprobs, -1, values.argmax(-1, keepdim=True)).squeeze()
    return log_prob * (mask.sum(-1) != 0.).float()

def get_seqmask(seqlens):
    # get sequence mask from seqlens
    # output shape = batch X maxlen
    maxlen = seqlens.max()
    batch = seqlens.shape[0]
    arange = torch.arange(maxlen).unsqueeze(0).expand(batch, -1).long().to(seqlens)
    return (arange < seqlens.unsqueeze(-1)).float()

In [248]:
class BTreeLSTMCell(nn.Module):
    def __init__(self, hdim, dropout_prob=None):
        super().__init__()
        self.hdim = hdim
        self.linear = nn.Linear(in_features = 2*self.hdim, out_features = 5*self.hdim)
        if dropout_prob is None:
            self.dropout = lambda x : x
        else: 
            self.dropout = nn.Dropout(dropout_prob)  
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.orthogonal_(self.linear.weight)
        nn.init.constant_(self.linear.bias, val = 0)
        nn.init.constant_(self.linear.bias[self.hdim:3*self.hdim], val = 1)
        
    def forward(self, hl, cl, hr, cr):
        # h[], c[] : Shape = batch X seqlen X hdim
        h = torch.cat([hl, hr], dim=-1)
        i, fl, fr, o, g = self.linear(h).chunk(chunks = 5, dim = -1)
        cp = self.dropout(g.tanh_()) * i.sigmoid_() + cl * fl.sigmoid_() + cr * fr.sigmoid_()
        hp = o.sigmoid() * cp.tanh()
        return hp, cp

In [249]:
class BTreeLSTMBase(nn.Module):
    def __init__(self, idim, hdim, tdim, dropout_prob=None):
        super().__init__()
        self.leaftransformer_lstm = nn.LSTM(idim, tdim)
        self.leaftransformer_linear = nn.Linear(tdim, 2*hdim)
        
        self.treelstm_cell = BTreeLSTMCell(hdim, dropout_prob)
        
        BTreeLSTMBase.reset_parameters(self)
    
    def reset_parameters(self):
        nn.init.orthogonal_(self.leaftransformer_linear.weight)
        nn.init.constant_(self.leaftransformer_linear.bias, val=0)
        self.treelstm_cell.reset_parameters()
        self.leaftransformer_lstm.reset_parameters()
    
    def transform_leafs(self, x):
        # x : Shape = batch X seqlen X idim
        x = self.leaftransformer_lstm(x)[0]
        # Shape = batch X seqlen X 2*hdim
        x = self.leaftransformer_linear(x).tanh()
        # Shape = (batch X seqlen X hdim, batch X seqlen X hdim)
        return x.chunk(chunks=2, dim=-1)
    
    def compose(self, composition, hl, cl, hr, cr, hp, cp, mask):
        # composition : Shape = batch X seqlen
        # hl, hr, hp, cl, cr, cp : Shape = batch X seqlen X hdim
        # mask : Shape = batch X seqlen
        # mask is for padding
        cumsum = torch.cumsum(composition, dim=-1)
        
        # Shape = batch X maxlen X 1
        # for broadcasting
        ml = (1 - cumsum).unsqueeze(-1)
        mr = (cumsum - 1).unsqueeze(-1)
        mask = mask.unsqueeze(-1)
        composition = composition.unsqueeze(-1)
        
        # next layer
        hp = mask * (ml * hl + mr * hr + composition * hp) + (1 - mask) * hl
        cp = mask * (ml * cl + mr * cr + composition * cp) + (1 - mask) * cl
        return hp, cp
    
    def forward(self, *inputs):
        raise NotImplementedError

In [250]:
class BTreeLSTMParser(BTreeLSTMBase):
    def __init__(self, idim, hdim, tdim, gumbel_temperature, dropout_prob=None):
        super().__init__(idim, hdim, tdim, dropout_prob)
        self.q = nn.Parameter(torch.FloatTensor(hdim))
        # temperature for gumbel softmax
        self.gumbel_temperature = gumbel_temperature
        self.reset_parameters()
        
    def reset_parameters(self):
        super().reset_parameters()
        nn.init.normal_(self.q, mean=0, std=0.01)
    
    def sample_composition(self, query_weights, mask):
        if self.training:
            # sample from gumbel_softmax if training
            composition = gumbel_softmax(query_weights, self.gumbel_temperature, mask)
        else:
            # greedy if not
            logits = masked_softmax(query_weights, mask)
            composition = torch.zeros_like(logits).scatter_(-1, logits.argmax(dim=-1, keepdim=True), 1.0)
        return composition
    
    def step(self, h, c, mask, eval_composition):
        # get left and right sides
        hl, hr = h[:,:-1], h[:,1:]
        cl, cr = c[:,:-1], c[:,1:]
        # composed states
        hp, cp = self.treelstm_cell(hl, cl, hr, cr)
        
        # get composition query weights
        query_weights = torch.matmul(hp, self.q)
        if eval_composition is None:
            # sample is not given
            composition = self.sample_composition(query_weights, mask)
        else:
            # use provided mergers if available
            composition = eval_composition
            
        # perform composition
        hp, cp = self.compose(composition, hl, cl, hr, cr, hp, cp, mask)
        return hp, cp, composition, query_weights
        
    def forward(self, x, mask, eval_tree_compositions=None):
        # transform the leafs
        h, c = self.transform_leafs(x)
        
        # values to record
        entropy = []
        norm_entropy = []
        log_probs = []
        tree_compositions = []
        hs = [h]
        cs = [c]
        for i in range(x.shape[1]-1):
            # get the relevant mask (1 less than the pervious one)
            rel_mask = mask[:, i+1:]
            # perfrom a step (move up a layer)
            eval_composition = None if eval_tree_compositions is None else eval_tree_compositions[i]
            h, c, composition, query_weights = self.step(h, c, rel_mask, eval_composition)
            tree_compositions.append(composition)
            entropy.append(cat_entropy(query_weights, rel_mask))
            norm_entropy.append(cat_norm_entropy(query_weights, rel_mask))
            log_probs.append(cat_logprob(query_weights, rel_mask, composition))
            hs.append(h)
            cs.append(c)
            
        entropy = sum(entropy)
        norm_entropy = sum(norm_entropy) / (mask[:, 2:].sum(-1) + 1e-17)
        log_probs = sum(log_probs)
        
        return tree_compositions, log_probs, entropy, norm_entropy

In [251]:
class BTreeLSTMComposer(BTreeLSTMBase):
    def __init__(self, idim, hdim, tdim, dropout_prob=None):
        super().__init__(idim, hdim, tdim, dropout_prob)
    
    def forward(self, x, mask, tree_compositions):        
        # transform the leafs
        h, c = self.transform_leafs(x)
        
        # perform merges
        for i in range(x.shape[1]-1):
            hl, hr = h[:,:-1], h[:,1:]
            cl, cr = c[:,:-1], c[:,1:]
            hp, cp = self.treelstm_cell(hl, cl, hr, cr)
            h, c = self.compose(tree_compositions[i], hl, cl, hr, cr, hp, cp, mask[:, i+1:])
        # return root
        return h.squeeze(1)

In [262]:
class Model(nn.Module):
    def __init__(self, vocab_size, idim, hdim, p_tdim, c_tdim, odim, gumbel_temperature):
        super().__init__()
        self.word_embeddings = nn.Embedding(vocab_size, idim)
        self.parser = BTreeLSTMParser(idim, hdim, p_tdim, gumbel_temperature)
        self.tree_embeddings = nn.Embedding(vocab_size, idim)
        self.composer = BTreeLSTMComposer(idim, hdim, c_tdim)
        self.linear = nn.Linear(hdim, odim)
        
        self.running_reward_var = 1.0
        self.norm_alpha = 0.9
        self.criterion = nn.CrossEntropyLoss(reduction='none')
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.normal_(self.word_embeddings.weight, 0.0, 0.01)
        nn.init.normal_(self.tree_embeddings.weight, 0.0, 0.01)
        nn.init.xavier_uniform_(self.linear.weight)
        nn.init.constant_(self.linear.bias, val=0)
        self.parser.reset_parameters()
        self.composer.reset_parameters()
        
    def single_pass(self, x, mask, labels):
        tree_compositions, log_probs, entropy, norm_entropy = self.parser(self.word_embeddings(x), mask)
        out = self.composer(self.tree_embeddings(x), mask, tree_compositions)
        logits = self.linear(out)
        rewards = self.criterion(logits, labels)
        return tree_compositions, log_probs, entropy, norm_entropy, logits, rewards
    
    def get_baseline(self, x, mask, labels):
        with torch.no_grad():
            self.eval()
            rewards_c = self.single_pass(x, mask, labels)[-1]
            self.train()
            return rewards_c
    
    def normalize_rewards(self, rewards):
        with torch.no_grad():
            self.running_reward_var = self.norm_alpha * self.running_reward_var + \
                                        (1 - self.norm_alpha) * rewards.var()
            return rewards / self.running_reward_var.sqrt().clamp(min=1.0)
         
    def forward(self, x, mask, labels):
        tree_compositions, log_probs, entropy, norm_entropy, logits, rewards =  self.single_pass(x, mask, labels)
        loss = rewards.mean()
        if self.training:
            baseline = self.get_baseline(x, mask, labels)
            rewards = self.normalize_rewards(rewards - baseline)
        predictions = logits.argmax(dim=-1)
        return predictions, tree_compositions, loss, rewards.detach(), log_probs, entropy, norm_entropy
    
    def evaluate(self, x, mask, eval_tree_compositions):
        _, log_probs, _, norm_entropy = self.parser(self.word_embeddings(x), mask, eval_tree_compositions)
        return log_probs, norm_entropy

In [263]:
batch = 5
seqlen = 10
vocab_size = 20
idim = 100
tdim = 300
p_tdim = 256
c_tdim = 150
odim = 10
gumbel_temperature = 1
model = Model(vocab_size, idim, hdim, p_tdim, c_tdim, odim, gumbel_temperature)

In [264]:
x = torch.tensor([[0,1,2,3,4,5,6,7,8,9],
                  [12,18,1,0,16,1,1,2,3,3],
                  [0,0,0,1,1,1,2,2,3,3],
                  [1,1,1,1,1,5,5,5,5,5],
                  [2,3,2,6,11,15,16,14,15,12]])
mask = torch.tensor([[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.]])
labels = torch.tensor([1,2,4,2,3])

In [269]:
model.train()
predictions, tree_compositions, loss, rewards, log_probs, entropy, norm_entropy = model(x, mask, labels)

In [270]:
eval_log_probs, eval_norm_entropy = model.evaluate(x, mask, tree_compositions)