<a href="https://colab.research.google.com/github/JackWittmayer/Transformer-Implementation/blob/main/EDTransformerDevNotebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook was used when first implementing and training the model. I have since moved the code into individual files, making this notebook out of date.

I am keeping it around because it's useful to see the history of it.

If you want to run a notebook in colab with the new files, use EDTransformer.ipynb.

In [None]:
!pip install tokenizers



In [None]:
import re
import string
import os
import pickle
from unicodedata import normalize
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.nn.functional import log_softmax, pad

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import BpeTrainer
from tokenizers.processors import TemplateProcessing

import random
import time

import numpy as np
import math
import matplotlib.pyplot as plt

import sys
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import corpus_bleu

from datetime import datetime
from enum import Enum
import logging

In [None]:
torch.manual_seed(25)
random.seed(25)
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(device)

cuda


In [None]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, embedding_size):
        super().__init__()
        self.table = nn.Embedding(vocab_size, embedding_size)

    def forward(self, sequence):
        embeddings = self.table(sequence)
        return embeddings

In [None]:
class Unembedding(nn.Module):
    def __init__(self, vocab_size, embedding_size):
        super().__init__()
        self.weight = nn.Linear(embedding_size, vocab_size)

    def forward(self, x):
        return self.weight(x)

In [None]:
class PositionalEmbedding(nn.Module):
    def __init__(self, embedding_size, max_sequence_length, device):
        super().__init__()
        self.table = nn.Embedding(max_sequence_length, embedding_size)
        self.device = device

    def forward(self, sequence):
        positions = torch.zeros(sequence.shape, dtype=torch.int32).to(self.device)
        positions[:, ::] = torch.arange(0, sequence.shape[-1])
        positional_embeddings = self.table(positions)
        return positional_embeddings

In [None]:
class MaskStrategy(Enum):
    UNMASKED = 1
    MASKED = 2


