In [2]:
import torch
import torch.nn as nn

In [6]:
class SelfAttention(nn.Module):
  def __init__(self,embed_size,heads):
    super(SelfAttention,self).__init__()
    self.embed_size=embed_size # 512
    self.heads=heads           # 4
    self.head_dim=embed_size//heads   # 512/4 = 128

    assert(self.head_dim*heads==embed_size),"Embed size verification"
    self.v=nn.Linear(self.head_dim,self.head_dim,bias=False)
    self.k=nn.Linear(self.head_dim,self.head_dim,bias=False)
    self.q=nn.Linear(self.head_dim,self.head_dim,bias=False)
    self.fc_out=nn.Linear(heads*self.head_dim,embed_size)

  def forward(self,v,k,q,mask):
    N=q.shape[0]
    v_len,k_len,q_len=v.shape[1],k.shape[1],q.shape[1]
    v=v.reshape(N,v_len,self.heads,self.head_dim)
    k=k.reshape(N,k_len,self.heads,self.head_dim)
    q=q.reshape(N,q_len,self.heads,self.head_dim)
    E=torch.einsum("nqhd,nkhd->nhqk",[q,k])      # Q * K.T
    if mask is not None:
      E=E.masked_fill(mask==0,float("1e-20"))
    attention=torch.softmax(E/(self.embed_size**(1/2)),dim=3)
    out=torch.einsum("nhql,nlhd->nqhd",[attention,v]).reshape(
        N,q_len,self.heads*self.head_dim          # Attention(Q,K,V)= softmax( Q* K.T /root(dk))*V

    )
    out=self.fc_out(out)
    return out

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, v, k, q, mask):
        attention = self.attention(v, k, q, mask)
        x = self.dropout(self.norm1(attention + q))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out


class Encoder(nn.Module):
  def __init__(
      self,src_vocab_size,
      embed_size,
      num_layers,
      heads,
      device,
      forward_expansion,
      dropout,
      max_length,
  ):
    super(Encoder,self).__init__()
    self.embed_size=embed_size
    self.device=device
    self.word_embedding=nn.Embedding(src_vocab_size,embed_size)
    self.position_embedding=nn.Embedding(max_length,embed_size)
    self.layers=nn.ModuleList(
        [
            TransformerBlock(
                embed_size,
                heads,
                dropout=dropout,
                forward_expansion=forward_expansion,
            )
        ]
    )
    self.dropout=nn.Dropout(dropout)

  def forward(self,x,mask):
    N,seq_length=x.shape
    positions=torch.arange(0,seq_length).expand(N,seq_length).to(self.device)
    out=self.dropout(self.word_embedding(x) + self.position_embedding(positions))

    for layer in self.layers:
      out=layer(out,out,out,mask)

    return out

class DecoderBlock(nn.Module):
  def __init__(self,embed_size,heads,forward_expansion,dropout,device):
    super(DecoderBlock,self).__init__()
    self.attention=SelfAttention(embed_size,heads)
    self.norm=nn.LayerNorm(embed_size)
    self.transformer_block=TransformerBlock(
        embed_size,heads,dropout,forward_expansion
    )
    self.dropout=nn.Dropout(dropout)

  def forward(self,x,v,k,src_mask,trg_mask):
    attention=self.attention(x,x,x,trg_mask)
    q=self.dropout(self.norm(attention+x))
    out=self.transformer_block(v,k,q,src_mask)
    return out

class Decoder(nn.Module):
    def __init__(self, trg_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)
        self.layers = nn.ModuleList(
            [
                DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
                for _ in range(num_layers)
            ]
        )
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device) # Matrix formed of rows = N and columns = seq_length
        x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)
        return out

import torch
import torch.nn as nn
import torch.optim as optim

# Define the complete Transformer model
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, embed_size=512, num_layers=6,
                 forward_expansion=4, heads=8, dropout=0.1, max_length=100, device="cuda"):
        super(Transformer, self).__init__()

        self.encoder = Encoder(
            src_vocab_size, embed_size, num_layers, heads, device, forward_expansion, dropout, max_length
        )

        self.decoder = Decoder(
            trg_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length
        )

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask.to(self.device)

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len)
        return trg_mask.to(self.device)

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)
        return out

# Example usage
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Example parameters
    src_vocab_size = 10000
    trg_vocab_size = 10000
    src_pad_idx = 0
    trg_pad_idx = 0
    embed_size = 512
    num_layers = 6
    heads = 8
    forward_expansion = 4
    dropout = 0.1
    max_length = 100

    # Initialize model, optimizer, and loss function
    model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, embed_size, num_layers, forward_expansion, heads, dropout, max_length, device).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    criterion = nn.CrossEntropyLoss(ignore_index=trg_pad_idx)

    # Dummy input data
    src = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 0], [1, 8, 7, 3, 4, 5, 6, 2]]).to(device)
    trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 3]]).to(device)

    # Forward pass
    out = model(src, trg[:, :-1])
    print(out.shape)  # Output should have shape (N, trg_len - 1, trg_vocab_size)

    # Calculate loss
    output = out.reshape(-1, out.shape[2])
    trg = trg[:, 1:].reshape(-1)
    loss = criterion(output, trg)

    # Backward pass and optimization
    loss.backward()
    optimizer.step()

    print(f"Loss: {loss.item()}")






















torch.Size([2, 7, 10000])
Loss: 9.382999420166016
