In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision.datasets as dset
import torchvision.transforms as T
import torch.nn.functional as F

import numpy as np
import math

USE_GPU = True
dtype = torch.float32

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print_every = 100
print('using device:', device)

using device: cuda


In [19]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)

    def forward(self, x, context=None, mask=None):
        if context is None:
            context = x
        out, _ = self.attn(x, context, context, attn_mask=mask)
        return out
    
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        out = self.fc2(x)
        return out
    
class EncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.ff = FeedForward(d_model, d_ff)

    def forward(self, x):
        identity = x
        x = self.norm1(x)
        x = self.attn(x)
        x = x + identity

        identity = x
        x = self.norm2(x)
        x = self.ff(x)
        out = x + identity
        return out
    
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=64):
        super().__init__()
        pe = torch.zeros((1, max_len, d_model))
        i = torch.arange(max_len)[:, None]
        pows = torch.pow(10000, -torch.arange(0, d_model, 2) / d_model)

        pe[0, :, 0::2] = torch.sin(i * pows)
        pe[0, :, 1::2] = torch.cos(i * pows)

        self.register_buffer("pe", pe)

    def forward(self, x):
        N = x.size(1)
        pe = self.pe[:, :N, :]
        return x + pe
    
class Encoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff):
        super().__init__()
        self.pe = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([
            EncoderBlock(d_model, num_heads, d_ff)
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        x = self.pe(x)

        for layer in self.layers:
            x = layer(x)

        out = self.norm(x)
        return out
    
class DecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn1 = MultiHeadAttention(d_model, num_heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.attn2 = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x, enc_out, mask):
        identity = x
        x = self.norm1(x)
        x = self.attn1(x, mask=mask)
        x = x + identity

        identity = x
        x = self.norm2(x)
        x = self.attn2(x, context=enc_out)
        x = x + identity

        identity = x
        x = self.norm3(x)
        x = self.ff(x)
        out = x + identity
        return out
    
class Decoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff):
        super().__init__()
        self.pe = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([
            DecoderBlock(d_model, num_heads, d_ff)
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, enc_out):
        seq_len = x.size(1)
        x = self.pe(x)

        mask = self.casual_mask(seq_len, x.device)

        for layer in self.layers:
            x = layer(x, enc_out, mask)

        out = self.norm(x)
        return out
    
    def casual_mask(self, size, device):
        mask = torch.triu(torch.ones(size, size, device=device), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask
    
class Transformer(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff, d_out):
        super().__init__()
        self.encoder = Encoder(num_layers, d_model, num_heads, d_ff)
        self.decoder = Decoder(num_layers, d_model, num_heads, d_ff)
        self.fc = nn.Linear(d_model, d_out)

    def forward(self, src, tgt):
        enc_out = self.encoder(src)
        dec_out = self.decoder(tgt, enc_out)
        out = self.fc(dec_out)
        return out

In [20]:
B = 2
src_len = 5
tgt_len = 4
d_model = 16
d_out = 10

src = torch.randn(B, src_len, d_model)
tgt = torch.randn(B, tgt_len, d_model)

model = Transformer(num_layers=2, d_model=d_model, num_heads=4, d_ff=64, d_out=d_out)

out = model(src, tgt)
print(out.shape)

torch.Size([2, 4, 10])


In [15]:
class ReverseDataset(torch.utils.data.Dataset):
    def __init__(self, num_samples=1000, seq_len=8, d_model=16):
        self.data = torch.randn(num_samples, seq_len, d_model)
        self.targets = torch.flip(self.data, dims=[1])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]
    
dataset = ReverseDataset(num_samples=2000, seq_len=8, d_model=d_model)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)


In [26]:
def train(model, optimizer, criterion, epochs):
    model = model.to(device=device)
    for e in range(epochs):
        total_loss = 0
        for (src, tgt_out) in loader:
            bos = torch.zeros(src.size(0), 1, d_model)
            tgt_in = torch.cat([bos, tgt_out[:, :-1, :]], dim=1)

            pred = model(src, tgt_in)
            loss = criterion(pred, tgt_out)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {e+1}, loss={total_loss/len(loader):.4f}")


In [27]:
B = 2
src_len = 5
tgt_len = 4
d_model = 16
d_out = 16

src = torch.randn(B, src_len, d_model)
tgt = torch.randn(B, tgt_len, d_model)
device = torch.device('cpu')

model = Transformer(num_layers=2, d_model=d_model, num_heads=4, d_ff=64, d_out=d_out)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

train(model, optimizer, criterion, 10)

Epoch 1, loss=1.2791
Epoch 2, loss=1.1713
Epoch 3, loss=1.0997
Epoch 4, loss=1.0588
Epoch 5, loss=1.0381
Epoch 6, loss=1.0282
Epoch 7, loss=1.0205
Epoch 8, loss=1.0159
Epoch 9, loss=1.0124
Epoch 10, loss=1.0096
