In [65]:
import torch
import torch.nn as nn

class Attention(nn.Module):
  def __init__(self,embeddingSize, headsCount):
    super(Attention, self).__init__()
    self.embeddingSize = embeddingSize
    self.headsCount = headsCount

    assert (embeddingSize%headsCount == 0), "Embedding Size need to be divisible by the count of heads"
    
    self.headDim = embeddingSize // headsCount
    self.V = nn.Linear(self.headDim, self.headDim, bias = False)
    self.K = nn.Linear(self.headDim, self.headDim, bias = False)
    self.Q = nn.Linear(self.headDim, self.headDim, bias = False)
    self.linear = nn.Linear(embeddingSize, embeddingSize)
  
  def forward(self, V, K, Q, mask):
    
    V = V.reshape(V.shape[0], V.shape[1], self.headsCount, self.headDim)
    K = K.reshape(K.shape[0], K.shape[1], self.headsCount, self.headDim)
    Q = Q.reshape(Q.shape[0], Q.shape[1], self.headsCount, self.headDim)

    V = self.V(V)
    K = self.K(K)
    Q = self.Q(Q)

    energy = torch.einsum("nqhd,nkhd->nhqk", [Q, K])

    if mask is not None:
      energy = energy.masked_fill(mask == 0, float("-inf"))
    
    attention = torch.softmax(energy/(self.embeddingSize ** 1/2), dim = 3)
    
    valueWeights = torch.einsum("nhqk,nkhd->nqhd",[attention, V])
    valueWeights = valueWeights.reshape(valueWeights.shape[0], valueWeights.shape[1], valueWeights.shape[2]*valueWeights.shape[3])
    
    output = self.linear(valueWeights)
    return output
    

In [66]:
class TransformerUnit(nn.Module):
  def __init__(self, embeddingSize, headsCount, dropout, factor):
    super(TransformerUnit, self).__init__()
    self.attention = Attention(embeddingSize, headsCount)
    self.normalization1 = nn.LayerNorm(embeddingSize)
    self.normalization2 = nn.LayerNorm(embeddingSize)

    self.feedForward = nn.Sequential(
        nn.Linear(embeddingSize, factor*embeddingSize),
        nn.ReLU(),
        nn.Linear(factor*embeddingSize, embeddingSize)
    )
    self.dropout = nn.Dropout(dropout)

  def forward(self, V, K, Q, mask):
    attention = self.attention(V, K, Q, mask)
    feed = self.dropout(self.normalization1(attention + Q))
    forward = self.feedForward(feed)
    output = self.dropout(self.normalization2(forward + feed))
    return output


In [67]:
class Encoder(nn.Module):

  def __init__(self, vocabLength, embeddingSize, layerCount, headsCount, deviceType, factor, dropout, max):
    super(Encoder, self).__init__()
    self.embeddingSize = embeddingSize
    self.deviceType = deviceType
    self.wordEmbedding = nn.Embedding(vocabLength, embeddingSize)
    self.position = nn.Embedding(max, embeddingSize)
    self.layers = nn.ModuleList(
        [
         TransformerUnit(embeddingSize, headsCount, dropout, factor) for _ in range(layerCount)
        ]
    )
    self.dropout = nn.Dropout(dropout)
  
  def forward(self, x, mask):
    n, seqLen = x.shape
    pos = torch.arange(0, seqLen).expand(n, -1).to(self.deviceType)
    output = self.dropout(self.wordEmbedding(x) + self.position(pos))

    for layer in self.layers:
      output = layer(output, output, output, mask)

    return output


In [68]:
class DecoderUnit(nn.Module):
  def __init__(self, embeddingSize, headsCount, deviceType, factor, dropout):
    super(DecoderUnit, self).__init__()
    self.attention = Attention(embeddingSize, headsCount)
    self.norm = nn.LayerNorm(embeddingSize)
    self.transformer = TransformerUnit(embeddingSize, headsCount, dropout, factor)
    self.dropout = nn.Dropout(dropout)
    
  def forward(self, x, V, K, sourceMask, targetMask):
    attention = self.attention(x, x, x, targetMask)
    Q = self.dropout(self.norm(x + attention))
    output = self.transformer(V, K, Q, sourceMask)
    return output

