In [None]:
import torch
import torch.nn as nn
from torch.utils.data import dataset
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from typing import Tuple
import math

In [None]:
train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

In [None]:
def data_process(raw_text_iter: dataset.IterableDataset) -> torch.Tensor:
    """Converts raw text into a flat Tensor."""
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

In [None]:
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def batchify(data: torch.Tensor, bsz: int) -> torch.Tensor:
    """Divides the data into bsz separate sequences, removing extra elements
    that wouldn't cleanly fit.

    Args:
        data: Tensor, shape [N]
        bsz: int, batch size

    Returns:
        Tensor of shape [N // bsz, bsz]
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)  # shape [seq_len, batch_size]
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

In [None]:
bptt = 35
def get_batch(source: torch.Tensor, i: int) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Args:
        source: Tensor, shape [full_seq_len, batch_size]
        i: int

    Returns:
        tuple (data, target), where data has shape [seq_len, batch_size] and
        target has shape [seq_len * batch_size]
    """
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target


In [None]:
data, target = get_batch(train_data, 0)
print(data.shape, target.shape)

torch.Size([35, 20]) torch.Size([700])


--- Start of Transformer ---

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self):
        super().__init__()
        pass
    def forward(self, X):
        pass

In [None]:
class LayerNorm(nn.Module):
    def __init__(self):
        super().__init__()
        pass
    def forward(self, X):
        pass

In [None]:
class AddNorm(nn.Module):
    def __init__(self, norm_shape, dropout):
        super().__init__()
        self.layer_norm = LayerNorm(norm_shape)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, X, Y):
        return self.layer_norm(self.dropout(Y) + X)

In [None]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout, num_heads):
        super().__init__()
        self.dropout = nn.Dropout()
        self.num_heads = num_heads

    def forward(self, q, k, v, mask=None):
        d_dim = q.shape[-1]
        k_t = k.transpose(2, 3)
        score = q.matmul(k_t) / math.sqrt(d_dim)
        if mask is not None:
            score = score.masked_fill(mask, -e)
        out = nn.softmax(score)
        v_matmul = out.mamtul(v)
        return v_matmul

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, d_model):
        super().__init__()
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_q = nn.Linear(d_model, d_model)
        self.w_out = nn.Linear(d_model, d_model)
        self.attention = ScaledDotProductAttention(dropout, num_heads)
        
    def forward(self, q, k, v mask=None):
        # Dot product with weights + split between attention heads 
        q = self.split(self.w_q(q))
        k = self.split(self.w_k(k))
        v = self.split(self.w_v(v))

        out = self.attention(q, k, v, mask=mask)
        out = self.concat(out)
        return self.w_out(out)

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, num_heads, ffn_in, ffn_hidden, ffn_output):
        super().__init__()
        self.attention = MultiHeadAttention(num_heads, d_model)
        self.addnorm = AddNorm(ffn_in, ffn_hidden)
        self.ffn = nn.Sequential(nn.Linear(ffn_in, ffn_hidden),
                                 nn.ReLU(),
                                 nn.Linear(ffn_hidden, ffn_output))
        
    def forward(self, X):
        # 2. Attention + AddNorm
        Y = self.addnorm(X, self.attention(X, X, X, mask))
        # 4. PositionWise FFN
        ffn_out = self.ffn(Y)
        # 5. AddNorm
        return self.addnorm(Y, ffn_out)

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, num_blocks, num_heads, 
                 num_hiddens, ffn_in, ffn_hidden):
        super().__init__()
        self.pos_encoding = PositionalEncoding()
        self.embedding = nn.Embedding()
        self.blocks = nn.Sequential()
        for i in range(num_blocks):
            self.blocks.add_module("block"+str(i), EncoderBlock(ffn_in, ffn_hidden, ffn_output))
   
    def forward(self, X):
        # 1. Pos-Encoding + Embedding
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens)
        for blk in self.blocks:
            X = blk(X)
        return X

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, num_hiddens, ffn_hiddens, num_heads, dropout):
        super().__init__()
        self.attention1 = MultiHeadAttention(num_heads, dropout)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.attention2 = MultiHeadAttention(num_heads, dropout)
        self.addnorm2 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFfn()
        self.addnorm3 = AddNorm(norm_shape, dropout)

    def forward(self, dec, enc, trg_mask):
        # 1. Masked Self-Attention
        _x = dec
        x = self.attention1(dec, dec, dec, mask)
        # 2. AddNorm
        x = self.norm1(x, _x)
        # 3. Encoder-Decoder Attention
        _x = x
        x = self.attention2(x, enc, enc, mask)
        # 4. AddNorm
        x = self.norm2(x, _x)
        # 5. PositionWise FFN
        _x = x
        x = self.ffn(x)
        # 6. AddNorm
        out = self.addnorm3(x, _x)
        return out

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.blocks = nn.Sequential()
        for i in range(num_blocks):
            self.blocks.add_module("block"+str(i), DecoderBlock(ffn_in, ffn_hidden, ffn_output))
    def forward(self, X):
        pass

In [None]:
class Transformer(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = Encoder(num_hiddens, ffn_hiddens, num_heads, num_blks, dropout)
        self.decoder = Decoder(num_hiddens, ffn_hiddens, num_heads, num_blks, dropout)

    def forward(self, X):
        memory = self.encoder(src, mas)
        out = self.decoder()

In [None]:
net = Transformer()
criterion = nn.CrossEntropyloss()
optimizer = torch.optim.Adam(net.Parameters())

num_epochs = 100
for i in range(num_epochs).
fir 