In [1]:
import torch
import sys
sys.path.append('../figet-hyperbolic-space/')
import figet
from figet.model_utils import CharEncoder, SelfAttentiveSum, sort_batch_by_length
import argparse

In [2]:
if torch.cuda.is_available():
        device = torch.device("cuda")
        torch.set_default_tensor_type(torch.DoubleTensor)

In [3]:
lopez_data = torch.load('../figet-hyperbolic-space/data/prep/MTNCI/data.pt')
lopez_data

{'vocabs': {'token': <figet.Dict.TokenDict at 0x7fa858791c10>,
  'type': <figet.Dict.TypeDict at 0x7fa77c45e160>,
  'char': <figet.Dict.Dict at 0x7fa77c45e190>},
 'train': <figet.Dataset.Dataset at 0x7fa77c45e1c0>,
 'dev': <figet.Dataset.Dataset at 0x7fa8587abbe0>,
 'test': <figet.Dataset.Dataset at 0x7fa77c45e3a0>}

In [None]:
# parser.add_argument("--emb_size", default=300, type=int, help="Embedding size.")
# parser.add_argument("--char_emb_size", default=50, type=int, help="Char embedding size.")
# parser.add_argument("--positional_emb_size", default=25, type=int, help="Positional embedding size.")
# parser.add_argument("--context_rnn_size", default=200, type=int, help="RNN size of ContextEncoder.")
# parser.add_argument("--attn_size", default=100, type=int, help="Attention vector size.")
# parser.add_argument("--mention_dropout", default=0.5, type=float, help="Dropout rate for mention")
# parser.add_argument("--context_dropout", default=0.2, type=float, help="Dropout rate for context")

In [4]:
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as unpack


class CharEncoder(nn.Module):
    def __init__(self, char_vocab, args):
        super(CharEncoder, self).__init__()
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        conv_dim_input = 100
        filters = 5
        self.char_W = nn.Embedding(char_vocab.size(), conv_dim_input, padding_idx=0)
        self.conv1d = nn.Conv1d(conv_dim_input, args.char_emb_size, filters)  # input, output, filter_number

    def forward(self, span_chars):
        char_embed = self.char_W(span_chars).transpose(1, 2)  # [batch_size, char_embedding, max_char_seq]
        conv_output = [self.conv1d(char_embed)]  # list of [batch_size, filter_dim, max_char_seq, filter_number]
        conv_output = [F.relu(c) for c in conv_output]  # batch_size, filter_dim, max_char_seq, filter_num
        cnn_rep = [F.max_pool1d(i, i.size(2)) for i in conv_output]  # batch_size, filter_dim, 1, filter_num
        cnn_output = torch.squeeze(torch.cat(cnn_rep, 1), 2)  # batch_size, filter_num * filter_dim, 1
        return cnn_output

class MentionEncoder(nn.Module):

    def __init__(self, char_vocab, args):
        super(MentionEncoder, self).__init__()
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.char_encoder = CharEncoder(char_vocab, args)
        self.attentive_weighted_average = SelfAttentiveSum(args.emb_size, 1)
        self.dropout = nn.Dropout(args.mention_dropout)

    def forward(self, mentions, mention_chars, word_lut):
        mention_embeds = word_lut(mentions)             # batch x mention_length x emb_size

        weighted_avg_mentions, _ = self.attentive_weighted_average(mention_embeds)
        char_embed = self.char_encoder(mention_chars)
        output = torch.cat((weighted_avg_mentions, char_embed), 1)
        return self.dropout(output).cuda()


class ContextEncoder(nn.Module):

    def __init__(self, args):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.emb_size = args.emb_size
        self.pos_emb_size = args.positional_emb_size
        self.rnn_size = args.context_rnn_size
        self.hidden_attention_size = 100
        super(ContextEncoder, self).__init__()
        self.pos_linear = nn.Linear(1, self.pos_emb_size)
        self.context_dropout = nn.Dropout(args.context_dropout)
        self.rnn = nn.LSTM(self.emb_size + self.pos_emb_size, self.rnn_size, bidirectional=True, batch_first=True)
        self.attention = SelfAttentiveSum(self.rnn_size * 2, self.hidden_attention_size) # x2 because of bidirectional

    def forward(self, contexts, positions, context_len, word_lut, hidden=None):
        """
        :param contexts: batch x max_seq_len
        :param positions: batch x max_seq_len
        :param context_len: batch x 1
        """
        positional_embeds = self.get_positional_embeddings(positions)   # batch x max_seq_len x pos_emb_size
        ctx_word_embeds = word_lut(contexts)                            # batch x max_seq_len x emb_size
        ctx_embeds = torch.cat((ctx_word_embeds, positional_embeds), 2)

        ctx_embeds = self.context_dropout(ctx_embeds)

        rnn_output = self.sorted_rnn(ctx_embeds, context_len)

        return self.attention(rnn_output)

    def get_positional_embeddings(self, positions):
        """ :param positions: batch x max_seq_len"""
        pos_embeds = self.pos_linear(positions.view(-1, 1))                     # batch * max_seq_len x pos_emb_size
        return pos_embeds.view(positions.size(0), positions.size(1), -1)        # batch x max_seq_len x pos_emb_size

    def sorted_rnn(self, ctx_embeds, context_len):
        sorted_inputs, sorted_sequence_lengths, restoration_indices = sort_batch_by_length(ctx_embeds, context_len)
        packed_sequence_input = pack(sorted_inputs, sorted_sequence_lengths, batch_first=True)
        packed_sequence_output, _ = self.rnn(packed_sequence_input, None)
        unpacked_sequence_tensor, _ = unpack(packed_sequence_output, batch_first=True)
        return unpacked_sequence_tensor.index_select(0, restoration_indices)



