The Encoder outputs a Tensor of shape (32, 10, 128), with the values being (batch size, sequence length, hidden size).
The output of the Encoder will get turned into key and value later.

In [131]:
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: torch.Tensor):
    # Expand pos_embedding along the batch_size dimension
        pos_embedding = self.pos_embedding[:token_embedding.size(0), :].expand(-1, token_embedding.size(1), -1)
        return self.dropout(token_embedding + pos_embedding)


# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: torch.Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

In [132]:
input_vocab_size, target_vocab_size

(2769, 1418)

In [133]:
input_vocab_size = input_lang.n_words # 2769
target_vocab_size = output_lang.n_words # 1418
vocab_size = max(input_vocab_size, target_vocab_size)
batch_size = 16 # not the issue
num_layers = 32 # not the issue
input_dim = 32
dropout = 0.2
n_heads = 2

class TransformerSeq2Seq (nn.Module):
    def __init__(self):
        super(TransformerSeq2Seq, self).__init__()
        self.embedding = nn.Embedding(vocab_size, input_dim)
        self.pos_encoder = PositionalEncoding(input_dim, dropout, maxlen=MAX_LENGTH+1)
        
        encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=n_heads)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        decoder_layer = nn.TransformerDecoderLayer(d_model=input_dim, nhead=n_heads)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        
        self.linear = nn.Linear(input_dim, vocab_size)
        self.log_softmax = nn.LogSoftmax(dim=-1)  # Add this line
        
    def get_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None):
        src = self.embedding(src)
        src = self.pos_encoder(src)
        src = self.transformer_encoder(src, src_key_padding_mask=src_key_padding_mask)

        tgt_mask = self.get_mask(tgt.size(0)).to(device)
        tgt = self.embedding(tgt)
        tgt = self.pos_encoder(tgt)
        
        output = self.transformer_decoder(
            tgt = tgt, 
            memory = src, 
            tgt_mask = tgt_mask, # to avoid looking at the future tokens (the ones on the right)
            tgt_key_padding_mask = tgt_key_padding_mask, # to avoid working on padding
            memory_key_padding_mask = src_key_padding_mask # avoid looking on padding of the src
        )
        
        output = self.linear(output)
        return self.log_softmax(output)
    
    def generate(self, src, src_key_padding_mask=None):
        ''' src has dimension of LEN x 1 '''
        src = self.embedding(src)
        src = self.pos_encoder(src)
        src = self.transformer_encoder(src, src_key_padding_mask=src_key_padding_mask)
        
        inputs = [SOS_token]
        for i in range(MAX_LENGTH):
            tgt = torch.LongTensor([inputs]).view(-1,1).to(device)
            tgt_mask = self.get_mask(i+1).to(device)
            
            tgt = self.embedding(tgt)
            tgt = self.pos_encoder(tgt)

            print("Shape of src:", src.shape)
            print("Shape of tgt:", tgt.shape)

            output = self.transformer_decoder(
                tgt=tgt, 
                memory=src, 
                tgt_mask=tgt_mask,
                memory_key_padding_mask = src_key_padding_mask )
            
            output = self.linear(output)
            output = self.log_softmax(output)
            output = output[-1] # the last timestep
            values, indices = output.max(dim=-1)
            pred_token = indices.item()
            inputs.append(pred_token)

        return inputs[1:]

In [134]:
from torch.nn.utils.rnn import pad_sequence

def indexesFromSentence(lang, sentence):
    return [SOS_token] + [lang.word2index[word] for word in sentence.split(' ')]

def tensorFromSentence(lang, sentence, max_length=MAX_LENGTH):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_token)

def tensorsFromPair(pair):
    input_tensor = tensorFromSentence(input_lang, pair[0])
    target_tensor = tensorFromSentence(output_lang, pair[1])
    return (input_tensor, target_tensor)

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([SOS_token]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_token])))

# ``src`` and ``tgt`` language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
                                               vocab_transform[ln], #Numericalization
                                               tensor_transform) # Add BOS/EOS and create tensor

# function to collate data samples into batch tensors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_token)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_token)
    return src_batch, tgt_batch

