<a href="https://colab.research.google.com/github/JackWittmayer/Transformer-Implementation/blob/main/EDTransformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install tokenizers



In [2]:
import re
import string
import os
import pickle
from unicodedata import normalize
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.nn.functional import log_softmax, pad

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import BpeTrainer
from tokenizers.processors import TemplateProcessing

import random
import time

import numpy as np
import math
import matplotlib.pyplot as plt

import sys
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import corpus_bleu

from datetime import datetime

In [3]:
torch.manual_seed(25)
random.seed(25)
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(device)

cuda


In [4]:
SAMPLE_X = torch.tensor([[3, 2, 0, 1], [1, 2, 3, 0]], dtype=torch.int32).to(device)
SAMPLE_Z = torch.tensor([4, 1, 7, 6], dtype=torch.int32).to(device)

In [5]:
def printIfVerbose(verbose, tag, value):
    if verbose:
        print(tag, value)

In [6]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, embedding_size):
        super().__init__()
        self.table = nn.Embedding(vocab_size, embedding_size).to(device)

    def forward(self, sequence):
        embeddings = self.table(sequence)
        return embeddings

In [7]:
def test_embedding():
    torch.manual_seed(25)
    vocab_size = 4
    embedding = Embedding(vocab_size, 4)
    print("weight:", embedding.table.weight)
    print("SAMPLE_X: ", SAMPLE_X)
    output = embedding(SAMPLE_X)
    print("output:", output)
    for j in range(len(output)):
        #print("sample:", sample)
        for i in range(vocab_size):
            assert output[j, i, :].eq(embedding.table.weight[SAMPLE_X[j, i]]).all()
test_embedding()

weight: Parameter containing:
tensor([[ 0.0877, -0.6113,  0.3441, -1.2916],
        [-0.5874,  0.8060,  1.3200,  0.4826],
        [ 1.6671, -0.2342,  0.1074,  1.7852],
        [ 0.7874, -0.2466,  0.2384, -0.6746]], device='cuda:0',
       requires_grad=True)
SAMPLE_X:  tensor([[3, 2, 0, 1],
        [1, 2, 3, 0]], device='cuda:0', dtype=torch.int32)
output: tensor([[[ 0.7874, -0.2466,  0.2384, -0.6746],
         [ 1.6671, -0.2342,  0.1074,  1.7852],
         [ 0.0877, -0.6113,  0.3441, -1.2916],
         [-0.5874,  0.8060,  1.3200,  0.4826]],

        [[-0.5874,  0.8060,  1.3200,  0.4826],
         [ 1.6671, -0.2342,  0.1074,  1.7852],
         [ 0.7874, -0.2466,  0.2384, -0.6746],
         [ 0.0877, -0.6113,  0.3441, -1.2916]]], device='cuda:0',
       grad_fn=<EmbeddingBackward0>)


In [8]:
class Unembedding(nn.Module):
    def __init__(self, vocab_size, embedding_size):
        super().__init__()
        self.weight = nn.Linear(embedding_size, vocab_size).to(device)

    def forward(self, x):
        return self.weight(x)

In [9]:
def test_unembedding():
    torch.manual_seed(25)
    vocab_size = 10
    embedding_size = 4
    sequence_length = 4
    batch_size = 2
    input = torch.rand(batch_size, sequence_length, embedding_size).to(device)
    unembedding = Unembedding(vocab_size, embedding_size)

    print("weight:", unembedding.weight)
    print("input: ", input)
    output = unembedding(input)
    print("output:", output)
    assert output.shape == (batch_size, sequence_length, vocab_size)
test_unembedding()

weight: Linear(in_features=4, out_features=10, bias=True)
input:  tensor([[[0.7518, 0.1929, 0.0629, 0.9118],
         [0.3828, 0.2990, 0.5933, 0.2911],
         [0.2416, 0.5582, 0.0481, 0.3497],
         [0.3520, 0.9528, 0.0284, 0.8488]],

        [[0.3947, 0.5181, 0.9726, 0.8813],
         [0.0056, 0.3056, 0.9384, 0.7949],
         [0.4399, 0.1766, 0.8739, 0.1425],
         [0.4682, 0.6254, 0.3040, 0.7923]]], device='cuda:0')
