In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, dropout, embedding, nhead):
        super().__init__()
        # attention layers
        self.attention_self = MultiHeadAttention(d_model=embedding, nhead=nhead, mask=None, dropout=dropout)

        # cross atten, query values, info
        self.cross_q_proj = nn.Linear(embedding, embedding)
        self.cross_k_proj = nn.Linear(embedding, embedding)
        self.cross_v_proj =  nn.Linear(embedding, embedding)
        self.cross_out_proj = nn.Linear(embedding, embedding)

        self.nhead = nhead
        self.head_dim = embedding // nhead
        self.scale = self.head_dim ** -0.5 # 1/sqrt(dk)

        # layer normal
        self.norm1 = nn.LayerNorm(embedding)
        self.norm2 = nn.LayerNorm(embedding)
        self.norm3 = nn.LayerNorm(embedding)

        # droput
        self.dropout = nn.Dropout(dropout)

        # fcnn
        self.fcnn = nn.Sequential(
            nn.Linear(embedding, embedding*2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embedding*2, embedding)
        )
    # decoder forward pass
    def forward(self, decoder_input, encoded_context, casual_mask):

        # self attention amoung decoder
        residual = decoder_input
        norm_x = self.norm1(decoder_input)
        self_attn = self.attention_self(norm_x, casual_mask)
        decoder_input = residual + self.dropout(self_attn)

        # cross attention to encoder
        norm_x = self.norm2(decoder_input)
        cross_atn = self.encoder_cross_attention(norm_x, encoded_context)

        # dropout, also cant do inplace ops bc of backprop
        decoder_input = decoder_input + self.dropout(cross_atn)

        # fcnn predictions
        norm_x = self.norm3(decoder_input)
        ffcn = self.fcnn(norm_x)
        out = decoder_input + self.dropout(ffcn)

        return out

    def encoder_cross_attention(self, query, key_value):
        B, L_q, _ = query.shape # decoder input
        B, L_kv, _ = key_value.shape # encoder output

        # both values full percision
        query = query.float()
        key_value = key_value.float()

        q = self.cross_q_proj(query)
        k = self.cross_k_proj(key_value)
        v = self.cross_v_proj(key_value)

        q = q.view(B, L_q, self.nhead, self.head_dim).transpose(1, 2)
        k = k.view(B, L_kv, self.nhead, self.head_dim).transpose(1, 2)
        v = v.view(B, L_kv, self.nhead, self.head_dim).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context = torch.matmul(attn_weights, v)
        context = context.transpose(1, 2).contiguous().view(B, L_q, -1)

        output = self.cross_out_proj(context)

        return output

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, nhead, mask, dropout=0.15):
        super().__init__() # inhert from parent class

        if d_model % nhead != 0:
            raise ValueError(f"d_model ({d_model}) must be divisible by nhead ({nhead})")

        self.d_model = d_model # dimension of model
        self.nhead = nhead # number of attention heads, multi headed
        self.head_dim = d_model // nhead

        # create key query and values
        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        # learn context as a product of the attention heads
        self.out_proj = nn.Linear(d_model, d_model)
        # dropout as a form of regularzation
        self.dropout = nn.Dropout(dropout)
        # scaling function
        self.scale = self.head_dim ** -0.5

    def forward(self, x, mask=None):
        B, L, _ = x.shape # batch and length

        # create q, k, v values | init just random matrix mults, learned parameter
        x_input = x.float() if x.dtype != torch.float32 else x
        qkv = self.qkv_proj(x_input) # use full percison for attention calculations

        # split key, query, and value vectors into diff pares
        q, k, v = qkv.chunk(3, dim=-1)

        # transpose the matrix so that batch and nhead are treated as batches and self attention is calculated from there
        q = q.view(B, L, self.nhead, self.head_dim).transpose(1, 2)
        k = k.view(B, L, self.nhead, self.head_dim).transpose(1, 2)
        v = v.view(B, L, self.nhead, self.head_dim).transpose(1, 2)

        # scaled dot product, scale so values arent 0 or 1
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale # matrix mult

        # set masked values to -inf so softmax does not "give" attention to them
        if mask is not None:
          if mask.dim() == 2:
            mask = mask.expand(B, self.nhead, L, L)
          elif mask.dim() == 3:
            mask = mask.unsqueeze(1)
            mask = mask.expand(B, self.nhead, L, L)
          elif mask.dim() == 4:
            mask = mask.expand(B, self.nhead, L, L)

          scores = scores.masked_fill(mask == 0, torch.finfo(scores.dtype).min) # ignore masked values

        # softmax to give attention weights to each token
        attn_weights = torch.softmax(scores, dim=-1)

        # drop some weights
        attn_weights = self.dropout(attn_weights)

        # context vector for a given input sequence
        context = torch.matmul(attn_weights, v)

        # transpose so the matrix is in the correct size to be concatinated
        context = context.transpose(1, 2).contiguous().view(B, L, self.d_model)

        # "combine" the outputs from the head to one general vector
        output = self.out_proj(context)

        return output.to(x.dtype)

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, nhead=None, mask=None, dropout=0.15):
        super(TransformerBlock, self).__init__()
        # self attention class definied above
        self.self_attn = MultiHeadAttention(d_model=d_model, nhead=nhead, dropout=dropout, mask=mask)

        # feed forward network for each token
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_model*2),
            nn.GELU(),
            nn.Dropout(dropout), # to combat overfitting
            nn.Linear(d_model*2, d_model)
        )

        # normilzations so values are between 0-1, learned gamma and beta parameters
        # to shift center and var for values.
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # standard dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        residual = x
        # normalized pre attention layer, gradients flow black directly without the normalizing effecting x values
        norm_x = self.norm1(x)
        # self attention
        attn_output = self.self_attn(norm_x, mask)
        # adding residual back to self attention
        x = residual + self.dropout(attn_output)

        residual = x
        # normalize values
        # we do so because over the amount of layers scale can get distorted, lead to super big or small values
        norm_x = self.norm2(x)
        # basic fcn
        ff_output = self.feed_forward(norm_x)
        # adding residual back so that the gradient can flow directly back.
        # adds a 1 + terms to gradients, helps solve the vanishing gradients problem
        x = residual + self.dropout(ff_output)

        return x