class MultiHeadedAttention(nn.Module):
    def __init__(
        self,
        num_heads,
        d_attn,
        d_x,
        d_z,
        d_out,
        d_mid,
        maskStrategy,
        p_dropout
    ):
        super().__init__()
        self.num_heads = num_heads
        self.d_attn = d_attn
        self.d_x = d_x
        self.d_z = d_z
        self.d_out = d_out
        self.d_mid = d_mid
        self.maskStrategy = maskStrategy
        self.weight_query = nn.Linear(d_x, d_attn)
        self.weight_key = nn.Linear(d_z, d_attn)
        self.weight_value = nn.Linear(d_z, d_mid)
        self.weight_out = nn.Linear(d_mid, d_out)
        self.dropout = nn.Dropout(p_dropout)

    def forward(self, z, x, padding_mask):
        length_z = z.shape[-2]
        length_x = x.shape[-2]
        batch_size = x.shape[0]

        queries = (
            self.weight_query(x)
            .view(batch_size, length_x, self.num_heads, -1)
            .transpose(1, 2)
        )
        keys = (
            self.weight_key(z)
            .view(batch_size, length_z, self.num_heads, -1)
            .transpose(1, 2)
        )
        values = (
            self.weight_value(z)
            .view(batch_size, length_z, self.num_heads, -1)
            .transpose(1, 2)
        )

        assert queries.shape == (
            batch_size,
            self.num_heads,
            length_x,
            self.d_attn / self.num_heads,
        )
        assert keys.shape == (
            batch_size,
            self.num_heads,
            length_z,
            self.d_attn / self.num_heads,
        )
        assert values.shape == (
            batch_size,
            self.num_heads,
            length_z,
            self.d_mid / self.num_heads,
        )

        if self.maskStrategy == MaskStrategy["UNMASKED"]:
            mask = padding_mask.unsqueeze(-2)
        elif self.maskStrategy == MaskStrategy["MASKED"]:
            padding_mask = padding_mask.unsqueeze(-2)
            if torch.cuda.is_available():
                device = torch.device("cuda")
            elif torch.backends.mps.is_available():
                device = torch.device("mps")
            else:
                device = torch.device("cpu")
            mask = torch.tril(torch.ones(length_x, length_z) == 1).to(device)
            # logging.debug(f"{padding_mask.shape=}")
            # logging.debug(f"{mask=}")
            mask = mask & padding_mask
            # logging.debug(f"{mask=}")
        mask = mask.unsqueeze(1)
        # logging.debug(f"{mask=}")
        # logging.debug(f"{mask.shape=}")
        v_out = self.attention(queries, keys, values, mask, self.dropout)
        # logging.debug(f"{v_out.shape=}")
        assert v_out.shape == (
            batch_size,
            self.num_heads,
            length_x,
            self.d_mid / self.num_heads,
        )
        # logging.debug(f"{v_out=}")
        # logging.debug(f"{v_out.shape=}")
        v_out = v_out.transpose(1, 2).reshape(batch_size, length_x, -1)
        # logging.debug(f"{v_out.shape=}")
        # logging.debug(f"{v_out=}")
        output = self.weight_out(v_out)
        # logging.debug(f"{output.shape=}")
        assert output.shape == (batch_size, length_x, self.d_out)
        return output

    @staticmethod
    def attention(queries, keys, values, mask, dropout):
        # logging.debug(f"{queries=}")
        # logging.debug(f"{keys=}")
        # logging.debug(f"{values=}")
        keys_transposed = torch.transpose(keys, -2, -1)
        # logging.debug(f"{keys_transposed=}")
        scores = torch.matmul(queries, keys_transposed)
        # assert scores.shape == (keys.shape[0], keys.shape[-1], queries.shape[-1])
        # logging.debug(f"{scores=}")
        # logging.debug(f"{scores.shape=}")
        # logging.debug(f"{mask.shape=}")
        scores = scores.masked_fill(mask == 0, -1e9)
        # logging.debug(f"{scores=}")
        d_attn = keys.shape[-1]
        scaled_scores = scores / math.sqrt(d_attn)
        # logging.debug(f"{scaled_scores=}")
        softmax_scores = torch.softmax(scaled_scores, -1)
        softmax_scores = dropout(softmax_scores)
        # logging.debug(f"{softmax_scores=}")
        # logging.debug(f"{softmax_scores.shape=}")
        # logging.debug(f"{values=}")
        v_out = torch.matmul(softmax_scores, values)
        return v_out

    def disable_subsequent_mask(self):
        self.maskStrategy = MaskStrategy["UNMASKED"]

    def enable_subsequent_mask(self):
        self.maskStrategy = MaskStrategy["MASKED"]


In [None]:
class LayerNorm(nn.Module):
    def __init__(self, feature_length):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(feature_length))
        self.offset = nn.Parameter(torch.zeros(feature_length))

    def forward(self, activations):
        mean = torch.mean(activations, -1, keepdim=True)
        variance = torch.var(activations, -1, keepdim=True, unbiased=False)
        normalized_activations = (activations - mean) / torch.sqrt(variance + 1e-6)
        return (normalized_activations * self.scale) + self.offset

In [None]:
class FeedForward(nn.Module):
    def __init__(self, hiddenLayerWidth, d_e, p_dropout):
        super().__init__()
        self.mlp1 = nn.Parameter(torch.rand(d_e, hiddenLayerWidth))
        self.mlp2 = nn.Parameter(torch.rand(hiddenLayerWidth, d_e))
        self.mlp1_bias = nn.Parameter(torch.zeros(hiddenLayerWidth))
        self.mlp2_bias = nn.Parameter(torch.zeros(d_e))
        self.dropout = nn.Dropout(p_dropout)

    def forward(self, activations):
        activations = torch.matmul(activations, self.mlp1) + self.mlp1_bias
        activations = activations.relu()
        activations = torch.matmul(activations, self.mlp2) + self.mlp2_bias
        activations = self.dropout(activations)
        return activations

