In [2]:
import torch
from torch import nn
import math

class multi_head_attn(nn.Module):
    def __init__(self, head=8, d_model=512):
        super().__init__()
        self.q_w = nn.Linear(d_model, d_model)
        self.k_w = nn.Linear(d_model, d_model)
        self.v_w = nn.Linear(d_model, d_model)
        self.o_w = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)
        self.head = head
        self.new_d = d_model // head
        self.d_model = d_model
        
    def forward(self, q, k, v):
        q, k, v = self.q_w(q), self.k_w(k), self.v_w(v)
        
        q = q.view(B, T, self.head, self.new_d).transpose(1,2)
        k = k.view(B, T, self.head, self.new_d).transpose(1,2)
        v = v.view(B, T, self.head, self.new_d).transpose(1,2)
        
        score = q @ k.transpose(2,3) / math.sqrt(self.new_d)
        mask = torch.tril(torch.ones(T, T, dtype=bool))
        score.masked_fill(mask==0, -10000)
        
        out = score @ v
        out = out.transpose(1,2).contiguous().view(B, T, self.d_model)
        
        out = self.o_w(out)
        
        return out
    
B = 16
T = 64
x = torch.randn(B, T, 512)
attn = multi_head_attn()
y = attn(x,x,x)
y.shape

torch.Size([16, 64, 512])

In [6]:
class layer_norm(nn.Module):
    def __init__(self, eps=1e-12, d_model=512):
        super().__init__()
        self.eps = eps
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.gamma = nn.Parameter(torch.ones(d_model))
        
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, unbiased=True, keepdim=True)
        out = (x-mean) / (var + self.eps)
        out = self.gamma * out + self.beta
        
        return out
    
layer_norm = layer_norm()
print(layer_norm.gamma[:10])
print(layer_norm.beta[:10])
x = torch.randn(16, 64, 512)
y = layer_norm(x)
y.shape

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], grad_fn=<SliceBackward0>)
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<SliceBackward0>)


torch.Size([16, 64, 512])