# def get_shimaoka(input, mention_encoder, context_encoder):
#     contexts, positions, context_len = input[0], input[1], input[2]
#     mentions, mention_chars = input[3], input[4]
#     type_indexes = input[5]

#     mention_vec = mention_encoder(mentions, mention_chars, self.word_lut)
#     context_vec, attn = context_encoder(contexts, positions, context_len, self.word_lut)

#     input_vec = torch.cat((mention_vec, context_vec), dim=1)
#     return input_vec

In [6]:
from MTNCI import MTNCI

In [7]:
class argClass():
    
    def __init__(self, args):
        self.emb_size = 300 
        self.char_emb_size = 50 
        self.positional_emb_size = 25 
        self.context_rnn_size = 200
        self.attn_size = 100
        self.mention_dropout = 0.5
        self.context_dropout = 0.5


args = {'emb_size': 300, 'char_emb_size': 50, 'positional_emb_size': 25, 'context_rnn_size':200,
        'attn_size': 100, 'mention_dropout' : 0.5, 'context_dropout': 0.5}
args = argClass(args)
vocabs = lopez_data['vocabs']

In [8]:
class ShimaokaMTNCI(MTNCI):
    
    def __init__(self, argss, vocabs, device, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        CHAR_VOCAB = 'char'        
        self.word_lut = nn.Embedding(vocabs["token"].size_of_word2vecs(), 
                                     argss.emb_size,
                                     padding_idx=0).cuda()
        
        self.mention_encoder = MentionEncoder(vocabs[CHAR_VOCAB], argss).cuda()
        self.context_encoder = ContextEncoder(argss).cuda()
        self.feature_len = argss.context_rnn_size * 2 + argss.emb_size + argss.char_emb_size
    
    def init_params(self, word2vec):
        self.word_lut.weight.data.copy_(word2vec)
        self.word_lut.weight.requires_grad = False
        
    def forward(self, input):
        contexts, positions, context_len = input[0], input[1].double(), input[2]
        mentions, mention_chars = input[3], input[4]
        type_indexes = input[5]
                
        mention_vec = self.mention_encoder(mentions, mention_chars, self.word_lut)
        
        context_vec, attn = self.context_encoder(contexts, positions, context_len, self.word_lut)

        input_vec = torch.cat((mention_vec, context_vec), dim=1)
        
        return super().forward(input_vec)


In [9]:
SHIMAOKA_OUT = args.context_rnn_size * 2 + args.emb_size + args.char_emb_size

out_spec = [{'manifold':'euclid', 'dim':[64, 10]},
                {'manifold':'poincare', 'dim':[128, 128, 10]}]

m = ShimaokaMTNCI(args, vocabs, device, 
                  input_d=SHIMAOKA_OUT,
                out_spec = out_spec,
                dims = [512, 512])


In [10]:
def get_dataset(data, batch_size, key):
    dataset = data[key]
    dataset.set_batch_size(batch_size)
    return dataset

test = get_dataset(lopez_data, 1024, "test")

In [11]:
word2vec = torch.load("../figet-hyperbolic-space/data/prep/MTNCI/word2vec.pt")

In [12]:
m.init_params(word2vec)

In [25]:
len(x[0])

10

In [26]:
lopez_data

{'vocabs': {'token': <figet.Dict.TokenDict at 0x7fa858791c10>,
  'type': <figet.Dict.TypeDict at 0x7fa77c45e160>,
  'char': <figet.Dict.Dict at 0x7fa77c45e190>},
 'train': <figet.Dataset.Dataset at 0x7fa77c45e1c0>,
 'dev': <figet.Dataset.Dataset at 0x7fa8587abbe0>,
 'test': <figet.Dataset.Dataset at 0x7fa77c45e3a0>}