In [None]:
class EncoderLayer(nn.Module):
    def __init__(
        self, num_heads, d_attn, d_x, d_z, d_out, d_mid, d_mlp, p_dropout
    ):
        super().__init__()
        self.multi_head_attention = MultiHeadedAttention(
            num_heads,
            d_attn,
            d_x,
            d_z,
            d_out,
            d_mid,
            MaskStrategy["UNMASKED"],
            p_dropout
        )
        self.layer_norm1 = LayerNorm(d_z)
        self.feed_forward = FeedForward(d_mlp, d_z, p_dropout)
        self.layer_norm2 = LayerNorm(d_z)

    def forward(self, z, padding_mask):
        z = self.layer_norm1(z)
        z = z + self.multi_head_attention(z, z, padding_mask)
        z = self.layer_norm2(z)
        z = z + self.feed_forward(z)
        return z


class Encoder(nn.Module):
    def __init__(
        self,
        num_layers,
        num_heads,
        d_attn,
        d_x,
        d_z,
        d_out,
        d_mid,
        d_mlp,
        p_dropout
    ):
        super().__init__()
        self.layers = []
        for i in range(num_layers):
            encoder_layer = EncoderLayer(
                num_heads, d_attn, d_x, d_z, d_out, d_mid, d_mlp, p_dropout
            )
            self.layers.append(encoder_layer)
        self.layers = nn.ModuleList(self.layers)
        self.final_norm = LayerNorm(d_z)

    def forward(self, z, padding_mask):
        for layer in self.layers:
            z = layer(z, padding_mask)
        return self.final_norm(z)


In [None]:
class DecoderLayer(nn.Module):
    def __init__(
        self, num_heads, d_attn, d_x, d_z, d_out, d_mid, d_mlp, p_dropout
    ):
        super().__init__()
        self.multi_head_self_attention = MultiHeadedAttention(
            num_heads,
            d_attn,
            d_x,
            d_z,
            d_out,
            d_mid,
            MaskStrategy["MASKED"],
            p_dropout
        )
        self.layer_norm1 = LayerNorm(d_x)
        self.multi_head_global_attention = MultiHeadedAttention(
            num_heads,
            d_attn,
            d_x,
            d_z,
            d_out,
            d_mid,
            MaskStrategy["UNMASKED"],
            p_dropout
        )
        self.layer_norm2 = LayerNorm(d_x)
        self.feed_forward = FeedForward(d_mlp, d_x, p_dropout)
        self.layer_norm3 = LayerNorm(d_x)

    def forward(self, z, x, src_mask, tgt_mask):
        x = self.layer_norm1(x)
        x = x + self.multi_head_self_attention(x, x, tgt_mask)
        x = self.layer_norm2(x)
        x = x + self.multi_head_global_attention(z, x, src_mask)
        x = self.layer_norm3(x)
        x = x + self.feed_forward(x)
        return x

    def disable_subsequent_mask(self):
        self.multi_head_self_attention.disable_subsequent_mask()