output: tensor([[[-4.9334e-01,  3.9030e-01, -3.4348e-03,  2.0479e-01, -1.7155e-01,
           3.0325e-01,  9.5298e-01, -9.2740e-01,  3.5209e-01, -2.1405e-02],
         [-6.2972e-02, -6.9980e-02, -5.4304e-02,  1.2675e-01, -5.5075e-01,
           1.8844e-01,  8.6408e-01, -5.4956e-01,  4.7789e-01,  7.8078e-02],
         [-2.9110e-01,  3.9284e-02,  7.2926e-02,  2.0875e-01, -3.4683e-01,
           1.1962e-01,  7.2445e-01, -5.6804e-01,  4.2547e-01,  1.0732e-02],
         [-3.4147e-01,  6.4171e-02, -5.8167e-02,  1.8160e-01, -2.1651e-01,
          -1.4458e-01,  1.0255e+0

In [10]:
class PositionalEmbedding(nn.Module):
    def __init__(self, embedding_size, max_sequence_length):
        super().__init__()
        self.table = nn.Embedding(max_sequence_length, embedding_size).to(device)

    def forward(self, sequence):
        positions = torch.zeros(sequence.shape, dtype=torch.int32)
        positions[:, ::] = torch.arange(0, sequence.shape[-1])
        #print("positions", positions)
        positional_embeddings = self.table(positions.to(device))
        return positional_embeddings

In [11]:
def test_positional_embedding():
    embedding_size = 8
    max_sequence_length = 10
    batch_size = 2
    positional_embedding = PositionalEmbedding(embedding_size, max_sequence_length)
    output = positional_embedding(SAMPLE_X)
    print("output:", output)
    assert output.shape == (batch_size, SAMPLE_X.shape[-1], embedding_size)
test_positional_embedding()

output: tensor([[[ 0.3412, -0.2888, -0.4162, -1.2433,  0.3364, -2.1897, -0.2086,
           0.0196],
         [ 0.2461, -0.0812, -0.4464, -1.2595,  0.5963, -1.3647, -0.7684,
           0.3472],
         [-0.0142,  0.1426,  2.0701, -0.1623, -0.4448,  0.8318, -0.2930,
          -0.2068],
         [-0.8096,  1.2487,  0.5594, -0.3657,  0.5478, -1.4327, -1.4111,
          -0.4237]],

        [[ 0.3412, -0.2888, -0.4162, -1.2433,  0.3364, -2.1897, -0.2086,
           0.0196],
         [ 0.2461, -0.0812, -0.4464, -1.2595,  0.5963, -1.3647, -0.7684,
           0.3472],
         [-0.0142,  0.1426,  2.0701, -0.1623, -0.4448,  0.8318, -0.2930,
          -0.2068],
         [-0.8096,  1.2487,  0.5594, -0.3657,  0.5478, -1.4327, -1.4111,
          -0.4237]]], device='cuda:0', grad_fn=<EmbeddingBackward0>)


In [12]:
def attention(queries, keys, values, mask, dropout, verbose):
    printIfVerbose(verbose, "queries:", queries)
    printIfVerbose(verbose, "keys:", keys)
    printIfVerbose(verbose, "values:", values)
    keys_transposed = torch.transpose(keys, -2, -1)
    printIfVerbose(verbose, "keys_transposed:", keys_transposed)
    scores = torch.matmul(queries, keys_transposed)
    #assert scores.shape == (keys.shape[0], keys.shape[-1], queries.shape[-1])
    printIfVerbose(verbose, "scores:", scores)
    printIfVerbose(verbose, "scores:", scores.shape)
    printIfVerbose(verbose, "masks:", mask.shape)
    scores = scores.masked_fill(mask == 0, -1e9)
    printIfVerbose(verbose, "masked scores:", scores)
    d_attn = keys.shape[-1]
    scaled_scores = scores / math.sqrt(d_attn)
    printIfVerbose(verbose, "scaled_scores:", scaled_scores)
    softmax_scores = torch.softmax(scaled_scores, -1)
    softmax_scores = dropout(softmax_scores)
    printIfVerbose(verbose, "softmax_scores:", softmax_scores)
    printIfVerbose(verbose, "softmax_socres shape:", softmax_scores.shape)
    printIfVerbose(verbose, "values:", values)
    v_out = torch.matmul(softmax_scores, values)
    return v_out

In [13]:
def test_attention():
    d_attn = 4
    length_x = 4
    length_z = 3
    batch_size = 2
    d_out = 2

    queries = torch.rand(batch_size, length_x, d_attn)
    keys = torch.rand(batch_size, length_z, d_attn)
    values = torch.rand(batch_size, length_z, d_out)
    mask = torch.tril(torch.ones(length_x, length_z) == 1)
    padding_mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]], dtype=torch.int32)

    v_out = attention(queries, keys, values, mask, nn.Dropout(0.1), True)
    #print("output:", v_out)
    assert v_out.shape == (batch_size, length_x, d_out)
test_attention()

queries: tensor([[[0.4961, 0.6278, 0.3572, 0.5220],
         [0.1997, 0.5286, 0.4723, 0.0238],
         [0.1838, 0.2010, 0.1765, 0.8587],
         [0.7776, 0.1199, 0.8638, 0.1066]],

        [[0.1084, 0.8448, 0.7043, 0.9275],
         [0.3953, 0.2704, 0.6228, 0.6078],
         [0.7686, 0.3296, 0.4959, 0.0065],
         [0.9125, 0.8358, 0.6698, 0.4129]]])
keys: tensor([[[0.0129, 0.5052, 0.5967, 0.3134],
         [0.1648, 0.4834, 0.2368, 0.7654],
         [0.9255, 0.3393, 0.5612, 0.0953]],

        [[0.5582, 0.5739, 0.5244, 0.6292],
         [0.7426, 0.3134, 0.7793, 0.9385],
         [0.1588, 0.3427, 0.3863, 0.2306]]])
values: tensor([[[0.1533, 0.0876],
         [0.9218, 0.8859],
         [0.6448, 0.5202]],

        [[0.3174, 0.8487],
         [0.8658, 0.5804],
         [0.1021, 0.1329]]])
keys_transposed: tensor([[[0.0129, 0.1648, 0.9255],
         [0.5052, 0.4834, 0.3393],
         [0.5967, 0.2368, 0.5612],
         [0.3134, 0.7654, 0.0953]],

        [[0.5582, 0.7426, 0.1588],
       

In [14]:
from enum import Enum
class MaskStrategy(Enum):
    UNMASKED = 1
    MASKED = 2

In [15]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, num_heads, d_attn, d_x, d_z, d_out, d_mid, maskStrategy, p_dropout, verbose):
        super().__init__()
        self.verbose = verbose
        self.num_heads = num_heads
        self.d_attn = d_attn
        self.d_x = d_x
        self.d_z = d_z
        self.d_out = d_out
        self.d_mid = d_mid
        self.maskStrategy = maskStrategy
        self.weight_query = nn.Linear(d_x, d_attn).to(device)
        self.weight_key = nn.Linear(d_z, d_attn).to(device)
        self.weight_value = nn.Linear(d_z, d_mid).to(device)
        self.weight_out = nn.Linear(d_mid, d_out).to(device)
        self.dropout = nn.Dropout(p_dropout)

    def forward(self, z, x, padding_mask):
        length_z = z.shape[-2]
        length_x = x.shape[-2]
        batch_size = x.shape[0]

        queries = self.weight_query(x).view(batch_size, length_x, self.num_heads, -1).transpose(1, 2)
        keys = self.weight_key(z).view(batch_size, length_z, self.num_heads, -1).transpose(1, 2)
        values = self.weight_value(z).view(batch_size, length_z, self.num_heads, -1).transpose(1, 2)

        assert queries.shape == (batch_size, self.num_heads, length_x, self.d_attn / self.num_heads)
        assert keys.shape == (batch_size, self.num_heads, length_z, self.d_attn / self.num_heads)
        assert values.shape == (batch_size, self.num_heads, length_z, self.d_mid / self.num_heads)

        if self.maskStrategy == MaskStrategy['UNMASKED']:
            mask = padding_mask.unsqueeze(-2)
        elif self.maskStrategy == MaskStrategy['MASKED']:
            padding_mask = padding_mask.unsqueeze(-2)
            mask = torch.tril(torch.ones(length_x, length_z) == 1).to(device)
            printIfVerbose(self.verbose, "padding mask:", padding_mask.shape)
            printIfVerbose(self.verbose, "mask tril", mask)
            mask = mask & padding_mask
            printIfVerbose(self.verbose, "merged mask:", mask)
        mask = mask.unsqueeze(1)
        printIfVerbose(self.verbose, "mask", mask)
        printIfVerbose(self.verbose, "mask", mask.shape)
        v_out = attention(queries, keys, values, mask, self.dropout, self.verbose)
        printIfVerbose(self.verbose, "v_out shape", v_out.shape)
        assert v_out.shape == (batch_size, self.num_heads, length_x, self.d_mid / self.num_heads)
        printIfVerbose(self.verbose, "v_out:", v_out)
        printIfVerbose(self.verbose, "v_out shape before:", v_out.shape)
        v_out = v_out.transpose(1, 2).reshape(batch_size, length_x, -1)
        printIfVerbose(self.verbose, "v_out shape:", v_out.shape)
        printIfVerbose(self.verbose, "v_out reshaped:", v_out)
        output = self.weight_out(v_out)
        printIfVerbose(self.verbose, "output shape", output.shape)
        assert output.shape == (batch_size, length_x, self.d_out)
        return output

    def disable_subsequent_mask(self):
        self.maskStrategy = MaskStrategy['UNMASKED']

    def enable_subsequent_mask(self):
        self.maskStrategy = MaskStrategy['MASKED']


In [16]:
def test_multi_headed_attention_encoder_fixed():
    num_heads = 1
    d_attn = 4
    d_x = 4
    d_z = 4
    d_out = 1
    d_mid = 3
    length_z = 3
    batch_size = 1
    padding_mask = torch.tensor([[1, 1, 0]], dtype=torch.int32).to(device)

    multi_headed_attention = MultiHeadedAttention(num_heads, d_attn, d_x, d_z, d_out, d_mid, MaskStrategy['UNMASKED'], 0.0, True).to(device)
    z = torch.tensor([[[1, 0, 1, 0], [0, 2, 0, 2], [1, 1, 1, 1]]], dtype=torch.float32).to(device)
    #print("z:", z
    output = multi_headed_attention(z, z, padding_mask)
    #print("output:", output)
    assert output.shape == (batch_size, length_z, d_out)
test_multi_headed_attention_encoder_fixed()

mask tensor([[[[1, 1, 0]]]], device='cuda:0', dtype=torch.int32)
mask torch.Size([1, 1, 1, 3])
queries: tensor([[[[-0.8182,  0.9362,  0.3153,  0.4970],
          [ 0.8439,  0.4537,  0.0540,  1.0113],
          [-0.2946,  0.9878,  0.2263,  0.8214]]]], device='cuda:0',
       grad_fn=<TransposeBackward0>)
keys: tensor([[[[-0.9612,  0.6361, -0.7727,  1.0188],
          [-0.5443, -1.1392, -0.0485,  1.6070],
          [-1.0001, -0.1185, -0.8126,  1.5968]]]], device='cuda:0',
       grad_fn=<TransposeBackward0>)
values: tensor([[[[ 0.6987, -0.6661, -0.9703],
          [ 0.6950, -0.5483, -0.6054],
          [ 0.9841, -0.7190, -1.0553]]]], device='cuda:0',
       grad_fn=<TransposeBackward0>)
keys_transposed: tensor([[[[-0.9612, -0.5443, -1.0001],
          [ 0.6361, -1.1392, -0.1185],
          [-0.7727, -0.0485, -0.8126],
          [ 1.0188,  1.6070,  1.5968]]]], device='cuda:0',
       grad_fn=<TransposeBackward0>)
scores: tensor([[[[1.6446, 0.1622, 1.2447],
          [0.4659, 0.6464, 0.673

In [17]:
def test_multi_headed_attention_encoder():
    num_heads = 4
    d_attn = 4
    d_x = 4
    d_z = 4
    d_out = 1
    d_mid = 4
    length_z = 3
    batch_size = 3
    padding_mask = torch.tensor([[1, 1, 0], [1, 1, 0], [1, 1, 1]], dtype=torch.int32).to(device)

    multi_headed_attention = MultiHeadedAttention(num_heads, d_attn, d_x, d_z, d_out, d_mid, MaskStrategy['UNMASKED'], 0.0, True).to(device)
    z = torch.tensor([[[1, 0, 1, 0], [0, 2, 0, 2], [1, 1, 1, 1]],
                      [[1, 0, 1, 0], [0, 2, 0, 2], [1, 1, 1, 1]],
                      [[1, 0, 1, 0], [0, 2, 0, 2], [1, 1, 1, 1]]], dtype=torch.float32).to(device)
    #print("z:", z
    output = multi_headed_attention(z, z, padding_mask)
    #print("output:", output)
    assert output.shape == (batch_size, length_z, d_out)
test_multi_headed_attention_encoder()

mask tensor([[[[1, 1, 0]]],


        [[[1, 1, 0]]],


        [[[1, 1, 1]]]], device='cuda:0', dtype=torch.int32)
mask torch.Size([3, 1, 1, 3])
queries: tensor([[[[-0.4361],
          [-1.3334],
          [-0.8685]],

         [[-0.2439],
          [ 0.0877],
          [-0.1861]],

         [[ 0.5704],
          [-0.6079],
          [ 0.1567]],

         [[-0.9431],
          [-0.2173],
          [-0.9932]]],


        [[[-0.4361],
          [-1.3334],
          [-0.8685]],

         [[-0.2439],
          [ 0.0877],
          [-0.1861]],

         [[ 0.5704],
          [-0.6079],
          [ 0.1567]],

         [[-0.9431],
          [-0.2173],
          [-0.9932]]],


        [[[-0.4361],
          [-1.3334],
          [-0.8685]],

         [[-0.2439],
          [ 0.0877],
          [-0.1861]],

         [[ 0.5704],
          [-0.6079],
          [ 0.1567]],

         [[-0.9431],
          [-0.2173],
          [-0.9932]]]], device='cuda:0', grad_fn=<TransposeBackward0>)
keys: tensor([

In [18]:
def test_multi_headed_attention_encoder_decoder():
    num_heads = 4
    d_attn = 4
    d_x = 4
    d_z = 4
    d_out = 4
    d_mid = 4
    length_x = 3
    length_z = 3
    batch_size = 4
    padding_mask = torch.tensor([[1, 1, 0]], dtype=torch.int32).to(device)

    multi_headed_attention = MultiHeadedAttention(num_heads, d_attn, d_x, d_z, d_out, d_mid, MaskStrategy['UNMASKED'], 0.0, True).to(device)
    x = torch.rand(batch_size, length_x, d_x).to(device)
    z = torch.rand(batch_size, length_z, d_z).to(device)
    output = multi_headed_attention(z, x, padding_mask)
    print("output:", output)
    assert output.shape == (batch_size, length_x, d_out)
test_multi_headed_attention_encoder_decoder()

mask tensor([[[[1, 1, 0]]]], device='cuda:0', dtype=torch.int32)
mask torch.Size([1, 1, 1, 3])
queries: tensor([[[[-1.0411e-01],
          [-2.0415e-01],
          [-2.1368e-01]],

         [[ 5.9094e-01],
          [ 1.3446e-01],
          [ 3.9921e-01]],

         [[-7.2403e-02],
          [-3.6016e-02],
          [-2.3203e-01]],

         [[ 9.5991e-01],
          [ 5.6997e-01],
          [ 7.0617e-01]]],


        [[[-2.4329e-01],
          [-1.9518e-01],
          [-3.2224e-01]],

         [[ 1.6908e-01],
          [ 3.7678e-01],
          [-9.6929e-02]],

         [[-1.3857e-01],
          [-2.0413e-01],
          [-2.2908e-01]],

         [[ 6.4452e-01],
          [ 7.0216e-01],
          [ 2.3602e-01]]],


        [[[-2.8424e-01],
          [-7.9439e-02],
          [-5.5061e-01]],

         [[-7.3381e-04],
          [ 2.5058e-01],
          [ 8.3200e-02]],

         [[-8.2078e-02],
          [ 1.1935e-01],
          [-5.9572e-01]],

         [[ 5.0838e-01],
          [ 6.3067e-

In [19]:
def test_multi_headed_attention_decoder_self():
    num_heads = 8
    d_attn = 8
    d_x = 8
    d_out = 8
    d_mid = 8
    length_x = 3
    batch_size = 4
    padding_mask = torch.tensor([[1, 1, 0], [1, 1, 0], [1, 0, 0], [1, 1, 1]], dtype=torch.int32).to(device)

    multi_headed_attention = MultiHeadedAttention(num_heads, d_attn, d_x, d_x, d_out, d_mid, MaskStrategy['UNMASKED'], 0.0, True).to(device)
    multi_headed_attention.enable_subsequent_mask()
    x = torch.rand(batch_size, length_x, d_x).to(device)
    output = multi_headed_attention(x, x, padding_mask)
    print("output:", output)
    assert output.shape == (batch_size, length_x, d_out)
test_multi_headed_attention_decoder_self()

padding mask: torch.Size([4, 1, 3])
mask tril tensor([[ True, False, False],
        [ True,  True, False],
        [ True,  True,  True]], device='cuda:0')
merged mask: tensor([[[1, 0, 0],
         [1, 1, 0],
         [1, 1, 0]],

        [[1, 0, 0],
         [1, 1, 0],
         [1, 1, 0]],

        [[1, 0, 0],
         [1, 0, 0],
         [1, 0, 0]],

        [[1, 0, 0],
         [1, 1, 0],
         [1, 1, 1]]], device='cuda:0', dtype=torch.int32)
mask tensor([[[[1, 0, 0],
          [1, 1, 0],
          [1, 1, 0]]],


        [[[1, 0, 0],
          [1, 1, 0],
          [1, 1, 0]]],


        [[[1, 0, 0],
          [1, 0, 0],
          [1, 0, 0]]],


        [[[1, 0, 0],
          [1, 1, 0],
          [1, 1, 1]]]], device='cuda:0', dtype=torch.int32)
mask torch.Size([4, 1, 3, 3])
queries: tensor([[[[-0.2834],
          [-0.0325],
          [-0.1563]],

         [[-0.3054],
          [-0.5627],
          [-0.3479]],

         [[ 0.5136],
          [ 0.6240],
          [ 0.5573]],

    

In [20]:
class LayerNorm(nn.Module):
    def __init__(self, feature_length):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(feature_length))
        self.offset = nn.Parameter(torch.zeros(feature_length))

    def forward(self, activations):
        mean = torch.mean(activations, -1, keepdim=True)
        #print("mean:", mean)
        #print("activations - mean", activations - mean)
        variance = torch.var(activations, -1, keepdim=True, unbiased=False)
        normalized_activations = (activations - mean) / torch.sqrt(variance + 1e-6)
        return (normalized_activations * self.scale) + self.offset

In [21]:
def test_layer_norm():
    feature_length = 4
    length_x = 3
    batch_size = 5
    layer_norm = LayerNorm(feature_length)

    activations = torch.rand(batch_size, length_x, feature_length)

    print("activations:", activations)
    print("layer_normed:", layer_norm(activations))
    assert layer_norm(activations).shape == activations.shape

test_layer_norm()

activations: tensor([[[0.5992, 0.4175, 0.3157, 0.9827],
         [0.7162, 0.8391, 0.7896, 0.7189],
         [0.4911, 0.8424, 0.8687, 0.0647]],

        [[0.5847, 0.6227, 0.4764, 0.1564],
         [0.2567, 0.7449, 0.0281, 0.6179],
         [0.3166, 0.7580, 0.0795, 0.9892]],

        [[0.3658, 0.3713, 0.9095, 0.2346],
         [0.4220, 0.1864, 0.4479, 0.7047],
         [0.4035, 0.5706, 0.9836, 0.7179]],

        [[0.6039, 0.7702, 0.2363, 0.8699],
         [0.3465, 0.5148, 0.4237, 0.1254],
         [0.5822, 0.1521, 0.6643, 0.8174]],

        [[0.0329, 0.2912, 0.9359, 0.8245],
         [0.2528, 0.2244, 0.9205, 0.2973],
         [0.2273, 0.3442, 0.5416, 0.2200]]])
layer_normed: tensor([[[ 0.0802, -0.6342, -1.0341,  1.5881],
         [-0.9662,  1.4206,  0.4597, -0.9141],
         [-0.2319,  0.8459,  0.9264, -1.5403]],

        [[ 0.6799,  0.8870,  0.0893, -1.6562],
         [-0.5448,  1.1687, -1.3470,  0.7231],
         [-0.6132,  0.6214, -1.2765,  1.2682]],

        [[-0.4030, -0.3816,  1.6

In [22]:
class FeedForward(nn.Module):
    def __init__(self, hiddenLayerWidth, d_e, p_dropout):
        super().__init__()
        self.mlp1 = nn.Parameter(torch.rand(d_e, hiddenLayerWidth))
        self.mlp2 = nn.Parameter(torch.rand(hiddenLayerWidth, d_e))
        self.mlp1_bias = nn.Parameter(torch.zeros(hiddenLayerWidth))
        self.mlp2_bias = nn.Parameter(torch.zeros(d_e))
        self.dropout = nn.Dropout(p_dropout)

    def forward(self, activations):
        activations = torch.matmul(activations, self.mlp1) + self.mlp1_bias
        activations = activations.relu()
        activations = torch.matmul(activations, self.mlp2) + self.mlp2_bias
        activations = self.dropout(activations)
        return activations


In [23]:
def test_feed_forward():
    hiddenLayerWidth = 3
    d_e = 4
    feed_forward = FeedForward(hiddenLayerWidth, d_e, 0.1)
    activations = torch.rand(10, 5, d_e)

    print("activations:", activations)
    output = feed_forward(activations)
    print("feed forward:", output)
    assert output.shape == activations.shape

test_feed_forward()

activations: tensor([[[0.7340, 0.5315, 0.3123, 0.1216],
         [0.5905, 0.7827, 0.5246, 0.2818],
         [0.3251, 0.9486, 0.7465, 0.8361],
         [0.2152, 0.9253, 0.1532, 0.9265],
         [0.2236, 0.7171, 0.8762, 0.0173]],

        [[0.8524, 0.3550, 0.3433, 0.9243],
         [0.8896, 0.5362, 0.7230, 0.5878],
         [0.9095, 0.1898, 0.3825, 0.5456],
         [0.3005, 0.6267, 0.9540, 0.6107],
         [0.5523, 0.0066, 0.2676, 0.8121]],

        [[0.9950, 0.5809, 0.3101, 0.2751],
         [0.7424, 0.8780, 0.5843, 0.2509],
         [0.9813, 0.0220, 0.9977, 0.5921],
         [0.8650, 0.0410, 0.4478, 0.3561],
         [0.9486, 0.0837, 0.5711, 0.3485]],

        [[0.4403, 0.4605, 0.8871, 0.5350],
         [0.1355, 0.8206, 0.3544, 0.7126],
         [0.6834, 0.8336, 0.2763, 0.3298],
         [0.4444, 0.8771, 0.9834, 0.1246],
         [0.5706, 0.4069, 0.6391, 0.0773]],

        [[0.1389, 0.9018, 0.4024, 0.0963],
         [0.7001, 0.1936, 0.8411, 0.3009],
         [0.3786, 0.3271, 0.6998,

In [24]:
class EncoderLayer(nn.Module):
    def __init__(self, num_heads, d_attn, d_x, d_z, d_out, d_mid, d_mlp, p_dropout, verbose):
        super().__init__()
        self.verbose = verbose
        self.multi_head_attention = MultiHeadedAttention(num_heads, d_attn, d_x, d_z, d_out, d_mid, MaskStrategy['UNMASKED'], p_dropout, verbose)
        self.layer_norm1 = LayerNorm(d_z)
        self.feed_forward = FeedForward(d_mlp, d_z, p_dropout)
        self.layer_norm2 = LayerNorm(d_z)

    def forward(self, z, padding_mask):
        z = self.layer_norm1(z)
        z = z + self.multi_head_attention(z, z, padding_mask)
        z = self.layer_norm2(z)
        z = z + self.feed_forward(z)
        return z

In [25]:
class Encoder(nn.Module):
    def __init__(self, num_layers, num_heads, d_attn, d_x, d_z, d_out, d_mid, d_mlp, p_dropout, verbose):
        super().__init__()
        self.layers = []
        for i in range(num_layers):
            encoder_layer = EncoderLayer(num_heads, d_attn, d_x, d_z, d_out, d_mid, d_mlp, p_dropout, verbose)
            self.layers.append(encoder_layer)
        self.layers = nn.ModuleList(self.layers)
        self.final_norm = LayerNorm(d_z)

    def forward(self, z, padding_mask):
        for layer in self.layers:
            z = layer(z, padding_mask)
        return self.final_norm(z)

In [26]:
class DecoderLayer(nn.Module):
    def __init__(self, num_heads, d_attn, d_x, d_z, d_out, d_mid, d_mlp, p_dropout, verbose):
        super().__init__()
        self.verbose = verbose
        self.multi_head_self_attention = MultiHeadedAttention(num_heads, d_attn, d_x, d_z, d_out, d_mid, MaskStrategy['MASKED'], p_dropout, verbose)
        self.layer_norm1 = LayerNorm(d_x)
        self.multi_head_global_attention = MultiHeadedAttention(num_heads, d_attn, d_x, d_z, d_out, d_mid, MaskStrategy['UNMASKED'], p_dropout, verbose)
        self.layer_norm2 = LayerNorm(d_x)
        self.feed_forward = FeedForward(d_mlp, d_x, p_dropout)
        self.layer_norm3 = LayerNorm(d_x)

    def forward(self, z, x, src_mask, tgt_mask):
        x = self.layer_norm1(x)
        x = x + self.multi_head_self_attention(x, x, tgt_mask)
        x = self.layer_norm2(x)
        x = x + self.multi_head_global_attention(z, x, src_mask)
        x = self.layer_norm3(x)
        x = x + self.feed_forward(x)
        return x

    def disable_subsequent_mask(self):
        self.multi_head_self_attention.disable_subsequent_mask()

In [27]:
class Decoder(nn.Module):
    def __init__(self, num_layers, num_heads, d_attn, d_x, d_z, d_out, d_mid, d_mlp, p_dropout, verbose):
        super().__init__()
        self.layers = []
        for i in range(num_layers):
            decoder_layer = DecoderLayer(num_heads, d_attn, d_x, d_z, d_out, d_mid, d_mlp, p_dropout, verbose)
            self.layers.append(decoder_layer)
        self.layers = nn.ModuleList(self.layers)
        self.final_norm = LayerNorm(d_x)

    def forward(self, z, x, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(z, x, src_mask, tgt_mask)
        return self.final_norm(x)

    def disable_subsequent_mask(self):
        for layer in self.layers:
            layer.multi_head_self_attention.disable_subsequent_mask()

In [28]:
class EncoderDecoderTransformer(nn.Module):
    def __init__(self, num_encoder_layers, num_decoder_layers, num_heads, d_attn, d_x, d_z, d_out, d_mid, d_mlp, d_e, vocab_size, max_sequence_length, p_dropout, verbose):
        super().__init__()
        self.verbose = verbose
        self.src_embedding = Embedding(vocab_size, d_e)
        self.tgt_embedding = Embedding(vocab_size, d_e)
        self.unembedding = Unembedding(vocab_size, d_e)
        self.embedding_dropout = nn.Dropout(p_dropout)
        self.positionalEmbedding = PositionalEmbedding(d_e, max_sequence_length)
        self.encoder = Encoder(num_encoder_layers, num_heads, d_attn, d_x, d_z, d_out, d_mid, d_mlp, p_dropout, verbose)
        self.decoder = Decoder(num_decoder_layers, num_heads, d_attn, d_x, d_z, d_out, d_mid, d_mlp, p_dropout, verbose)

    def forward(self, z, x, src_mask, tgt_mask):
        z = self.src_embedding(z) + self.positionalEmbedding(z)
        z = self.embedding_dropout(z)
        z = self.encoder(z, src_mask)
        x = self.tgt_embedding(x) + self.positionalEmbedding(x)
        x = self.embedding_dropout(x)
        x = self.decoder(z, x, src_mask, tgt_mask)
        #print("x after decoder:", x.shape)
        x = self.unembedding(x)
        #print("x after unembedding:", x.shape)
        return x

    def disable_subsequent_mask(self):
        self.decoder.disable_subsequent_mask()




In [29]:
enRawName = "drive/MyDrive/colab data/multi30kEnTrain.txt"
deRawName = "drive/MyDrive/colab data/multi30kDeTrain.txt"
en30kVal = "drive/MyDrive/colab data/multi30kEnVal.txt"
de30kVal = "drive/MyDrive/colab data/multi30kDeVal.txt"
englishCleanName = "data/english_tokens.pkl"
germanCleanName = "data/german_tokens.pkl"
englishSortedName = "data/englishSorted.pkl"
germanSortedName = "data/germanSorted.pkl"

truncEn = "drive/MyDrive/colab data/truncEn.pkl"
truncDe = "drive/MyDrive/colab data/truncDe.pkl"

enTokenizerName = "drive/MyDrive/colab data/enTokenizer.pkl"
deTokenizerName = "drive/MyDrive/colab data/deTokenizer.pkl"
pairsName = "drive/MyDrive/colab data/pairs.pkl"
folder = "drive/MyDrive/colab data/"

enTrainingFileName = folder + "enTraining"
deTrainingFileName = folder + "deTraining"
enTestFileName = folder + "enTest"
deTestFileName = folder + "deTest"
enValFileName = folder + "enValidation"
deValFileName = folder + "deValidation"

enCombinedFileName = folder + "enCombined"
deCombinedFileName = folder + "deCombined"

In [30]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [31]:
# class SentenceDataset(Dataset):

#     TOKENIZER_SUFFIX = "_tokenizer"
#     BOS_TOKEN = "[SOS]"
#     EOS_TOKEN = "[EOS]"
#     PAD_TOKEN = "[PAD]"
#     UNK_TOKEN = "[UNK]"

#     def __init__(self, src_filename, tgt_filename, src_vocab_size, tgt_vocab_size, sequence_start_index, sequence_end_index):
#         src_sequences = self.to_sequences(self.load_doc(src_filename), sequence_start_index, sequence_end_index)
#         tgt_sequences = self.to_sequences(self.load_doc(tgt_filename), sequence_start_index, sequence_end_index)
#         src_sequences = [self.add_special_tokens(sequence) for sequence in src_sequences]
#         tgt_sequences = [self.add_special_tokens(sequence) for sequence in tgt_sequences]
#         self.src_tokenizer, self.tgt_tokenizer = self.setup_tokenizers(src_filename, tgt_filename, src_vocab_size, tgt_vocab_size, src_filename + SentenceDataset.TOKENIZER_SUFFIX, tgt_filename + SentenceDataset.TOKENIZER_SUFFIX)
#         # src_tokenized = self.src_tokenizer.encode_batch(src_sequences)
#         # tgt_tokenized = self.tgt_tokenizer.encode_batch(tgt_sequences)
#         # src_tensors = [torch.IntTensor(sequence.ids) for sequence in src_tokenized]
#         # tgt_tensor = [torch.IntTensor(sequence.ids) for sequence in tgt_tokenized]
#         self.pairs = self.pair_sequences(src_sequences, tgt_sequences)
#         #print("pairs", self.pairs)

#     # load doc into memory
#     def load_doc(self, filename):
#         # open the file as read only
#         file = open(filename, mode='rt')
#         # read all text
#         text = file.read()
#         # close the file
#         file.close()
#         return text

#     def add_special_tokens(self, sequence):
#         sequence = self.BOS_TOKEN + " " + sequence + " " + self.EOS_TOKEN
#         return sequence

#     def pair_sequences(self, src_sequences, tgt_sequences):
#         paired_sequences = list(zip(src_sequences, tgt_sequences))
#         sorted_pairs = sorted(paired_sequences, key=lambda x: len(x[0]))
#         return sorted_pairs

#     # split a loaded document into sequences
#     def to_sequences(self, doc, sequence_start_index, sequence_end_index):
#         sequences = doc.strip().split('\n')
#         return sequences[sequence_start_index:sequence_end_index]

#     def setup_tokenizers(self, src_filename, tgt_filename, src_vocab_size, tgt_vocab_size, src_tokenizer_name, tgt_tokenizer_name):
#         print("creating tokenizer for " + src_filename)
#         src_tokenizer = Tokenizer(BPE(unk_token=SentenceDataset.UNK_TOKEN))
#         src_tokenizer.pre_tokenizer = Whitespace()
#         # src_tokenizer.post_processor = TemplateProcessing(
#         #     single="[BOS] $A [EOS]",
#         #     special_tokens=[("[BOS]", 0), ("[EOS]", 1)],
#         # )
#         trainer = BpeTrainer(vocab_size = src_vocab_size, special_tokens=[SentenceDataset.BOS_TOKEN, SentenceDataset.EOS_TOKEN, SentenceDataset.PAD_TOKEN, SentenceDataset.UNK_TOKEN])
#         src_tokenizer.train([src_filename], trainer=trainer)
#         pickle.dump(src_tokenizer, open(src_tokenizer_name, "wb"))

#         print("creating tokenizer for " + tgt_filename)
#         tgt_tokenizer = Tokenizer(BPE(unk_token=SentenceDataset.UNK_TOKEN))
#         tgt_tokenizer.pre_tokenizer = Whitespace()
#         trainer = BpeTrainer(vocab_size = tgt_vocab_size, special_tokens=[SentenceDataset.BOS_TOKEN, SentenceDataset.EOS_TOKEN, SentenceDataset.PAD_TOKEN, SentenceDataset.UNK_TOKEN])
#         tgt_tokenizer.train([tgt_filename], trainer=trainer)
#         pickle.dump(tgt_tokenizer, open(tgt_tokenizer_name, "wb"))
#         return src_tokenizer, tgt_tokenizer

#     def __len__(self):
#         return len(self.pairs)

#     def __getitem__(self, index):
#         src_seq, tgt_seq = self.pairs[index]
#         return src_seq, tgt_seq


In [32]:
class SequencePairDataset(Dataset):
    BOS_TOKEN = "[SOS]"
    EOS_TOKEN = "[EOS]"
    PAD_TOKEN = "[PAD]"
    UNK_TOKEN = "[UNK]"
    PAD_ID = 2

    def __init__(self, src_text, tgt_text, start_index, end_index):
        src_sequences = self.to_sequences(src_text, start_index, end_index)
        #tgt_sequences = self.to_sequences(tgt_text, start_index, end_index)
        tgt_sequences = self.to_sequences(tgt_text, start_index, end_index)
        #src_sequences = [self.add_special_tokens(sequence) for sequence in src_sequences]
        #tgt_sequences = [self.add_special_tokens(sequence) for sequence in tgt_sequences]
        self.pairs = self.pair_sequences(src_sequences, tgt_sequences)

    def pair_sequences(self, src_sequences, tgt_sequences):
        paired_sequences = list(zip(src_sequences, tgt_sequences))
        sorted_pairs = sorted(paired_sequences, key=lambda x: len(x[0]))
        return sorted_pairs

    # split a loaded document into sequences
    def to_sequences(self, doc, sequence_start_index, sequence_end_index):
        sequences = doc.strip().split('\n')
        return sequences[sequence_start_index : sequence_end_index]

    def add_special_tokens(self, sequence):
        sequence = self.BOS_TOKEN + " " + sequence + " " + self.EOS_TOKEN
        return sequence

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, index):
        src_seq, tgt_seq = self.pairs[index]
        return src_seq, tgt_seq

In [33]:
class TrainAndValidationSequenceDatasets():
    def __init__(self, src_filename, tgt_filename, src_vocab_size, tgt_vocab_size, train_start_index, train_end_index, val_start_index, val_end_index):
        src_text = self.load_doc(src_filename)
        tgt_text = self.load_doc(tgt_filename)
        self.train_dataset = SequencePairDataset(src_text, tgt_text, train_start_index, train_end_index)
        self.val_dataset = SequencePairDataset(src_text, tgt_text, val_start_index, val_end_index)

        # load doc into memory
    def load_doc(self, filename):
        # open the file as read only
        file = open(filename, mode='rt')
        # read all text
        text = file.read()
        # close the file
        file.close()
        return text

In [34]:
import copy
class PadCollate:
    TOKENIZER_SUFFIX = "_tokenizer"

    def __init__(self, src_filename, tgt_filename, src_vocab_size, tgt_vocab_size):
        self.src_tokenizer, self.tgt_tokenizer = self.setup_tokenizers(src_filename, tgt_filename, src_vocab_size, tgt_vocab_size, src_filename + self.TOKENIZER_SUFFIX, tgt_filename + self.TOKENIZER_SUFFIX)

    def setup_tokenizers(self, src_filename, tgt_filename, src_vocab_size, tgt_vocab_size, src_tokenizer_name, tgt_tokenizer_name):
        print("creating tokenizer for " + src_filename)
        src_tokenizer = Tokenizer(BPE(unk_token=SequencePairDataset.UNK_TOKEN))
        src_tokenizer.pre_tokenizer = Whitespace()
        # src_tokenizer.post_processor = TemplateProcessing(
        #     single="[BOS] $A [EOS]",
        #     special_tokens=[("[BOS]", 0), ("[EOS]", 1)],
        # )
        trainer = BpeTrainer(vocab_size = src_vocab_size, special_tokens=[SequencePairDataset.BOS_TOKEN, SequencePairDataset.EOS_TOKEN, SequencePairDataset.PAD_TOKEN, SequencePairDataset.UNK_TOKEN])
        src_tokenizer.train([src_filename], trainer=trainer)
        pickle.dump(src_tokenizer, open(src_tokenizer_name, "wb"))

        print("creating tokenizer for " + tgt_filename)
        tgt_tokenizer = Tokenizer(BPE(unk_token=SequencePairDataset.UNK_TOKEN))
        tgt_tokenizer.pre_tokenizer = Whitespace()
        trainer = BpeTrainer(vocab_size = tgt_vocab_size, special_tokens=[SequencePairDataset.BOS_TOKEN, SequencePairDataset.EOS_TOKEN, SequencePairDataset.PAD_TOKEN, SequencePairDataset.UNK_TOKEN])
        tgt_tokenizer.train([tgt_filename], trainer=trainer)
        tgt_tokenizer.post_processor = TemplateProcessing(
            single="[BOS] $A [EOS]",
            special_tokens=[("[BOS]", 0), ("[EOS]", 1)],
        )
        pickle.dump(tgt_tokenizer, open(tgt_tokenizer_name, "wb"))
        return src_tokenizer, tgt_tokenizer

    def __call__(self, batch):
        # max_len_src = max([len(pair[0].split()) for pair in batch])
        # max_len_tgt = max([len(pair[1].split()) for pair in batch])

        #tgt_sequence_lengths

        self.src_tokenizer.no_padding()
        self.tgt_tokenizer.no_padding()

        self.src_tokenizer.no_truncation()
        self.tgt_tokenizer.no_truncation()

        src_tokenized = self.src_tokenizer.encode_batch([pair[0] for pair in batch])
        tgt_tokenized = self.tgt_tokenizer.encode_batch([pair[1] for pair in batch])

        max_len_src = max([len(sequence) for sequence in src_tokenized])
        max_len_tgt = max([len(sequence) for sequence in tgt_tokenized])

        # print("max len src:", max_len_src)
        # print("max len tgt:", max_len_tgt)

        self.src_tokenizer.enable_padding(pad_id = SequencePairDataset.PAD_ID, pad_token = SequencePairDataset.PAD_TOKEN)
        self.src_tokenizer.enable_truncation(max_length=max_len_src)
        self.tgt_tokenizer.enable_padding(pad_id = SequencePairDataset.PAD_ID, pad_token = SequencePairDataset.PAD_TOKEN)
        self.tgt_tokenizer.enable_truncation(max_length=max_len_tgt)

        # print("src batch:", [pair[0] for pair in batch])
        # print("tgt batch:", [pair[1] for pair in batch])

        src_tokenized = self.src_tokenizer.encode_batch([pair[0] for pair in batch])
        tgt_tokenized = self.tgt_tokenizer.encode_batch([pair[1] for pair in batch])
        # src_tokenized = [sequence.ids for sequence in src_tokenized]
        # tgt_tokenized = [sequence.ids for sequence in tgt_tokenized]
        # src_tensors = torch.IntTensor(src_tokenized)
        # tgt_tensor = torch.IntTensor(tgt_tokenized)

        return src_tokenized, tgt_tokenized

In [35]:
vocab_size = 10000
train_and_validation_sequence_datasets = TrainAndValidationSequenceDatasets(enRawName, deRawName, vocab_size, vocab_size, 0, 28250, 28250, 29000)
train_dataset = train_and_validation_sequence_datasets.train_dataset
val_dataset = train_and_validation_sequence_datasets.val_dataset

In [36]:
print(train_dataset.__getitem__(0))

('A dog in a car.', 'Ein Hund in einem Auto.')


In [37]:
pad_collate = PadCollate(enRawName, deRawName, vocab_size, vocab_size)
train_dataloader = DataLoader(train_dataset, batch_size=128, collate_fn = pad_collate)
val_dataloader = DataLoader(val_dataset, batch_size=128, collate_fn = pad_collate)

creating tokenizer for drive/MyDrive/colab data/multi30kEnTrain.txt
creating tokenizer for drive/MyDrive/colab data/multi30kDeTrain.txt


In [38]:
i = 0
for src, tgt in train_dataloader:
    print(src[0].ids, tgt[0].ids)
    # print("decoded", sequenceDataset.src_tokenizer.decode_batch([sequence.ids for sequence in src]))
    # print("tgt", tgt)

    # print("mask:", src[0].attention_mask)
    break

[30, 177, 83, 57, 238, 15, 2, 2] [0, 109, 200, 100, 111, 786, 14, 1, 2, 2, 2, 2]


In [39]:
def decode(x, tokenizer):
    x = torch.softmax(x, -1)
    #print("x softmax:", x)
    x = torch.argmax(x, dim=-1)
    x = x.tolist()
    print("argmax x:", x)
    return tokenizer.decode(x)

In [40]:
def test_decode(tokenizer):
    x = torch.tensor([[0, 5], [10, 20]], dtype=torch.float32)
    words = decode(x, tokenizer)
    print(words)

test_decode(pad_collate.tgt_tokenizer)

argmax x: [1, 1]



In [41]:
num_encoder_layers = 4
num_decoder_layers = 4
num_heads = 8
d_attn = 256
d_x = 256
d_z = 256
d_out = 256
d_mid = 256
d_mlp = 512
d_e = 256
max_sequence_length = 100
p_dropout = 0.1

In [120]:
encoder_decoder_transformer = EncoderDecoderTransformer(num_encoder_layers, num_decoder_layers, num_heads, d_attn, d_x, d_z, d_out, d_mid, d_mlp, d_e, vocab_size, max_sequence_length, 0.1, False).to(device)

def train_model(encoder_decoder_transformer, train_dataloader, val_dataloader, src_tokenizer, tgt_tokenizer):
    torch.manual_seed(25)
    epochs = 1000
    print(encoder_decoder_transformer.parameters())
    nameSuffix = "-maxData"
    state_dict_filename = folder + "encoder_decoder_transformer_state_dict_" + datetime.today().strftime('%Y-%m-%d %H') + nameSuffix
    opt = optim.AdamW(encoder_decoder_transformer.parameters(), lr=0.0001, weight_decay=0.0001)
    loss_function = nn.CrossEntropyLoss(label_smoothing=0.1, ignore_index=2)
    #labelSmoothing = LabelSmoothing(2000, PADDING_IDX, 0.1)
    training_step = 0
    validation_step = 0
    best_val_loss = 100
    num_fails = 0
    # Large models need this to actually train
    for p in encoder_decoder_transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    for i in range(epochs):
        epoch_time_start = time.time()
        dataloader_iter = iter(train_dataloader)
        train_losses = []
        val_losses = []
        for src_batch, tgt_batch in dataloader_iter:
            # print("x:", sequence_x)
            # print("z:", sequence_z)
            # sequence_x, sequence_z = sequenceDataset.__getitem__(i)
            src_tokens = torch.IntTensor([sequence.ids for sequence in src_batch]).to(device)
            encoder_input = src_tokens
            train_tgt_tokens = torch.IntTensor([sequence.ids for sequence in tgt_batch]).to(device)
            decoder_input = train_tgt_tokens[:, :-1]
            decoder_desired_output_train = train_tgt_tokens[:, 1:]
            src_masks = torch.IntTensor([sequence.attention_mask for sequence in src_batch]).to(device)
            tgt_masks = torch.IntTensor([sequence.attention_mask for sequence in tgt_batch])[:, :-1].to(device)
            # print("src masks", src_masks)
            # print("tgt masks", tgt_masks)
            train_output = encoder_decoder_transformer(encoder_input, decoder_input, src_masks, tgt_masks)
            #print("output", train_output)
            output_transpose = train_output.transpose(-1, -2) # output needs to be N, C, other dimension for torch cross entropy
            loss = loss_function(output_transpose, decoder_desired_output_train.long())
            opt.zero_grad()
            loss.backward()
            opt.step()
            train_losses.append(loss.item())
            if (training_step % 20 == 0):
                print("Completed training step", training_step)
            training_step += 1

        for src_batch, tgt_batch in val_dataloader:
            src_tokens = torch.IntTensor([sequence.ids for sequence in src_batch]).to(device)
            encoder_input = src_tokens
            val_tgt_tokens = torch.IntTensor([sequence.ids for sequence in tgt_batch]).to(device)
            decoder_input = val_tgt_tokens[:, :-1]
            decoder_desired_output_val = val_tgt_tokens[:, 1:]
            src_masks = torch.IntTensor([sequence.attention_mask for sequence in src_batch]).to(device)
            tgt_masks = torch.IntTensor([sequence.attention_mask for sequence in tgt_batch])[:, :-1].to(device)
            val_output = encoder_decoder_transformer(encoder_input, decoder_input, src_masks, tgt_masks)
            output_transpose = val_output.transpose(-1, -2) # output needs to be N, C, other dimension for torch cross entropy
            loss = loss_function(output_transpose, decoder_desired_output_val.long())
            val_losses.append(loss.item())
            if (validation_step % 20 == 0):
                print("Completed validation step", validation_step)
            validation_step += 1

        print("epoch", i, "took", time.time() - epoch_time_start)
        print("avg training loss:", sum(train_losses) / len(train_losses))
        avg_val_loss = sum(val_losses)/ len(val_losses)
        print("avg validation loss:", avg_val_loss)
        expected_train_output = tgt_tokenizer.decode(decoder_desired_output_train[0].tolist())
        print("expected train output", expected_train_output)
        decoded_output = decode(train_output[0], tgt_tokenizer)
        print("decoded train output:", decoded_output)
        expected_val_output = tgt_tokenizer.decode(decoder_desired_output_val[0].tolist())
        print("expected validation output", expected_val_output)
        decoded_output = decode(val_output[0], tgt_tokenizer)
        print("decoded validation output:", decoded_output)
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(encoder_decoder_transformer.state_dict(), state_dict_filename)
            print("Saved model state dict to", state_dict_filename)
            num_fails = 0
        else:
            print("Average validation loss did not decrease from ", best_val_loss)
            num_fails += 1
            print("Failed to decrease the average validation loss", num_fails, "times.")
            if num_fails >= 2:
                print("Stopping training")
                break
        print()
        print()

In [47]:
custom_encoder_decoder_transformer = EncoderDecoderTransformer(num_encoder_layers, num_decoder_layers, num_heads, d_attn, d_x, d_z, d_out, d_mid, d_mlp, d_e, vocab_size, max_sequence_length, p_dropout, False).to(device)



In [127]:
# For some reason the Pytorch transformer doesn't have its own embedding layers. Adding them here.
class ExtendedPytorchTransformer(nn.Module):
    def __init__(self, d_model, num_heads, num_encoder_layers, num_decoder_layers, d_mlp, p_dropout, batch_first = True, norm_first = True):
        super().__init__()
        self.src_embedding = Embedding(vocab_size, d_model)
        self.tgt_embedding = Embedding(vocab_size, d_model)
        self.unembedding = Unembedding(vocab_size, d_model)
        self.transformer = nn.Transformer(d_model, num_heads, num_encoder_layers, num_decoder_layers, d_mlp, p_dropout, batch_first = batch_first, norm_first = norm_first)
        self.embedding_dropout = nn.Dropout(p_dropout)
        self.positionalEmbedding = PositionalEmbedding(d_model, max_sequence_length)

    def forward(self, src_sequence, tgt_sequence, src_mask, tgt_key_padding_mask):
        src_sequence = self.src_embedding(src_sequence) + self.positionalEmbedding(src_sequence)
        src_sequence = self.embedding_dropout(src_sequence)
        tgt_sequence = self.tgt_embedding(tgt_sequence) + self.positionalEmbedding(tgt_sequence)
        tgt_sequence = self.embedding_dropout(tgt_sequence)
        tgt_mask = self.get_tgt_mask(tgt_sequence.shape[1], tgt_sequence.shape[1])
        #print(src_sequence.shape, tgt_sequence.shape)
        src_mask = ~src_mask.bool()
        tgt_mask = ~tgt_mask.bool()
        tgt_key_padding_mask = ~tgt_key_padding_mask.bool()
        #print(src_mask.shape, tgt_mask.shape)
        transformer_out = self.transformer(src_sequence, tgt_sequence, src_key_padding_mask = src_mask, tgt_key_padding_mask = tgt_key_padding_mask, tgt_mask = tgt_mask)
        #print(transformer_out)
        return self.unembedding(transformer_out)


    def get_tgt_mask(self, length_x, length_z):
        mask = torch.tril(torch.ones(length_x, length_z) == 1).to(device)
        #print(mask.shape)
        return mask


In [128]:
pytorch_encoder_decoder_transformer = ExtendedPytorchTransformer(d_e, num_heads, num_encoder_layers, num_decoder_layers, d_mlp, p_dropout, batch_first = True, norm_first = True).to(device)

In [129]:
train_model(pytorch_encoder_decoder_transformer, train_dataloader, val_dataloader, pad_collate.src_tokenizer, pad_collate.tgt_tokenizer)

<generator object Module.parameters at 0x7f70c30c52a0>
Completed training step 0
Completed training step 20
Completed training step 40
Completed training step 60
Completed training step 80
Completed training step 100
Completed training step 120
Completed training step 140
Completed training step 160
Completed training step 180
Completed training step 200
Completed training step 220
Completed validation step 0
epoch 0 took 21.481504440307617
avg training loss: 6.997570639821738
avg validation loss: 6.18852702776591
expected train output Ein schwarz - rot - weißes Rennwagen saust im Vordergrund auf einer grauen Strecke mit einer blauen Ban de , im Hintergrund ist eine verschwommen e Menschenmenge zu sehen .
argmax x: [109, 124, 100, 100, 12, 12, 12, 12, 114, 12, 12, 111, 12, 12, 12, 111, 12, 111, 12, 14, 114, 14, 14, 14, 1, 14, 14, 14, 14, 14, 1, 14, 14, 14, 14, 111, 14, 14, 14, 14, 14, 114, 14, 111, 1, 14, 111, 14, 12]
decoded train output: Ein Mann in in , , , , und , , einem , , , ein

In [113]:
#state_dict_filename = folder + "encoder_decoder_transformer_state_dict_" + "2024-08-02 15"
state_dict = torch.load(state_dict_filename, map_location = device)
print(state_dict.keys())
encoder_decoder_transformer.load_state_dict(state_dict)

odict_keys(['src_embedding.table.weight', 'tgt_embedding.table.weight', 'unembedding.weight.weight', 'unembedding.weight.bias', 'positionalEmbedding.table.weight', 'encoder.layers.0.multi_head_attention.weight_query.weight', 'encoder.layers.0.multi_head_attention.weight_query.bias', 'encoder.layers.0.multi_head_attention.weight_key.weight', 'encoder.layers.0.multi_head_attention.weight_key.bias', 'encoder.layers.0.multi_head_attention.weight_value.weight', 'encoder.layers.0.multi_head_attention.weight_value.bias', 'encoder.layers.0.multi_head_attention.weight_out.weight', 'encoder.layers.0.multi_head_attention.weight_out.bias', 'encoder.layers.0.layer_norm1.scale', 'encoder.layers.0.layer_norm1.offset', 'encoder.layers.0.feed_forward.mlp1', 'encoder.layers.0.feed_forward.mlp2', 'encoder.layers.0.feed_forward.mlp1_bias', 'encoder.layers.0.feed_forward.mlp2_bias', 'encoder.layers.0.layer_norm2.scale', 'encoder.layers.0.layer_norm2.offset', 'encoder.layers.1.multi_head_attention.weight_qu

<All keys matched successfully>

In [116]:
def test_model_with_one_sample(model):
    dataloader_iter = iter(train_dataloader)
    for src_batch, tgt_batch in dataloader_iter:
        # print("x:", sequence_x)
        # print("z:", sequence_z)
        # sequence_x, sequence_z = sequenceDataset.__getitem__(i)
        src_batch = [src_batch[0]]
        tgt_batch = [tgt_batch[0]]
        src_tokens = torch.IntTensor([sequence.ids for sequence in src_batch]).to(device)
        encoder_input = src_tokens
        print("encoder input", encoder_input)
        decoded_input = pad_collate.src_tokenizer.decode(encoder_input[0].tolist())
        print("decoded input", decoded_input)
        train_tgt_tokens = torch.IntTensor([sequence.ids for sequence in tgt_batch]).to(device)
        decoder_input = train_tgt_tokens[:, :-1]
        print("decoder input", decoder_input)
        decoder_desired_output_train = train_tgt_tokens[:, 1:]
        src_masks = torch.IntTensor([sequence.attention_mask for sequence in src_batch]).to(device)
        tgt_masks = torch.IntTensor([sequence.attention_mask for sequence in tgt_batch])[:, :-1].to(device)
        # print("src masks", src_masks)
        # print("tgt masks", tgt_masks)
        train_output = encoder_decoder_transformer(encoder_input, decoder_input, src_masks, tgt_masks)
        print(train_output.shape)
        decoded_output = decode(train_output[0], pad_collate.tgt_tokenizer)
        print("decoded output", decoded_output)
        break

test_model_with_one_sample(encoder_decoder_transformer)

encoder input tensor([[  30,   93,   89,   94, 1602,   15,    2,    2]], device='cuda:0',
       dtype=torch.int32)
decoded input A man on the sea .
decoder input tensor([[   0,   30,   93,   89,   94, 1602,   15,    1,    2]],
       device='cuda:0', dtype=torch.int32)
torch.Size([1, 9, 10000])
argmax x: [30, 93, 89, 94, 1602, 15, 1, 2, 2]
decoded output A man on the sea .


In [None]:
#torch.save(encoder_decoder_transformer.state_dict(), state_dict_filename)

In [132]:
def predict_from_tokens(model, input, src_tokenizer, tgt_tokenizer):
    #model.disable_subsequent_mask()
    src_tokenizer.no_padding()
    tgt_tokenizer.no_padding()

    src_tokenizer.no_truncation()
    tgt_tokenizer.no_truncation()
    src_sequence = input
    print(src_sequence)
    src_sequence = src_tokenizer.encode(src_sequence)
    print(src_sequence)
    print(src_tokenizer.decode(src_sequence.ids))
    src_sequence = torch.IntTensor(src_sequence.ids).unsqueeze(0).to(device)
    print("src tokens", src_sequence)
    tgt_sequence = torch.IntTensor([0]).unsqueeze(0).to(device)
    src_mask = torch.ones(src_sequence.shape, dtype=torch.int32).to(device)
    print("decoder input", tgt_sequence)
    predictions = []
    with torch.no_grad():
        model.eval()
        length_gen = 100
        for i in range(length_gen):
            tgt_mask = torch.ones(tgt_sequence.shape, dtype=torch.int32).to(device)
            prediction = model(src_sequence, tgt_sequence, src_mask, tgt_mask)
            #print("prediction:", prediction)
            prediction = torch.softmax(prediction, -1)
            #print("softmax prediction:", prediction.shape)
            prediction = torch.argmax(prediction, dim=-1)
            print("argmax prediction:", prediction)
            print("actual prediction:", tgt_tokenizer.decode(prediction[0].tolist()))
            last_token = prediction[0][-1]
            tgt_sequence = torch.cat((tgt_sequence, last_token.unsqueeze(0).unsqueeze(0)), dim=-1)
            if last_token == 1:
                break
    return tgt_sequence

tgt_sequence = predict_from_tokens(pytorch_encoder_decoder_transformer, "A man sings a song.", pad_collate.src_tokenizer, pad_collate.tgt_tokenizer)
print(tgt_sequence)
print(pad_collate.tgt_tokenizer.decode(tgt_sequence[0].tolist()))

A man sings a song.
Encoding(num_tokens=6, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])
A man sings a song .
src tokens tensor([[  30,   93, 1561,   57, 2705,   15]], device='cuda:0',
       dtype=torch.int32)
decoder input tensor([[0]], device='cuda:0', dtype=torch.int32)
argmax prediction: tensor([[109]], device='cuda:0')
actual prediction: Ein
argmax prediction: tensor([[109, 124]], device='cuda:0')
actual prediction: Ein Mann
argmax prediction: tensor([[109, 124, 829]], device='cuda:0')
actual prediction: Ein Mann singt
argmax prediction: tensor([[109, 124, 829, 114]], device='cuda:0')
actual prediction: Ein Mann singt und
argmax prediction: tensor([[109, 124, 829, 114, 829]], device='cuda:0')
actual prediction: Ein Mann singt und singt
argmax prediction: tensor([[109, 124, 829, 114, 829, 103]], device='cuda:0')
actual prediction: Ein Mann singt und singt ein
argmax prediction: tensor([[ 109,  124,  829,  114,  829,  103, 3680]], de