In [1]:
import math
import copy
import random
from typing import Optional


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [3]:
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [4]:
def attention(query, key, value, mask: Optional[torch.Tensor]=None, dropout: Optional[nn.Module]=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        # mask: 1 for allowed positions, 0 for masked
        scores = scores.masked_fill(mask == 0, float('-inf'))
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

In [5]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        if mask is not None:
            # same mask applied to all heads
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1,2)
                             for l, x in zip(self.linears, (query, key, value))]

        # 2) Apply attention on all the projected vectors in batch
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear
        x = x.transpose(1,2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)


In [6]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [7]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff=2048, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))

In [8]:
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

In [9]:
class SublayerConnection(nn.Module):
    """A residual connection followed by layer norm."""
    def __init__(self, size, dropout):
        super().__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        """Apply residual connection to any sublayer with the same size."""
        return x + self.dropout(sublayer(self.norm(x)))


In [10]:
class EncoderLayer(nn.Module):
    def __init__(self, size, self_attn, feed_forward, dropout):
        super().__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x, mask):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

In [11]:
class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super().__init__()
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)
        self.size = size

    def forward(self, x, memory, src_mask, tgt_mask):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, memory, memory, src_mask))
        return self.sublayer[2](x, self.feed_forward)

In [12]:
class Encoder(nn.Module):
    def __init__(self, layer, N):
        super().__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [13]:
class Decoder(nn.Module):
    def __init__(self, layer, N):
        super().__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)

In [14]:
class Transformer(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model=512, N=6, h=8, d_ff=2048, dropout=0.1):
        super().__init__()
        self.src_embed = nn.Sequential(nn.Embedding(src_vocab, d_model), PositionalEncoding(d_model, dropout))
        self.tgt_embed = nn.Sequential(nn.Embedding(tgt_vocab, d_model), PositionalEncoding(d_model, dropout))

        attn = MultiHeadedAttention(h, d_model, dropout)
        ff = PositionwiseFeedForward(d_model, d_ff, dropout)
        encoder_layer = EncoderLayer(d_model, copy.deepcopy(attn), copy.deepcopy(ff), dropout)
        decoder_layer = DecoderLayer(d_model, copy.deepcopy(attn), copy.deepcopy(attn), copy.deepcopy(ff), dropout)

        self.encoder = Encoder(encoder_layer, N)
        self.decoder = Decoder(decoder_layer, N)
        self.out = nn.Linear(d_model, tgt_vocab)

        # Initialize parameters
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)

    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

    def forward(self, src, tgt, src_mask, tgt_mask):
        memory = self.encode(src, src_mask)
        dec = self.decode(memory, src_mask, tgt, tgt_mask)
        return self.out(dec)

In [15]:
def subsequent_mask(size):
    """Mask out subsequent positions."""
    attn_shape = (1, size, size)
    subsequent = torch.triu(torch.ones(attn_shape), diagonal=1).type(torch.uint8)
    return subsequent == 0

In [16]:
def make_std_mask(tgt, pad):
    tgt_mask = (tgt != pad).unsqueeze(-2)
    tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)).type_as(tgt_mask)
    return tgt_mask

In [17]:
class CopyDataset(Dataset):
    def __init__(self, vocab_size=11, seq_len=10, size=10000, pad=0):
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.size = size
        self.pad = pad
        self.data = [self._sample() for _ in range(size)]

    def _sample(self):
        # random sequence excluding pad (0) and special tokens
        seq = [random.randint(1, self.vocab_size-1) for _ in range(self.seq_len)]
        return torch.tensor(seq, dtype=torch.long), torch.tensor(seq, dtype=torch.long)

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return self.data[idx]

In [18]:
class LabelSmoothing(nn.Module):
    def __init__(self, size, padding_idx, smoothing=0.0):
        super().__init__()
        self.criterion = nn.KLDivLoss(reduction='sum')
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size

    def forward(self, x, target):
        # x: (batch*seq_len, vocab) - log probabilities assumed
        assert x.size(1) == self.size
        true_dist = x.data.clone()
        true_dist.fill_(self.smoothing / (self.size - 2))
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        true_dist[:, self.padding_idx] = 0
        mask = torch.nonzero(target.data == self.padding_idx)
        if mask.dim() > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0)
        return self.criterion(x, true_dist)

In [19]:
def run_transformer():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    vocab_size = 12
    pad_idx = 0
    seq_len = 8

    model = Transformer(src_vocab=vocab_size, tgt_vocab=vocab_size, d_model=128, N=3, h=4, d_ff=256, dropout=0.1).to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    dataset = CopyDataset(vocab_size=vocab_size, seq_len=seq_len, size=5000, pad=pad_idx)
    loader = DataLoader(dataset, batch_size=64, shuffle=True)

    model.train()
    for epoch in range(1, 6):
        total_loss = 0
        for i, (src, tgt) in enumerate(loader):
            src = src.to(device)
            tgt = tgt.to(device)
            # prepare input and output
            tgt_input = torch.cat([torch.full((tgt.size(0),1), 1, dtype=torch.long, device=device), tgt[:,:-1]], dim=1)
            src_mask = (src != pad_idx).unsqueeze(-2)
            tgt_mask = make_std_mask(tgt_input, pad_idx).to(device)

            out = model(src, tgt_input, src_mask.to(device), tgt_mask)
            loss = criterion(out.view(-1, out.size(-1)), tgt.view(-1))

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()
        print(f"Epoch {epoch} Loss: {total_loss/len(dataset):.4f}")

    model.eval()
    sample_src, sample_tgt = dataset[0]
    src = sample_src.unsqueeze(0).to(device)
    generated = [1]
    for _ in range(seq_len):
        tgt_in = torch.tensor(generated, dtype=torch.long, device=device).unsqueeze(0)
        tgt_mask = make_std_mask(tgt_in, pad_idx).to(device)
        out = model(src, tgt_in, (src!=pad_idx).unsqueeze(-2).to(device), tgt_mask)
        next_word = out.argmax(dim=-1)[0, -1].item()
        generated.append(next_word)
    print("Source:", sample_src.tolist())
    print("Target:", sample_tgt.tolist())
    print("Generated:", generated[1:])

In [20]:
run_transformer()

Epoch 1 Loss: 0.0378
Epoch 2 Loss: 0.0165
Epoch 3 Loss: 0.0027
Epoch 4 Loss: 0.0016
Epoch 5 Loss: 0.0009
Source: [1, 9, 4, 5, 10, 3, 3, 7]
Target: [1, 9, 4, 5, 10, 3, 3, 7]
Generated: [1, 9, 4, 5, 10, 3, 3, 7]
