In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext.legacy.data import Field
from torchtext.legacy.data import TabularDataset
"""
from konlpy.tag import Hannanum
import pandas as pd

from nltk.tokenize import TreebankWordTokenizer
import nltk
"""
from torchtext.legacy.data import TabularDataset

In [None]:
class Encoder(nn.Module):
    def __init__(self, n_tokens, n_inputs, n_hiddens, padding_idx):
        super().__init__()
        self.n_hiddens = n_hiddens
        self.embedding = nn.Embedding(n_tokens, n_inputs, padding_idx = padding_idx)
        self.bidirectional_gru = nn.GRU(n_inputs, n_hiddens, bidirectional=True)

    def init_hidden(self, batch_size):
        weight = next(self.parameters())
        h0 = weight.new_zeros(2, batch_size, self.n_hiddens)
        return h0

    def forward(self, x):
        input_length = torch.LongTensor([torch.max(x[i, :].data.nonzero()) + 1 for i in range(x.size(0))])
        hidden = self.init_hidden(x.size(0))
        x = self.embedding(x)
        x = nn.utils.rnn.pack_padded_sequence(x, input_length, batch_first = self.batch_first)
        output, hidden = self.bidirectional_gru(x, hidden)
        output = nn.utils.rnn.pad_packed_sequence(output, batch_first = self.batch_first)
        hidden = torch.cat([hidden[0], hidden[1]], dim = -1)
        return output, hidden



In [6]:
class Alignment(nn.Module):
    def __init__(self, n_hiddens):
        self.n_hiddens = n_hiddens
        self.v = nn.Parameter(nn.init.uniform_(torch.empty(n_hiddens))) #linear?
        self.align = nn.Linear(self.n_hiddens * 3, self.n_hiddens)
        
    def forward(self, h , s):
        e = torch.tanh(self.align(torch.cat([h, s], 2)))
        e = e.transpose(1, 2)
        v = self.v.repeat(e.size(0), 1).unsqueeze(1)

        e = torch.bmm(v, e)
        return e.squeeze(1)


In [7]:
class Attention(nn.Module):
    def __init__(self, n_hiddens):
        super().__init__()
        self.n_hidden = n_hiddens
        self.align = Alignment(self.n_hiddens)
    
    def forward(self, h_, s):
        time_step = s.size(0)
        h = h_.repeat(time_step, 1, 1).transpose(0, 1)
        s = s.transpose(0, 1)
        energy = self.align(h, s)
        return F.softmax(energy, dim=1).unsqueeze(1)


In [10]:
class Decoder(nn.Module):
    def __init__(self, n_outputs, n_embeddings, n_hiddens, n_maxout):
        super().__init__()
        self.n_hiddens = n_hiddens
        self.embedding = nn.Embedding(n_outputs, n_embeddings)
        self.attention_layer = Attention(self.n_hiddens)
        self.gru = nn.GRU(n_embeddings + n_hiddens * 2, n_hiddens)

        self.maxout = Maxout(n_hiddens * 3 + n_embeddings, n_maxout, 2)
        self.out = nn.Linear(n_maxout, n_outputs)

    def forward(self, input, h, s):
        s = s.transpose(0, 1)
        embedded = self.embedding(input)
        attention = self.attention_layer(h[-1], s)
        context = attention.bmm(s.transpose(0, 1)).transpose(0, 1)
        input = torch.cat([embedded, context], 2)
        out, hidden = self.gru(input, h)
        maxout_input = torch.cat([h, embedded, context])
        out = self.maxout(maxout_input).squeeze(0)
        out = self.out(out)
        out = F.log_softmax(out, dim=1)
        return out, hidden

In [9]:
class Maxout(nn.Module):

    def __init__(self, d_in, d_out, pool_size):
        super().__init__()
        self.d_in, self.d_out, self.pool_size = d_in, d_out, pool_size
        self.lin = nn.Linear(d_in, d_out * pool_size)


    def forward(self, inputs):
        shape = list(inputs.size())
        shape[-1] = self.d_out
        shape.append(self.pool_size)
        max_dim = len(shape) - 1
        out = self.lin(inputs)
        m, i = out.view(*shape).max(max_dim)
        return m
        

In [8]:
class RNNsearch(nn.Module):
    def __init__(self, n_tokens, n_inputs, n_outputs, n_embeddings, n_hiddens, n_maxout, padding_idx, device):
        super().__init__()
        self.n_outputs = n_outputs
        self.device = device

        self.encoder = Encoder(n_tokens, n_inputs, n_hiddens, padding_idx)
        self.decoder = Decoder(n_outputs, n_embeddings, n_hiddens, n_maxout)

        
    def forward(self, x, target):
        
        encoder_outputs, encoder_hidden = self.encoder(x)

        input = target[:, 0]
        outputs = torch.zeros(target.shape[0], target.shape[1], self.n_outputs).to(self.device)
        for t in range(1, target.shape[1]):
            output, hidden = self.decoder(input, encoder_hidden, encoder_outputs)
            outputs[:, t] = output

            input = output.argmax(1)
    
        return outputs




        

In [None]:
encoder = Encoder(10,10, 10, 0)