## Transformer Decoder

In [3]:
import torch
import math
import torch.nn as nn

In [4]:
# block / layer
class SimpleDecoderLayer(nn.Module):
    def __init__(self, hidden_dim, head_num, attention_dropout_rate = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.head_dim = hidden_dim // head_num # multi-head

        # layer (mha, ffn)
        # mha
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.o_proj = nn.Linear(hidden_dim, hidden_dim)
        self.drop_att = nn.Dropout(attention_dropout_rate)
        self.att_ln = nn.LayerNorm(hidden_dim, eps = 0.0000001)

        # ffn (increase dim -> decrease dim -> layer norm)
        self.up_proj = nn.Linear(hidden_dim, hidden_dim * 4) 
        self.down_proj = nn.Linear(hidden_dim*4, hidden_dim)
        self.act_fn = nn.GELU() # (improved ReLU)
        self.drop_ffn = nn.Dropout(0.1)
        self.ffn_ln = nn.LayerNorm(hidden_dim, eps = 0.0000001)

    def attention_layer(self, query, key, value, attention_mask = None):
        # output shape (b, s, h)
        key = key.transpose(2,3) # (b, head_num, head_dim, seq)
        attention_weight = torch.matmul(query, key) / math.sqrt(self.head_dim)

        if attention_mask is not None:
            attention_mask = attention_mask.tril()
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0, float('-inf')

            )
        else:
            attention_mask = torch.ones_like(
                attention_weight
            ).tril()
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0, float('-inf')
            )
        
        attention_weight = torch.softmax(attention_weight)
        attention_weight = self.drop_att(attention_weight)
        mid_output = torch.matmul(attention_weight, value) # (b, head_num, seq, head_dim)
        
        mid_output = mid_output.transpose(1, 2).contiguous()
        batch, seq, _, _ = mid_output.size()
        mid_output = mid_output.view(batch, seq, -1) # hidden_dim

        output = self.o_proj(mid_output)

        return  output
        
    def mha(self, X, mask = None):
        # (b, s, h) -> (b, head_num, s, head_dim)
        batch, seq, _ = X.size()

        Q = self.query_proj(X)
        K = self.key_proj(X)
        V = self.value_proj(X)

        query = Q.view(batch, seq, self.head_num, self.head_dim).transpose(1,2)
        key = K.view(batch, seq, self.head_num, self.head_dim).transpose(1,2)
        value = V.view(batch, seq, self.head_num, self.head_dim).transpose(1,2)

        output = self.attention_layer(query, key, value, mask) 
        # post norm (b, s, h)
        return self.att_ln(X + output)
    
    def ffn(self, X):
        up = self.up_proj(X)
        up = self.act_fn(up)
        down = self.down_proj(up)
        
        # dropout
        down = self.drop_ffn(down)

        # post layernorm
        return self.ffn_ln(X + down)

    def forward(self, X, attention_mask = None):
        X = self.mha(X, attention_mask)
        X = self.ffn(X)
        return X

class Decoder(nn.Module):
    def __init__(self,):
        super().__init__()
        self.layer_list = nn.ModuleList(
            [
                SimpleDecoderLayer(64,8) for i in range(5)
            ]
        )
        self.emb = nn.Embedding(12, 64)
        self.out = nn.Linear(64,12)

    def forward(self, X, mask = None):
        # (b, s)
        X = self.emb(X)
        for i, l in enumerate(self.layer_list):
            X = l(X, mask)
        print(X.shape)
        output = self.out(X)
        return torch.softmax(output, dim = -1)
    


    
x = torch.randint(low = 0, high = 12, size(3,4) )