class Decoder(nn.Module):
    def __init__(
        self,
        num_layers,
        num_heads,
        d_attn,
        d_x,
        d_z,
        d_out,
        d_mid,
        d_mlp,
        p_dropout
    ):
        super().__init__()
        self.layers = []
        for i in range(num_layers):
            decoder_layer = DecoderLayer(
                num_heads, d_attn, d_x, d_z, d_out, d_mid, d_mlp, p_dropout
            )
            self.layers.append(decoder_layer)
        self.layers = nn.ModuleList(self.layers)
        self.final_norm = LayerNorm(d_x)

    def forward(self, z, x, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(z, x, src_mask, tgt_mask)
        return self.final_norm(x)

    def disable_subsequent_mask(self):
        for layer in self.layers:
            layer.multi_head_self_attention.disable_subsequent_mask()


In [None]:
class EncoderDecoderTransformer(nn.Module):
    def __init__(
        self,
        num_encoder_layers,
        num_decoder_layers,
        num_heads,
        d_attn,
        d_x,
        d_z,
        d_out,
        d_mid,
        d_mlp,
        d_e,
        vocab_size,
        max_sequence_length,
        p_dropout,
        device
    ):
        super().__init__()
        self.src_embedding = Embedding(vocab_size, d_e)
        self.tgt_embedding = Embedding(vocab_size, d_e)
        self.unembedding = Unembedding(vocab_size, d_e)
        self.embedding_dropout = nn.Dropout(p_dropout)
        self.positionalEmbedding = PositionalEmbedding(d_e, max_sequence_length, device)
        self.encoder = Encoder(
            num_encoder_layers,
            num_heads,
            d_attn,
            d_x,
            d_z,
            d_out,
            d_mid,
            d_mlp,
            p_dropout
        )
        self.decoder = Decoder(
            num_decoder_layers,
            num_heads,
            d_attn,
            d_x,
            d_z,
            d_out,
            d_mid,
            d_mlp,
            p_dropout
        )

    def forward(self, z, x, src_mask, tgt_mask):
        z = self.src_embedding(z) + self.positionalEmbedding(z)
        z = self.embedding_dropout(z)
        z = self.encoder(z, src_mask)
        x = self.tgt_embedding(x) + self.positionalEmbedding(x)
        x = self.embedding_dropout(x)
        x = self.decoder(z, x, src_mask, tgt_mask)
        x = self.unembedding(x)
        return x

    def disable_subsequent_mask(self):
        self.decoder.disable_subsequent_mask()


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
class SequencePairDataset(Dataset):
    BOS_TOKEN = "[SOS]"
    EOS_TOKEN = "[EOS]"
    PAD_TOKEN = "[PAD]"
    UNK_TOKEN = "[UNK]"
    PAD_ID = 2

    def __init__(self, src_text, tgt_text, start_index, end_index):
        src_sequences = self.to_sequences(src_text, start_index, end_index)
        # tgt_sequences = self.to_sequences(tgt_text, start_index, end_index)
        tgt_sequences = self.to_sequences(tgt_text, start_index, end_index)
        # src_sequences = [self.add_special_tokens(sequence) for sequence in src_sequences]
        # tgt_sequences = [self.add_special_tokens(sequence) for sequence in tgt_sequences]
        self.pairs = self.pair_sequences(src_sequences, tgt_sequences)

    def pair_sequences(self, src_sequences, tgt_sequences):
        paired_sequences = list(zip(src_sequences, tgt_sequences))
        sorted_pairs = sorted(paired_sequences, key=lambda x: len(x[0]))
        return sorted_pairs

    # split a loaded document into sequences
    def to_sequences(self, doc, sequence_start_index, sequence_end_index):
        sequences = doc.strip().split("\n")
        return sequences[sequence_start_index:sequence_end_index]

    def add_special_tokens(self, sequence):
        sequence = self.BOS_TOKEN + " " + sequence + " " + self.EOS_TOKEN
        return sequence

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, index):
        src_seq, tgt_seq = self.pairs[index]
        return src_seq, tgt_seq

In [None]:
class TrainAndValidationSequenceDatasets:
    def __init__(
        self,
        src_filename,
        tgt_filename,
        src_vocab_size,
        tgt_vocab_size,
        train_start_index,
        train_end_index,
        val_start_index,
        val_end_index,
    ):
        src_text = self.load_doc(src_filename)
        tgt_text = self.load_doc(tgt_filename)
        self.train_dataset = SequencePairDataset(
            src_text, tgt_text, train_start_index, train_end_index
        )
        self.val_dataset = SequencePairDataset(
            src_text, tgt_text, val_start_index, val_end_index
        )

        # load doc into memory

    def load_doc(self, filename):
        # open the file as read only
        file = open(filename, mode="rt")
        # read all text
        text = file.read()
        # close the file
        file.close()
        return text


