In [7]:
import torch
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

tensor([1.], device='mps:0')


In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math, copy

In [6]:
class MultiHeadAttention(nn.Module):
  def __init__(self, dataModel, numHeads):
    super(MultiHeadAttention, self).__init__()
    assert dataModel % numHeads == 0, "dataModel must be divisible by numHeads"

    self.dataModel = dataModel
    self.numHeads = numHeads
    self.dimK = dataModel // numHeads # int division
    # Weights for Q, K, V
    self.weightQ = nn.Linear(dataModel, dataModel)
    self.weightK = nn.Linear(dataModel, dataModel)
    self.weightV = nn.Linear(dataModel, dataModel)
    self.weightO = nn.Linear(dataModel, dataModel)

  def scaledDotProductAttention(self, Q, K, V, mask=None):
    attentionScores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.dimK)
    if mask is not None:
      attentionScores = attentionScores.masked_fill(mask, -1e9)
    attentionProbabilities = nn.functional.softmax(attentionScores, dim=-1)
    output = torch.matmul(attentionProbabilities, V)
    return output
  
  def splitHeads(self, x, batchSize):
    batchSize, seqLen, _ = x.size()
    return x.view(batchSize, seqLen, self.numHeads, self.dimK).transpose(1, 2)
  
  def combineHeads(self, x):
    batchSize, _, seqLen, _ = x.size()
    return x.transpose(1, 2).contiguous().view(batchSize, seqLen, self.dataModel)
  
  def forward(self, Q, K, V, mask=None):
    Q = self.splitHeads(self.weightQ(Q))
    K = self.splitHeads(self.weightK(K))
    V = self.splitHeads(self.weightV(V))
    attentionOutput = self.scaledDotProductAttention(Q, K, V, mask)
    output = self.weightO(self.combineHeads(attentionOutput))
    return output




In [7]:
class PositionWiseFeedForwardNN(nn.Module):
  def __init__(self, dataModel, dFF):
    super(PositionWiseFeedForwardNN, self).__init__()
    self.forwardConnected1 = nn.Linear(dataModel, dFF)
    self.forwardConnected2 = nn.Linear(dFF, dataModel)
    self.relu = nn.ReLU()

  def forward(self, x):
    return self.forwardConnected2(self.relu(self.forwardConnected1(x)))

In [8]:
class PositionalEncoding(nn.Module):
  def __init__(self, dataModel, maxSeqLen):
    super(PositionalEncoding, self).__init__()
    pe = torch.zeros(maxSeqLen, dataModel)
    position = torch.arange(start=0, end=maxSeqLen, step=1, dtype=torch.float32).unsqueeze(1) # Unsqueeze makes it into a column matrix
    embeddingIndex = torch.arange(start=0, end=dataModel, step=2, dtype=torch.float32)
    divTerm = 1 / torch.tensor(1e4)**(embeddingIndex/dataModel)

    pe[:, 0::2] = torch.sin(position * divTerm)
    pe[:, 1::2] = torch.cos(position * divTerm)
    self.register_buffer('pe', pe)
  
  def forward(self, x):
    return x + self.pe[:, :x.size(1)]

In [13]:
x = torch.tensor(1e4)
x.dtype

torch.float32

In [12]:
class EncoderLayer(nn.Module):
  def __init__(self, dataModel, numHeads, dFF, dropout):
    super(EncoderLayer, self).__init__()
    self.selfAttention = MultiHeadAttention(dataModel=dataModel, numHeads=numHeads)
    self.feedForward = PositionWiseFeedForwardNN(dataModel=dataModel, dFF=dFF)
    self.layerNorm1 = nn.LayerNorm(dataModel)
    self.layerNorm2 = nn.LayerNorm(dataModel)
    self.Dropout = nn.Dropout(dropout)

  def forward(self, x, mask):
    attentionOut = self.selfAttention(x, x, x, mask)
    x = self.layerNorm1(x + self.Dropout(attentionOut))
    feedForwardOut = self.feedForward(x)
    x = self.layerNorm2(x + self.Dropout(feedForwardOut))
    return x


In [13]:
class DecoderLayer(nn.Module):
  def __init__(self, dataModel, numHeads, dFF, dropout):
    super(DecoderLayer, self).__init__()
    self.selfAttention = MultiHeadAttention(dataModel=dataModel, numHeads=numHeads)
    self.crossAttention = MultiHeadAttention(dataModel=dataModel, numHeads=numHeads)
    self.feedForward = PositionWiseFeedForwardNN(dataModel=dataModel, dFF=dFF)
    self.layerNorm1 = nn.LayerNorm(dataModel)
    self.layerNorm2 = nn.LayerNorm(dataModel)
    self.layerNorm3 = nn.LayerNorm(dataModel)
    self.dropout = nn.Dropout(dropout)
  
  def forward(self, x, encoderOutput, srcMask, targetMask):
    attentionOut = self.selfAttention(x, x, x, targetMask)
    x = self.layerNorm1(x + self.dropout(attentionOut))
    crossAttentionOut = self.crossAttention(x, encoderOutput, encoderOutput, srcMask)
    x = self.layerNorm2(x + self.dropout(crossAttentionOut))
    feedForwardOut = self.feedForward(x)
    x = self.layerNorm3(x + self.dropout(feedForwardOut))
    return x

