1. Positional Encoding

In [None]:
import torch
import math

def positional_encoding(seq_len, d_model):
    pe = torch.zeros(seq_len, d_model)
    position = torch.arange(0, seq_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe


2. Self-Attention

In [None]:
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    weights = F.softmax(scores, dim=-1)
    return torch.matmul(weights, V), weights


3. Multi-Head Attention

In [None]:
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model):
        super().__init__()
        self.heads = heads
        self.d_k = d_model // heads
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):
        B, T, _ = q.size()
        q = self.q_linear(q).view(B, T, self.heads, self.d_k).transpose(1,2)
        k = self.k_linear(k).view(B, T, self.heads, self.d_k).transpose(1,2)
        v = self.v_linear(v).view(B, T, self.heads, self.d_k).transpose(1,2)

        out, _ = scaled_dot_product_attention(q, k, v, mask)
        out = out.transpose(1,2).contiguous().view(B, T, self.heads * self.d_k)
        return self.out(out)


4. Feedforward Networks

In [None]:
class PositionwiseFeedforward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(F.relu(self.linear1(x)))


English-to-French Translation using nn.Transformer

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerTranslator(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, nhead=8, num_layers=6):
        super().__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = positional_encoding(100, d_model)

        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead, num_encoder_layers=num_layers,
                                          num_decoder_layers=num_layers)
        self.out = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt):
        src_emb = self.encoder_embedding(src) + self.positional_encoding[:src.size(1), :]
        tgt_emb = self.decoder_embedding(tgt) + self.positional_encoding[:tgt.size(1), :]
        src_emb = src_emb.transpose(0, 1)  # (seq_len, batch, d_model)
        tgt_emb = tgt_emb.transpose(0, 1)

        output = self.transformer(src_emb, tgt_emb)
        return self.out(output)