In [None]:
class PadCollate:
    TOKENIZER_SUFFIX = "_tokenizer"

    def __init__(self, src_filename, tgt_filename, src_vocab_size, tgt_vocab_size):
        self.src_tokenizer, self.tgt_tokenizer = self.setup_tokenizers(
            src_filename,
            tgt_filename,
            src_vocab_size,
            tgt_vocab_size,
            src_filename + self.TOKENIZER_SUFFIX,
            tgt_filename + self.TOKENIZER_SUFFIX,
        )

    def setup_tokenizers(
        self,
        src_filename,
        tgt_filename,
        src_vocab_size,
        tgt_vocab_size,
        src_tokenizer_name,
        tgt_tokenizer_name,
    ):
        print("creating tokenizer for " + src_filename)
        src_tokenizer = Tokenizer(BPE(unk_token=SequencePairDataset.UNK_TOKEN))
        src_tokenizer.pre_tokenizer = Whitespace()
        # src_tokenizer.post_processor = TemplateProcessing(
        #     single="[BOS] $A [EOS]",
        #     special_tokens=[("[BOS]", 0), ("[EOS]", 1)],
        # )
        trainer = BpeTrainer(
            vocab_size=src_vocab_size,
            special_tokens=[
                SequencePairDataset.BOS_TOKEN,
                SequencePairDataset.EOS_TOKEN,
                SequencePairDataset.PAD_TOKEN,
                SequencePairDataset.UNK_TOKEN,
            ],
        )
        src_tokenizer.train([src_filename], trainer=trainer)
        pickle.dump(src_tokenizer, open(src_tokenizer_name, "wb"))

        print("creating tokenizer for " + tgt_filename)
        tgt_tokenizer = Tokenizer(BPE(unk_token=SequencePairDataset.UNK_TOKEN))
        tgt_tokenizer.pre_tokenizer = Whitespace()
        trainer = BpeTrainer(
            vocab_size=tgt_vocab_size,
            special_tokens=[
                SequencePairDataset.BOS_TOKEN,
                SequencePairDataset.EOS_TOKEN,
                SequencePairDataset.PAD_TOKEN,
                SequencePairDataset.UNK_TOKEN,
            ],
        )
        tgt_tokenizer.train([tgt_filename], trainer=trainer)
        tgt_tokenizer.post_processor = TemplateProcessing(
            single="[BOS] $A [EOS]",
            special_tokens=[("[BOS]", 0), ("[EOS]", 1)],
        )
        pickle.dump(tgt_tokenizer, open(tgt_tokenizer_name, "wb"))
        return src_tokenizer, tgt_tokenizer

    def __call__(self, batch):
        # max_len_src = max([len(pair[0].split()) for pair in batch])
        # max_len_tgt = max([len(pair[1].split()) for pair in batch])

        # tgt_sequence_lengths

        self.src_tokenizer.no_padding()
        self.tgt_tokenizer.no_padding()

        self.src_tokenizer.no_truncation()
        self.tgt_tokenizer.no_truncation()

        src_tokenized = self.src_tokenizer.encode_batch([pair[0] for pair in batch])
        tgt_tokenized = self.tgt_tokenizer.encode_batch([pair[1] for pair in batch])

        max_len_src = max([len(sequence) for sequence in src_tokenized])
        max_len_tgt = max([len(sequence) for sequence in tgt_tokenized])

        # print("max len src:", max_len_src)
        # print("max len tgt:", max_len_tgt)

        self.src_tokenizer.enable_padding(
            pad_id=SequencePairDataset.PAD_ID, pad_token=SequencePairDataset.PAD_TOKEN
        )
        self.src_tokenizer.enable_truncation(max_length=max_len_src)
        self.tgt_tokenizer.enable_padding(
            pad_id=SequencePairDataset.PAD_ID, pad_token=SequencePairDataset.PAD_TOKEN
        )
        self.tgt_tokenizer.enable_truncation(max_length=max_len_tgt)

        # print("src batch:", [pair[0] for pair in batch])
        # print("tgt batch:", [pair[1] for pair in batch])

        src_tokenized = self.src_tokenizer.encode_batch([pair[0] for pair in batch])
        tgt_tokenized = self.tgt_tokenizer.encode_batch([pair[1] for pair in batch])
        # src_tokenized = [sequence.ids for sequence in src_tokenized]
        # tgt_tokenized = [sequence.ids for sequence in tgt_tokenized]
        # src_tensors = torch.IntTensor(src_tokenized)
        # tgt_tensor = torch.IntTensor(tgt_tokenized)

        return src_tokenized, tgt_tokenized


