In [1]:
from dataclasses import dataclass
import torch
import torch.nn as nn

In [2]:
if torch.cuda.is_available():
    print(torch.cuda.get_device_name())

Quadro P5000


In [62]:
@dataclass
class ModelArgs:
    d_model: int = 512
    d_ff: int = 2048
    N: int=6
    dropout: float =0.1
    num_heads: int = 8
    vocab_size: int = 50000
    pad_idx : int = 0

class Embeddings(nn.Module):
    def __init__(self, args:ModelArgs):
        super().__init__()
        self.embedding = nn.Embedding(args.vocab_size, args.d_model)
        self.d_model = args.d_model
    
    def forward(self, x):
        return self.embedding(x)*self.d_model**0.5

        

class MultiHeadAttention(nn.Module):
    def __init__(self, args: ModelArgs, attn_mask=False, pad_mask=True):
        super().__init__()
        self.attn_mask=attn_mask
        self.pad_mask=pad_mask
        self.heads = nn.ModuleList([SelfAttention(args, self.attn_mask, self.pad_mask) for _ in range(args.num_heads)])
    
    #x_2 for cross attention
    def forward(self, x, encoder_output=None):
        if encoder_output:
            return torch.cat([head(x, encoder_output) for head in self.heads], dim=-1)
        else:
            return torch.cat([head(x) for head in self.heads], dim=-1)

class SelfAttention(nn.Module):
    def __init__(self, args: ModelArgs, attn_mask=False, pad_mask=True):
        super().__init__()
        d_in = args.d_model
        self.d_out_kq = d_in // args.d_model
        d_out_v = self.d_out_kq
        self.W_q = nn.Parameter(torch.rand(d_in, self.d_out_kq))
        self.W_k = nn.Parameter(torch.rand(d_in, self.d_out_kq))
        self.W_v = nn.Parameter(torch.rand(d_in, d_out_v))
        self.attn_mask = attn_mask
        self.pad_mask=pad_mask
        self.pad_idx = args.pad_idx

    #x_2 for cross attention
    def forward(self, x, encoder_output=None):
        if encoder_output:
            queries = x @ self.W_q
            keys = encoder_output @ self.W_k
            values = encoder_output @ self.W_v
        else:
            queries = x @ self.W_q
            keys = x @ self.W_k
            values = x @ self.W_v

        attn_scores = queries @ keys.T
        #for masked attention
        if self.attn_mask:
            block_size = attn_scores.shape[0]
            mask = torch.triu(torch.ones(block_size, block_size), diagonal=1)
            attn_scores = attn_scores.masked_fill(mask.bool(), -torch.inf)
        
        if self.pad_mask:
            if encoder_output is not None:
                pad_mask = (encoder_output == self.pad_idx).all(dim=-1)  # Adjusted for encoder_output
            else:
                pad_mask = (x == self.pad_idx).all(dim=-1)  # Original input mask
            pad_mask = pad_mask.unsqueeze(1).unsqueeze(2)
            attn_scores = attn_scores.masked_fill(pad_mask, float('-inf'))
                #calculate and apply pad_mask

        attn_weights = torch.softmax(attn_scores/self.d_out_kq**0.5, dim=-1)
        context_vector = attn_weights @ values

        return context_vector

class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_in_out, d_hidden, dropout):
        super().__init__()
        self.w1 = nn.Linear(d_in_out, d_hidden)
        self.w2 = nn.Linear(d_hidden, d_in_out)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = F.ReLU(self.w1(x))
        return self.w2(self.dropout(x))

class EncoderLayer(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.ff = PositionWiseFeedForward(args.d_model, args.d_ff, args.dropout)
        self.mha = MultiHeadAttention(args)
        self.norm1 = nn.LayerNorm(args.d_model)
        self.norm2 = nn.LayerNorm(args.d_model)

    def forward(self, x):
        x = self.norm1(x + self.mha(x))
        return self.norm2(x+self.ff(x))
    
class Encoder(nn.Module):
    def __init__(self, args:ModelArgs):
        super().__init__()
        self.layers = nn.ModuleList([EncoderLayer(args) for _ in range(args.N)])
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x,)
        return x
    

class DecoderLayer(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.mmha = MultiHeadAttention(args, attn_mask=True)
        self.mhca = MultiHeadAttention(args)
        self.ff = PositionWiseFeedForward(args.d_model, args.d_ff, args.dropout)
        self.norm1 = nn.LayerNorm(args.d_model)
        self.norm2 = nn.LayerNorm(args.d_model)
        self.norm3 = nn.LayerNorm(args.d_model)
    
    def forward(self, x, encoder_output):
        x = self.norm1(x+self.mmha(x))
        x = self.norm2(x + self.mhca(x, encoder_output))
        return self.norm3(x + self.ff(x))

class Decoder(nn.Module):
    def __init__(self, args:ModelArgs):
        super().__init__()
        self.layers = nn.ModuleList([DecoderLayer(args) for _ in range(args.N)])
    
    def forward(self, x, encoder_output):
        for layer in self.layers:
            x = layer(x, encoder_output)
        return x
    
class Transformer(nn.Module):

    def __init__(self, args:ModelArgs):
        super().__init__()
        self.input_embedding = Embeddings(args)
        self.output_embedding = Embeddings(args)
        self.encoder = Encoder(args)
        self.decoder = Decoder(args)
        self.linear = nn.Linear(args.d_model, args.vocab_size)
        self.linear.weight = self.input_embedding.embedding.weight
        self.output_embedding.weight = self.input_embedding.embedding.weight
    
    def forward(self, src, tgt):
        src = self.encoder(self.input_embedding(src))
        tgt = self.shift_right(tgt)
        tgt = self.decoder(self.output_embedding(tgt), src)
        return torch.log_softmax(self.linear(x), dim=-1)

    



In [63]:
args = ModelArgs()
print(Transformer(args))

Transformer(
  (input_embedding): Embeddings(
    (embedding): Embedding(50000, 512)
  )
  (output_embedding): Embeddings(
    (embedding): Embedding(50000, 512)
  )
  (encoder): Encoder(
    (layers): ModuleList(
      (0): EncoderLayer(
        (ff): PositionWiseFeedForward(
          (w1): Linear(in_features=512, out_features=2048, bias=True)
          (w2): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (mha): MultiHeadAttention(
          (heads): ModuleList(
            (0): SelfAttention()
            (1): SelfAttention()
            (2): SelfAttention()
            (3): SelfAttention()
            (4): SelfAttention()
            (5): SelfAttention()
            (6): SelfAttention()
            (7): SelfAttention()
          )
        )
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      )
      (1): EncoderL

In [29]:
for name, param in sa.named_parameters():
    print(name, param.shape)

W_q torch.Size([512, 512])
W_k torch.Size([512, 512])
W_v torch.Size([512, 512])