def get_dataloader(batch_size):
    input_lang, output_lang, pairs = prepareData('en', 'it', True)

    n = len(pairs)
    input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)
    target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)

    for idx, (inp, tgt) in enumerate(pairs):
        inp_ids = indexesFromSentence(input_lang, inp)
        tgt_ids = indexesFromSentence(output_lang, tgt)
        inp_ids = inp_ids[:MAX_LENGTH-1] + [EOS_token]  # truncate or pad inp_ids
        tgt_ids = tgt_ids[:MAX_LENGTH-1] + [EOS_token]  # truncate or pad tgt_ids
        input_ids[idx, :len(inp_ids)] = inp_ids
        target_ids[idx, :len(tgt_ids)] = tgt_ids

    train_data = TensorDataset(torch.LongTensor(input_ids).to(device),
                               torch.LongTensor(target_ids).to(device))

    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size, collate_fn=collate_fn)
    return input_lang, output_lang, train_dataloader


In [135]:
def train_epoch(dataloader, model, optimizer, criterion):
    total_loss = 0
    for i, data in enumerate(dataloader, 1):
        input_tensor, target_tensor = data
        input_tensor, target_tensor = input_tensor.to(device), target_tensor.to(device)

        # Right shift the target tensor for the decoder's input
        sos_token = torch.full((1, target_tensor.size(1)), SOS_token, device=device, dtype=torch.long)
        target_tensor_shifted = torch.cat([sos_token, target_tensor[:-1]], dim=0)

        optimizer.zero_grad()

        # Pass the shifted target tensor to the model
        output = model(input_tensor, target_tensor_shifted)

        loss = criterion(
            output.view(-1, output.size(-1)),
            target_tensor.view(-1)
        )
        loss.backward()

        optimizer.step()

        total_loss += loss.item()

        # Print progress
        print(f"\rBatch {i}/{len(dataloader)}: Loss = {loss.item()}", end="")

    print()  # Ensure the next print starts on a new line
    return total_loss / len(dataloader)


In [136]:
import time
import math

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [137]:
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import matplotlib.ticker as ticker
import numpy as np

def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)

In [138]:
def train(train_dataloader, model, n_epochs, learning_rate=0.001,
               print_every=100, plot_every=100):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.NLLLoss()

    for epoch in range(1, n_epochs + 1):
        loss = train_epoch(train_dataloader, model, optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        if epoch % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),
                                        epoch, epoch / n_epochs * 100, print_loss_avg))

        if epoch % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0

    showPlot(plot_losses)

In [139]:
def evaluate(model, sentence, input_lang, output_lang, max_length=MAX_LENGTH):
    with torch.no_grad():
        input_tensor = tensorFromSentence(input_lang, sentence)
        # Pad the input_tensor to max_length
        if input_tensor.size(0) < max_length:
            input_tensor = F.pad(input_tensor, (0, max_length - input_tensor.size(0)), 'constant', 0)
        decoded_ids = model.generate(input_tensor)
        decoded_words = []
        for idx in decoded_ids:
            if idx == EOS_token:  # use the global variable EOS_token
                decoded_words.append('<EOS>')
                break
            else:
                decoded_words.append(output_lang.index2word[idx])
        return decoded_words

def evaluateRandomly(model, n=10):
    for i in range(n):
        pair = random.choice(pairs)
        print('>', pair[0])
        print('=', pair[1])
        output_words = evaluate(model, pair[0], input_lang, output_lang)
        output_sentence = ' '.join(output_words)
        print('<', output_sentence)
        print('')

In [140]:
input_lang.n_words, output_lang.n_words

(2769, 1418)

In [141]:
input_lang, output_lang, train_dataloader = get_dataloader(batch_size)

model = TransformerSeq2Seq().to(device)

train(train_dataloader, model, n_epochs=1, print_every=1, plot_every=30, learning_rate=0.0001)

Reading lines...
Read 374932 sentence pairs
Trimmed to 12262 sentence pairs
Counting words...
Counted words:
it 2769
en 1418


RuntimeError: The size of tensor a (16) must match the size of tensor b (6) at non-singleton dimension 0

In [92]:
model.eval()
evaluateRandomly(model)

> sei spericolato
= you re reckless
Shape of src: torch.Size([1, 9, 66])
Shape of tgt: torch.Size([1, 1, 66])


RuntimeError: shape '[1, 2, 33]' is invalid for input of size 594