In [None]:
def decode(x, tokenizer):
    x = torch.softmax(x, -1)
    # print("x softmax:", x)
    x = torch.argmax(x, dim=-1)
    x = x.tolist()
    print("argmax x:", x)
    return tokenizer.decode(x)


def train_model(
    encoder_decoder_transformer,
    train_dataloader,
    val_dataloader,
    src_tokenizer,
    tgt_tokenizer,
    device,
    state_dict_filename
):
    torch.manual_seed(25)

    epochs = 1000
    print(encoder_decoder_transformer.parameters())
    opt = optim.AdamW(
        encoder_decoder_transformer.parameters(), lr=0.0001, weight_decay=0.0001
    )
    loss_function = nn.CrossEntropyLoss(label_smoothing=0.1, ignore_index=2)
    # labelSmoothing = LabelSmoothing(2000, PADDING_IDX, 0.1)
    training_step = 0
    validation_step = 0
    best_val_loss = 100
    num_fails = 0
    # Large models need this to actually train
    for p in encoder_decoder_transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    for i in range(epochs):
        epoch_time_start = time.time()
        dataloader_iter = iter(train_dataloader)
        train_losses = []
        val_losses = []
        for src_batch, tgt_batch in dataloader_iter:
            # print("x:", sequence_x)
            # print("z:", sequence_z)
            # sequence_x, sequence_z = sequenceDataset.__getitem__(i)
            src_tokens = torch.IntTensor([sequence.ids for sequence in src_batch]).to(
                device
            )
            encoder_input = src_tokens
            train_tgt_tokens = torch.IntTensor(
                [sequence.ids for sequence in tgt_batch]
            ).to(device)
            decoder_input = train_tgt_tokens[:, :-1]
            decoder_desired_output_train = train_tgt_tokens[:, 1:]
            src_masks = torch.IntTensor(
                [sequence.attention_mask for sequence in src_batch]
            ).to(device)
            tgt_masks = torch.IntTensor(
                [sequence.attention_mask for sequence in tgt_batch]
            )[:, :-1].to(device)
            # print("src masks", src_masks)
            # print("tgt masks", tgt_masks)
            train_output = encoder_decoder_transformer(
                encoder_input, decoder_input, src_masks, tgt_masks
            )
            # print("output", train_output)
            output_transpose = train_output.transpose(
                -1, -2
            )  # output needs to be N, C, other dimension for torch cross entropy
            loss = loss_function(output_transpose, decoder_desired_output_train.long())
            opt.zero_grad()
            loss.backward()
            opt.step()
            train_losses.append(loss.item())
            if training_step % 20 == 0:
                print("Completed training step", training_step)
            training_step += 1

        for src_batch, tgt_batch in val_dataloader:
            src_tokens = torch.IntTensor([sequence.ids for sequence in src_batch]).to(
                device
            )
            encoder_input = src_tokens
            val_tgt_tokens = torch.IntTensor(
                [sequence.ids for sequence in tgt_batch]
            ).to(device)
            decoder_input = val_tgt_tokens[:, :-1]
            decoder_desired_output_val = val_tgt_tokens[:, 1:]
            src_masks = torch.IntTensor(
                [sequence.attention_mask for sequence in src_batch]
            ).to(device)
            tgt_masks = torch.IntTensor(
                [sequence.attention_mask for sequence in tgt_batch]
            )[:, :-1].to(device)
            val_output = encoder_decoder_transformer(
                encoder_input, decoder_input, src_masks, tgt_masks
            )
            output_transpose = val_output.transpose(
                -1, -2
            )  # output needs to be N, C, other dimension for torch cross entropy
            loss = loss_function(output_transpose, decoder_desired_output_val.long())
            val_losses.append(loss.item())
            if validation_step % 20 == 0:
                print("Completed validation step", validation_step)
            validation_step += 1

        print("epoch", i, "took", time.time() - epoch_time_start)
        print("avg training loss:", sum(train_losses) / len(train_losses))
        avg_val_loss = sum(val_losses) / len(val_losses)
        print("avg validation loss:", avg_val_loss)
        expected_train_output = tgt_tokenizer.decode(
            decoder_desired_output_train[0].tolist()
        )
        print("expected train output", expected_train_output)
        decoded_output = decode(train_output[0], tgt_tokenizer)
        print("decoded train output:", decoded_output)
        expected_val_output = tgt_tokenizer.decode(
            decoder_desired_output_val[0].tolist()
        )
        print("expected validation output", expected_val_output)
        decoded_output = decode(val_output[0], tgt_tokenizer)
        print("decoded validation output:", decoded_output)
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(encoder_decoder_transformer.state_dict(), state_dict_filename)
            print("Saved model state dict to", state_dict_filename)
            num_fails = 0
        else:
            print("Average validation loss did not decrease from ", best_val_loss)
            num_fails += 1
            print("Failed to decrease the average validation loss", num_fails, "times.")
            if num_fails >= 2:
                print("Stopping training")
                break
        print()
        print()


