In [23]:
import torch
from torch import nn, optim
from torch.utils.data import random_split, DataLoader
from torchinfo import summary
from torchvision import datasets, transforms, models

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

cuda


## Implementatation

In [24]:
class SelfAttention(nn.Module):
    def __init__(self, input_dim, embedding_dim):
        super().__init__()
        self.input_dim = input_dim
        
        self.Q = nn.Linear(input_dim, embedding_dim)
        self.K = nn.Linear(input_dim, embedding_dim)
        self.V = nn.Linear(input_dim, embedding_dim)
    
    def forward(self, x_q, x_k, x_v, mask=None):
        q = self.Q(x_q)
        k = self.K(x_k)
        v = self.V(x_v)
        
        a = (q @ k.permute(0, 2, 1)) / (self.input_dim ** 0.5)
        if mask is not None:
            a[:, mask] = -1e10
        x = torch.softmax(a, dim=2) @ v
        return x


class MultiheadSelfAttention(nn.Module):
    def __init__(self, input_dim, n_heads):
        super().__init__()
        self.attention_heads = nn.ModuleList([SelfAttention(input_dim, input_dim // n_heads) for _ in range(n_heads)])
        self.linear = nn.Linear(input_dim, input_dim)
    
    def forward(self, x_q, x_k, x_v, mask=None):
        x = torch.concat([attention_head(x_q, x_k, x_v, mask) for attention_head in self.attention_heads], dim=-1)
        x = self.linear(x)
        return x


x_batch = torch.randn(4, 100, 512)
mask = torch.ones(100, 100).tril() == 0

print(MultiheadSelfAttention(512, 16)(x_batch, x_batch, x_batch, mask=mask).shape)
print(nn.MultiheadAttention(512, 16, batch_first=True)(x_batch, x_batch, x_batch, need_weights=False, attn_mask=mask)[0].shape)

torch.Size([4, 100, 512])
torch.Size([4, 100, 512])


In [25]:
class MSABlock(nn.Module):
    def __init__(self, input_dim, n_heads, dropout, torch_msa=True):
        super().__init__()
        self.torch_msa = torch_msa
        
        self.msa = MultiheadSelfAttention(input_dim, n_heads) if torch_msa \
            else nn.MultiheadAttention(input_dim, n_heads, batch_first=True)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x_q, x_k, x_v, mask=None):
        x = self.msa(x_q, x_k, x_v, mask=mask) if self.torch_msa \
            else self.msa(x_q, x_k, x_v, need_weights=False, attn_mask=mask)[0] # mask
        x = self.dropout(x)
        return x


class MLPBlock(nn.Module):
    def __init__(self, input_dim, mlp_dim, dropout):
        super().__init__()
        
        self.linear1 = nn.Linear(input_dim, mlp_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(mlp_dim, input_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.dropout(x)
        return x

In [26]:
class DecoderBlock(nn.Module):
    def __init__(self, input_dim, mlp_dim, n_heads, dropout):
        super().__init__()
        self.msa_block = MSABlock(input_dim, n_heads, dropout)
        self.layer_norm1 = nn.LayerNorm(input_dim)
        self.mlp_block = MLPBlock(input_dim, mlp_dim, dropout)
        self.layer_norm2 = nn.LayerNorm(input_dim)
    
    def forward(self, x, mask=None):
        shortcut = x
        x = self.layer_norm1(x)
        x = self.msa_block(x, x, x, mask) + shortcut
        shortcut = x
        x = self.layer_norm2(x)
        x = self.mlp_block(x) + shortcut
        return x

In [27]:
class GPT2(nn.Module):
    def __init__(self, seq_len, word_dim, embedding_dim, mlp_dim, n_heads, n_layers, n_classes, dropout):
        super().__init__()
        self.seq_len = seq_len
        
        self.input_embedding = nn.Linear(word_dim, embedding_dim, bias=False)
        self.positional_encoding = nn.Parameter(torch.randn(1, seq_len, embedding_dim))
        
        self.decoder_layers = nn.ModuleList([
            DecoderBlock(embedding_dim, mlp_dim, n_heads, dropout) for _ in range(n_layers)
        ])
        
        self.classification_head = nn.Sequential(
            nn.LayerNorm(embedding_dim),
            nn.Linear(embedding_dim, n_classes)
        )
    
    def create_attention_mask(self):
        mask = torch.ones(self.seq_len, self.seq_len).tril() == 0
        return mask
    
    def forward(self, input_seq):
        input_embedding = self.input_embedding(input_seq) + self.positional_encoding
        
        mask = self.create_attention_mask()
        decoder_out = input_embedding
        for decoder_layer in self.decoder_layers:
            decoder_layer(input_embedding, mask)
        
        logits = self.classification_head(decoder_out)
        
        return logits

In [28]:
seq_len = 1024
word_dim = 50257
embedding_dim = 768
mlp_dim = embedding_dim*4
n_heads = 12
n_layers = 12
dropout = 0.1

gpt2_model = GPT2(seq_len, word_dim, embedding_dim, mlp_dim, n_heads, n_layers, word_dim, dropout)
summary(gpt2_model, input_size=(1, seq_len, word_dim), device='cpu', col_names=['output_size', 'num_params', 'mult_adds'], depth=2)

Layer (type:depth-idx)                                  Output Shape              Param #                   Mult-Adds
GPT2                                                    [1, 1024, 50257]          786,432                   --
├─Linear: 1-1                                           [1, 1024, 768]            38,597,376                38,597,376
├─ModuleList: 1-2                                       --                        --                        --
│    └─DecoderBlock: 2-1                                [1, 1024, 768]            7,087,872                 7,087,872
│    └─DecoderBlock: 2-2                                [1, 1024, 768]            7,087,872                 7,087,872
│    └─DecoderBlock: 2-3                                [1, 1024, 768]            7,087,872                 7,087,872
│    └─DecoderBlock: 2-4                                [1, 1024, 768]            7,087,872                 7,087,872
│    └─DecoderBlock: 2-5                                [1, 1024, 768

In [29]:
from transformers import GPT2Config, GPT2Model
config = GPT2Config()
gpt1_torch_model = GPT2Model(config)

input_tensor = torch.randint(0, word_dim, (1, seq_len), dtype=torch.long) 
summary(gpt1_torch_model, input_data=input_tensor, device='cpu', col_names=['output_size', 'num_params', 'mult_adds'], depth=2)

Layer (type:depth-idx)                        Output Shape              Param #                   Mult-Adds
GPT2Model                                     [1, 12, 1024, 64]         --                        --
├─Embedding: 1-1                              [1, 1024, 768]            38,597,376                38,597,376
├─Embedding: 1-2                              [1, 1024, 768]            786,432                   786,432
├─Dropout: 1-3                                [1, 1024, 768]            --                        --
├─ModuleList: 1-4                             --                        --                        --
│    └─GPT2Block: 2-1                         [1, 1024, 768]            7,087,872                 13,605,473,280
│    └─GPT2Block: 2-2                         [1, 1024, 768]            7,087,872                 13,605,473,280
│    └─GPT2Block: 2-3                         [1, 1024, 768]            7,087,872                 13,605,473,280
│    └─GPT2Block: 2-4              