In [69]:
class Decoder(nn.Module):
  def __init__(self, vocabLength, embeddingSize, layerCount, headsCount, deviceType, factor, dropout, max):
    super(Decoder, self).__init__()
    self.deviceType = deviceType
    self.wordEmbedding = nn.Embedding(vocabLength, embeddingSize)
    self.position = nn.Embedding(max, embeddingSize)
    self.layers = nn.ModuleList([
                                 DecoderUnit(embeddingSize, headsCount, deviceType, factor, dropout) for _ in range(layerCount)
    ])

    self.linear = nn.Linear(embeddingSize, vocabLength)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, encoderOutput, sourceMask, targetMask):
    n, seqLen = x.shape
    pos = torch.arange(0, seqLen).expand(n, -1).to(self.deviceType)
    output = self.dropout(self.wordEmbedding(x) + self.position(pos))
    
    for layer in self.layers:
      output = layer(output, encoderOutput, encoderOutput, sourceMask, targetMask)
    
    output = self.linear(output)
    return output

In [70]:
class Transformer(nn.Module):
  def __init__(
      self,
      sourceVocab,
      targetVocab,
      sourcePaddingIdx,
      targetPaddingIdx,
      embeddingSize = 256,
      layerCount = 4,
      factor = 4,
      headsCount = 4,
      dropout = 0,
      deviceType = "cuda",
      max = 100,
  ):
    super(Transformer, self).__init__()
    self.encoder = Encoder(sourceVocab, embeddingSize, layerCount, headsCount, deviceType, factor, dropout, max)
    self.decoder = Decoder(targetVocab, embeddingSize, layerCount, headsCount, deviceType, factor, dropout, max)
    self.sourcePaddingIdx = sourcePaddingIdx
    self.targetPaddingIdx = targetPaddingIdx
    self.deviceType = deviceType

  def getSourceMask(self, source):
    return (source != self.sourcePaddingIdx).unsqueeze(1).unsqueeze(2).to(self.deviceType)
  def getTargetMask(self, target):
    n, len = target.shape
    return torch.tril(torch.ones(len, len)).expand(n, 1, len, len).to(self.deviceType)
    
  def forward(self, source, target):
    sourceMask = self.getSourceMask(source)
    targetMask = self.getTargetMask(target)
    encOut = self.encoder(source, sourceMask)
    decOut = self.decoder(target, encOut, sourceMask, targetMask)
    return decOut

In [72]:
device = torch.device("cuda") #set to cpu if cuda not available
x = torch.tensor([[1,2,3,4,5,6,7], [1,2,3,4,5, 0, 0]]).to(device)
y = torch.tensor([[1,2,3,4,5], [5,4,3,2,1]]).to(device)
model = Transformer(8, 8, 0, 0).to(device)
out = model(x, y[:,:-1])
print(out)

tensor([[[ 0.4069, -0.1897,  0.0228, -1.2993, -0.4289,  0.3691,  0.0747,
          -0.6761],
         [ 0.3462,  0.2732, -0.8237,  0.6872,  0.2129, -0.5731,  0.5943,
          -0.2611],
         [ 0.2636, -0.6270, -0.4473, -0.7676, -0.4255,  0.4864,  0.2619,
           0.0401],
         [-0.5796, -0.6556,  0.3651, -0.7396,  0.0257, -0.9096, -0.3184,
          -0.5247]],

        [[ 0.0892,  0.0401,  0.2288, -1.1982,  0.5981, -0.9803, -0.6447,
          -0.3452],
         [-0.5032,  0.1254,  0.0109, -0.5198,  1.1096, -1.3209, -0.1106,
          -0.6891],
         [-0.0353, -0.2241, -0.2844, -1.1636, -0.0219,  0.3274, -0.0452,
          -0.2564],
         [-0.4210,  0.2478,  0.1842, -0.2557, -0.2561, -0.5669, -0.0486,
          -0.5985]]], device='cuda:0', grad_fn=<AddBackward0>)
