In [1]:
import torch
import torch.nn as nn

In [2]:
class Encoder(nn.Module):
    def __init__(self, d_model, nhead):
        super(Encoder, self).__init__()
        self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model, nhead), num_layers=1)

    def forward(self, src, src_mask=None):
        output = self.encoder(src, src_mask)
        return output

class Decoder(nn.Module):
    def __init__(self, d_model, nhead):
        super(Decoder, self).__init__()
        self.decoder = nn.TransformerDecoder(nn.TransformerDecoderLayer(d_model, nhead), num_layers=1)

    def forward(self, tgt, memory, tgt_mask=None):
        output = self.decoder(tgt, memory, tgt_mask)
        return output

class Generator(nn.Module):
    def __init__(self, d_model, vocab_size):
        super(Generator, self).__init__()
        self.projection = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        output = self.projection(x)
        return output

class EncoderDecoderTransformer(nn.Module):
    def __init__(self, encoder, decoder, generator):
        super(EncoderDecoderTransformer, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.generator = generator

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        memory = self.encoder(src, src_mask)
        output = self.decoder(tgt, memory, tgt_mask)
        return self.generator(output)

def inference(model, src, src_mask, max_len=100):
    model.eval()
    with torch.no_grad():
        memory = model.encoder(src, src_mask)
        tgt = torch.ones(1, 1).fill_(SOS_TOKEN).type_as(src.data)
        decoded = []
        for i in range(max_len):
            output = model.decoder(tgt, memory, tgt_mask)
            prediction = model.generator(output[:, -1])
            predicted_token = prediction.argmax(1).item()
            if predicted_token == EOS_TOKEN:
                break
            decoded.append(predicted_token)
            tgt = torch.cat([tgt, prediction.argmax(1).unsqueeze(1)], dim=1)
        return decoded