In [1]:
import torch

## Seq2seq

In [17]:
class Encoder(torch.nn.Module):
    def __init__(self, emb_dim, voc_size):
        super().__init__()
        self.emb = torch.nn.Embedding(voc_size, emb_dim)
        self.gru = torch.nn.GRU(emb_dim, 100, 1, batch_first=True)
    
    def forward(self, x):
        x = self.emb(x)
        y, h = self.gru(x)
        return y, h

In [26]:
class Decoder(torch.nn.Module):
    def __init__(self, emb_dim, voc_size):
        super().__init__()
        self.emb = torch.nn.Embedding(voc_size, emb_dim)
        self.gru = torch.nn.GRU(emb_dim, 100, 1, batch_first=True)
        self.lin = torch.nn.Linear(100, voc_size)

    def forward(self, x, h):
        x = self.emb(x)
        y, h1 = self.gru(x, h)
        y = torch.nn.functional.relu(y)
        y = self.lin(y)
        return torch.nn.functional.softmax(y, dim=-1)

In [27]:
class Seq2seq(torch.nn.Module):
    def __init__(self, emb_dim, voc_size_x, voc_size_target, length, encoder, decoder):
        super().__init__()
        self.encoder = encoder(emb_dim, voc_size_x)
        self.decoder = decoder(emb_dim, voc_size_target)

    def forward(self, x_inp, x_tgt):
        x_inp, h = self.encoder(x_inp)
        h = torch.nn.functional.relu(h)
        y = self.decoder(x_tgt, h)
        return y

In [28]:
x_inp = torch.randint(1, 10000, (1, 10))
x_tgt = torch.randint(1, 10000, (1, 10))

In [29]:
seq = Seq2seq(300, 10000, 10000, 10, Encoder, Decoder)
seq(x_inp, x_tgt)

tensor([[[1.0295e-04, 1.3806e-04, 9.8142e-05,  ..., 9.6626e-05,
          1.0747e-04, 8.9388e-05],
         [7.5145e-05, 1.2423e-04, 9.7938e-05,  ..., 8.4141e-05,
          8.9150e-05, 9.2981e-05],
         [8.4381e-05, 1.3692e-04, 8.5195e-05,  ..., 8.7998e-05,
          1.1183e-04, 7.9228e-05],
         ...,
         [1.0744e-04, 1.4628e-04, 9.2292e-05,  ..., 9.3539e-05,
          1.1310e-04, 8.4102e-05],
         [8.4830e-05, 1.2515e-04, 1.0501e-04,  ..., 9.0506e-05,
          1.0236e-04, 8.7274e-05],
         [1.0487e-04, 1.0975e-04, 1.0038e-04,  ..., 8.4768e-05,
          1.1647e-04, 1.0437e-04]]], grad_fn=<SoftmaxBackward0>)