In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
def scaled_dot_product(query, key, value):
  temp = query.bmm(key.transpose(1,2))
  scale = query.size(-1) ** 0.5
  softmax =  F.softmax(temp/scale, dim=-1)
  return softmax.bmm(value)

In [None]:
class AttentionHead(nn.Module):
  def __init__(self, dim_in, dim_q, dim_k):
    super().__init__()
    self.q = nn.Linear(dim_in, dim_q)
    self.k = nn.Linear(dim_in, dim_k)
    self.v = nn.Linear(dim_in, dim_k)

  def forward(self, query, key, value):
    return scaled_dot_product(self.q(query), self.k(key), self.value(v))

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, num_heads, dim_in, dim_q, dim_k):
    super().__init__()
    self.heads = nn.ModuleList(
        [AttentionHead(dim_in, dim_q, dim_k) for _ in range(num_heads)]
    )
    self.linear = nn.Linear(num_heads*dim_k, dim_in)

  def forward(self, query, key, value):
    return self.linear(
        torch.cat([h(query, key, value) for h in self.heads], dim=-1)
    )

In [None]:
def position_encoding(seq_len, dim_model, device = torch.device('cpu')):
  pos = torch.arange(seq_len, dtype=torch.float, device=device).reshape(1,-1,1)
  dim = torch.arange(dim_model, dtype=torch.float, device=device).reshape(1,1,-1)
  phase = pos/(1e4 ** (dim/dim_model))

  return torch.where(dim.long()%2==0, torch.sin(phase), torch.cos(phase))

In [None]:
a = torch.tensor([[1,2],[3,4]])
b = torch.tensor([[5,6],[7,8]])

In [None]:
torch.cat([a,b], dim=-1)

tensor([[1, 2, 5, 6],
        [3, 4, 7, 8]])

In [None]:
def feed_forward(dim_input=512, dim_feedforward=2048):
  print(dim_input, dim_feedforward)
  return nn.Sequential(
      nn.Linear(dim_input, dim_feedforward),
      nn.ReLU(),
      nn.Linear(dim_feedforward, dim_input)
  )

In [None]:
class Residual(nn.Module):
  def __init__(self, sublayer, dimension, dropout=0.1):
    super().__init__()
    self.sublayer = sublayer
    self.norm = nn.LayerNorm(dimension)
    self.dropout = nn.Dropout(dropout)

  def forward(self, *tensors):
    return self.norm(tensors[0] + self.dropout(self.sublayer(*tensors)))

In [None]:
class TransformerEncoderLayer(nn.Module):
  def __init__(self, dim_model = 512, num_heads = 6, dim_feedforward = 2048, dropout = 0.1):
    super().__init__()
    dim_q = dim_k = max(dim_model // num_heads, 1)
    self.attention = Residual(
        MultiHeadAttention(num_heads, dim_model, dim_q, dim_k),
        dimension=dim_model,
        dropout=dropout
    )
    self.feed_forward = Residual(
        feed_forward(dim_model, dim_feedforward),
        dimension=dim_model,
        dropout=dropout
        )
  def forward(self, src):
    src = self.attention(src, src, src)
    return self.feed_forward(src)

In [None]:
class TransformerEncoder(nn.Module):
  def __init__(self, num_layers=6, dim_model=512, num_heads=8, dim_feedforward=2048, dropout=0.1):
    super().__init__()
    self.layers = nn.ModuleList(
        [
            TransformerEncoderLayer(dim_model, num_heads, dim_feedforward, dropout)
            for _ in range(num_layers)
        ]
    )
  def forward(self, src):
    seq_len, dimension = src.size(1), src.size(2)
    src += position_encoding(seq_len, dimension)
    for layer in self.layers:
      src = layer(src)

    return src

In [None]:
class TransformerDecoderLayer(nn.Module):
  def __init__(self, dim_model=512, num_heads=6, dim_feedforward=2048,
               dropout=0.1):
    super().__init__()
    dim_q = dim_k = max(dim_model // num_heads, 1)
    self.attention_1 = Residual(
        MultiHeadAttention(num_heads, dim_model, dim_q, dim_k),
        dimension=dim_model,
        dropout=dropout

    )
    self.attention_2 = Residual(
            MultiHeadAttention(num_heads, dim_model, dim_q, dim_k),
            dimension=dim_model,
            dropout=dropout

        )
    self.feed_forward = Residual(
        feed_forward(dim_model, dim_feedforward),
        dimension=dim_model,
        dropout=dropout
        )

    def forward(self, tgt, memory):
      tgt = self.attention_1(tgt, tgt, tgt)
      tgt = self.attention_2(tgt, memory, memory)
      return self.feed_forward(tgt)

In [None]:
class TransformerDecoder(nn.Module):
  def __init__(self, num_layers=6, dim_model=512, num_heads=8, dim_feedforward=2048, dropout=0.1):
    super().__init__()
    self.layers = nn.ModuleList(
        [
            TransformerDecoderLayer(dim_model, num_heads, dim_feedforward, dropout)
            for _ in range(num_layers)
         ]
    )
    self.linear = nn.Linear(dim_model, dim_model)
  def forward(self, tgt, memory):
    seq_len, dimension = tgt.size(1), tgt.size(2)
    tgt += position_encoding(seq_len, dimension)
    for layer in self.layers:
      tgt = layer(tgt,memory)

    return torch.softmax(self.linear(tgt), sim=-1)

In [None]:
class Transformer(nn.Module):
  def __init__(self, num_encoder_layers=6, num_decoder_layers = 6, dim_model=6, num_heads=6, dim_feedforward=2048, dropout=0.1,
               activation = nn.ReLU()):
    super().__init__()
    self.encoder = TransformerEncoder(
        num_layers = num_encoder_layers,
        dim_model = dim_model,
        num_heads = num_heads,
        dim_feedforward = dim_feedforward,
        dropout = dropout
    )

    self.decoder = TransformerDecoder(
      num_layers = num_encoder_layers,
      dim_model = dim_model,
      num_heads = num_heads,
      dim_feedforward = dim_feedforward,
      dropout = dropout
  )

  def forward(self, src, tgt):
    return self.decoder(tgt, self.encoder(src))

In [1]:
src = torch.rand(64,32,512)
tgt = torch.rand(64,16,512)
out = Transformer()(src, tgt)