In [None]:
def main():
    folder = "drive/MyDrive/colab data/"
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    num_encoder_layers = 4
    num_decoder_layers = 4
    num_heads = 8
    d_attn = 256
    d_x = 256
    d_z = 256
    d_out = 256
    d_mid = 256
    d_mlp = 512
    d_e = 256
    max_sequence_length = 100
    p_dropout = 0.1
    enRawName = folder + "/multi30kEnTrain.txt"
    deRawName = folder + "/multi30kDeTrain.txt"
    saveDirectory = "./"
    nameSuffix = ""
    state_dict_filename = (
        saveDirectory
        + "encoder_decoder_transformer_state_dict_"
        + datetime.today().strftime("%Y-%m-%d %H")
        + nameSuffix
    )
    tensor = torch.tensor([1, 2, 3])
    tensor.float()
    vocab_size = 10000
    train_and_validation_sequence_datasets = TrainAndValidationSequenceDatasets(
        enRawName, deRawName, vocab_size, vocab_size, 0, 28250, 28250, 29000
    )
    custom_encoder_decoder_transformer = EncoderDecoderTransformer(
        num_encoder_layers,
        num_decoder_layers,
        num_heads,
        d_attn,
        d_x,
        d_z,
        d_out,
        d_mid,
        d_mlp,
        d_e,
        vocab_size,
        max_sequence_length,
        p_dropout,
        device
    ).to(device)
    custom_encoder_decoder_transformer.src_embedding.table = custom_encoder_decoder_transformer.src_embedding.table.to(device)
    custom_encoder_decoder_transformer.tgt_embedding.table = custom_encoder_decoder_transformer.tgt_embedding.table.to(device)
    custom_encoder_decoder_transformer.positionalEmbedding.table = custom_encoder_decoder_transformer.positionalEmbedding.table.to(device)
    train_dataset = train_and_validation_sequence_datasets.train_dataset
    val_dataset = train_and_validation_sequence_datasets.val_dataset
    pad_collate = PadCollate(enRawName, deRawName, vocab_size, vocab_size)
    train_dataloader = DataLoader(train_dataset, batch_size=256, collate_fn=pad_collate)
    val_dataloader = DataLoader(val_dataset, batch_size=256, collate_fn=pad_collate)
    train_model(
        custom_encoder_decoder_transformer,
        train_dataloader,
        val_dataloader,
        pad_collate.src_tokenizer,
        pad_collate.tgt_tokenizer,
        device,
        state_dict_filename
    )