In [43]:
class Transformer(nn.Module):
  def __init__(self, sourceVocabSize, targetVocabSize, dataModel, numHeads, numLayers, dFF, maxSeqLen, dropout):
    super(Transformer, self).__init__()
    self.encoderEmbedding = nn.Embedding(sourceVocabSize, dataModel)
    self.decoderEmbedding = nn.Embedding(targetVocabSize, dataModel)
    self.positionalEncoding = PositionalEncoding(dataModel=dataModel, maxSeqLen=maxSeqLen)
    self.encoderLayers = nn.ModuleList([EncoderLayer(dataModel=dataModel, numHeads=numHeads, dFF=dFF, dropout=dropout) for _ in range(numLayers)])
    self.decoderLayers = nn.ModuleList([DecoderLayer(dataModel=dataModel, numHeads=numHeads, dFF=dFF, dropout=dropout) for _ in range(numLayers)])
    self.fc = nn.Linear(dataModel, targetVocabSize)
    self.dropout = nn.Dropout(dropout)

  def generateMask(self, src, target):
    srcMask = (src != 0).unsqueeze(1).unsqueeze(2)
    targetMask = (target != 0).unsqueeze(1).unsqueeze(2)
    sequenceLength = target.size(1)
    nopeakMask = (1 - torch.triu(torch.ones(sequenceLength, sequenceLength), diagonal=1)).bool()
    targetMask = targetMask & nopeakMask
    return srcMask, targetMask

  def forward(self, src, target):
    srcMask, targetMask = self.generateMask(src, target)
    srcEmbedded = self.dropout(self.positionalEncoding(self.encoderEmbedding(src)))
    targetEmbedded = self.dropout(self.positionalEncoding(self.decoderEmbedding(target)))

    encoderOut = srcEmbedded
    for encodedLayer in self.encoderLayers:
      encoderOut = encodedLayer(encoderOut, srcMask)
    decoderOut = targetEmbedded
    for decodedLayer in self.decoderLayers:
      decoderOut = decodedLayer(decoderOut, encoderOut, srcMask, targetMask)
    output = self.fc(decoderOut)
    return output

In [44]:
sourceVocabSize = 5000
targetVocabSize = 5000
dataModel = 512
numHeads = 8
numLayers = 6
dFF = 2**11
maxSequenceLength = 100
dropout = 0.1

In [45]:
transformer = Transformer(
  sourceVocabSize=sourceVocabSize,
  targetVocabSize=targetVocabSize,
  dataModel=dataModel,
  numHeads=numHeads,
  numLayers=numLayers,
  dFF=dFF,
  maxSeqLen=maxSequenceLength,
  dropout=dropout
  )

In [46]:
sourceData = torch.randint(1, sourceVocabSize, (64, maxSequenceLength))
targetData = torch.randint(1, targetVocabSize, (64, maxSequenceLength))

In [47]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=1e-4, betas=(0.9,0.98), eps=1e-9)

transformer.train()



Transformer(
  (encoderEmbedding): Embedding(5000, 512)
  (decoderEmbedding): Embedding(5000, 512)
  (positionalEncoding): PositionalEncoding()
  (encoderLayers): ModuleList(
    (0-5): 6 x EncoderLayer(
      (selfAttention): MultiHeadAttention(
        (weightQ): Linear(in_features=512, out_features=512, bias=True)
        (weightK): Linear(in_features=512, out_features=512, bias=True)
        (weightV): Linear(in_features=512, out_features=512, bias=True)
        (weightO): Linear(in_features=512, out_features=512, bias=True)
      )
      (feedForward): PositionWiseFeedForwardNN(
        (forwardConnected1): Linear(in_features=512, out_features=2048, bias=True)
        (forwardConnected2): Linear(in_features=2048, out_features=512, bias=True)
        (relu): ReLU()
      )
      (layerNorm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (layerNorm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (Dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (deco

In [48]:
for epoch in range(100):
  optimizer.zero_grad()
  output = transformer(sourceData, targetData[:, :-1])
  loss = criterion(output.contiguous().view(-1, targetVocabSize), targetData[:, 1:].contiguous().view(-1))
  loss.backward()
  optimizer.step()
  print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

RuntimeError: The size of tensor a (512) must match the size of tensor b (100) at non-singleton dimension 2