if __name__ == "__main__":
    main()


creating tokenizer for drive/MyDrive/colab data//multi30kEnTrain.txt
creating tokenizer for drive/MyDrive/colab data//multi30kDeTrain.txt
<generator object Module.parameters at 0x78093ba04430>
Completed training step 0
Completed training step 20
Completed training step 40
Completed training step 60
Completed training step 80
Completed training step 100
Completed validation step 0
epoch 0 took 25.332783699035645
avg training loss: 7.700793476792069
avg validation loss: 6.781032244364421
expected train output Ein schwarz - rot - weißes Rennwagen saust im Vordergrund auf einer grauen Strecke mit einer blauen Ban de , im Hintergrund ist eine verschwommen e Menschenmenge zu sehen .
argmax x: [109, 124, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 14, 14, 14, 14, 14, 14, 1, 14, 14, 1, 14, 1, 14, 14, 1, 1, 14, 1, 1, 1, 1, 1, 14, 1, 14, 14, 1, 1, 1, 1, 14, 1, 1]
decoded train output: Ein Mann , , , , , , , , , , , , , , . . . . . . . . . . . . . . . .
expected validation output Zwei

In [None]:
def predict_from_tokens(model, input, src_tokenizer, tgt_tokenizer):
    model.disable_subsequent_mask()
    src_tokenizer.no_padding()
    tgt_tokenizer.no_padding()

    src_tokenizer.no_truncation()
    tgt_tokenizer.no_truncation()
    src_sequence = input
    print(src_sequence)
    src_sequence = src_tokenizer.encode(src_sequence)
    print(src_sequence)
    print(src_tokenizer.decode(src_sequence.ids))
    src_sequence = torch.IntTensor(src_sequence.ids).unsqueeze(0).to(device)
    print("src tokens", src_sequence)
    tgt_sequence = torch.IntTensor([0]).unsqueeze(0).to(device)
    src_mask = torch.ones(src_sequence.shape, dtype=torch.int32).to(device)
    print("decoder input", tgt_sequence)
    predictions = []
    with torch.no_grad():
        model.eval()
        length_gen = 100
        for i in range(length_gen):
            tgt_mask = torch.ones(tgt_sequence.shape, dtype=torch.int32).to(device)
            prediction = model(src_sequence, tgt_sequence, src_mask, tgt_mask)
            #print("prediction:", prediction)
            prediction = torch.softmax(prediction, -1)
            #print("softmax prediction:", prediction.shape)
            prediction = torch.argmax(prediction, dim=-1)
            print("argmax prediction:", prediction)
            print("actual prediction:", tgt_tokenizer.decode(prediction[0].tolist()))
            last_token = prediction[0][-1]
            tgt_sequence = torch.cat((tgt_sequence, last_token.unsqueeze(0).unsqueeze(0)), dim=-1)
            if last_token == 1:
                break
    return tgt_sequence

tgt_sequence = predict_from_tokens(encoder_decoder_transformer, "A man on the sea", pad_collate.src_tokenizer, pad_collate.tgt_tokenizer)
print(tgt_sequence)
print(pad_collate.tgt_tokenizer.decode(tgt_sequence[0].tolist()))

NameError: name 'encoder_decoder_transformer' is not defined