In [1]:
import torch
from torch import nn

In [2]:
vocab_size = 16000
seq_len = 128
d_model = 128
n_layer = 4
n_head = 4

In [12]:
import math
from torchinfo import summary
from torch.nn import functional as F

class SinusoidPE(nn.Module):
    """ sin/cos position encoding """
    def __init__(self):
        super().__init__()
        
        pe = torch.zeros(seq_len, d_model)
        pos = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        emb = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:,0::2] = torch.sin(pos * emb)
        pe[:,1::2] = torch.cos(pos * emb)
        
        # token embedding: B * C * E
        # pos embedding: 1 * C * E
        pe = pe.unsqueeze(0)
        self.register_buffer('sinusoid_pe', pe)
        
    def forward(self, x):
        return self.sinusoid_pe[:, :x.shape[1],:]

class FeedForward(nn.Module):
    def __init__(self, n_embd, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )
        
    def forward(self, x):
        return self.net(x)
    
    
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size, dropout=0.0):
        super().__init__()
        self.key = nn.Linear(d_model, head_size, bias=False)
        self.query = nn.Linear(d_model, head_size, bias=False)
        self.value = nn.Linear(d_model, head_size, bias=False)
        self.register_buffer('mask', torch.tril(torch.ones(seq_len, seq_len)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, C, E = x.shape
        k = self.key(x)   # (B, C, E)
        q = self.query(x)  # (B, C, E)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2, -1) * E**-0.5  # (B, C, E) @ (B, E, C) -> (B, C, C)
        wei = wei.masked_fill(self.mask[:C,:C] == 0, float('-inf'))  # (B, C, C)
        wei = F.softmax(wei, dim=-1)  # (B, C, C)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x)  # (B, C, E)
        out = wei @ v  # (B, C, C) @ (B, C, E) -> (B, C, E)
        return out
    
class SelfAttention(nn.Module):
    def __init__(self, num_heads, head_size, dropout=0.0):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out
    
class MultiAttention(nn.Module):
    def __init__(self, dropout=0.0):
        super().__init__()
        # self.w_q = nn.Linear(d_model, d_model)
        # self.w_k = nn.Linear(d_model, d_model)
        # self.w_v = nn.Linear(d_model, d_model)
        self.attn = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout) 
        self.register_buffer('mask', torch.tril(torch.ones(seq_len, seq_len))
                                          .view(1,1, seq_len, seq_len))
        
    def forward(self, x):
        B, C, E = x.shape
        q, k, v = self.attn(x).split(d_model, dim=2)
        q = q.view(B, C, n_head, E // n_head).transpose(1,2) # (B, C, nh, hs) -> (B, nh, C, hs)
        k = k.view(B, C, n_head, E // n_head).transpose(1,2) # (B, C, nh, hs)
        v = v.view(B, C, n_head, E // n_head).transpose(1,2) # (B, C, nh, hs)
        
        # (B, nh, C, hs) * (B, nh, hs, C) -> (B, nh, C, C)
        wei = q @ k.transpose(-2, -1) * (k.size(-1))**-0.5 
        wei = wei.masked_fill(self.mask[:,:,:C,:C] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        att = wei @ v # (B, nh, C, C) * (B, nh, C, hs) -> (B, nh, C, hs)
        att = att.transpose(1,2).contiguous().view(B,C,E) # (B, nh, C, hs) -> (B, C, nh, hs) -> (B, C, E)
        
        out = self.proj(att)
        return out
            
class Block(nn.Module):
    
    def __init__(self):
        super().__init__()
        head_size = d_model // n_head
        self.ln1 = nn.LayerNorm(d_model)
        # self.attn = SelfAttention(n_head, head_size)
        self.attn = MultiAttention()
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model)
        
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

class GPTModel(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.tok_embed_table = nn.Embedding(vocab_size, d_model)
        self.pos_embed_table = SinusoidPE()
        self.decoder_blocks = nn.Sequential(*[Block() for _ in range(n_layer)])
        self.ln = nn.LayerNorm(d_model)
        self.final_linear = nn.Linear(d_model, vocab_size)
        
    def forward(self, features, targets=None):
        tok_emb = self.tok_embed_table(features)
        pos_emb = self.pos_embed_table(tok_emb)
        x = tok_emb + pos_emb
        x = self.decoder_blocks(x)
        out = self.final_linear(self.ln(x))
        
        if targets is not None:
            B, C, V = out.shape
            out = out.view(B * C, V)
            targets = targets.view(B * C)
            loss = F.cross_entropy(out, targets)
            return out, loss
        else:
            return out, None
      
    @torch.no_grad()
    def generate(self, seq, max_new_tokens):
        for _ in range(max_new_tokens):
            seq = seq[:,-seq_len:] # B, L, E
            pred, _ = self(seq)
            pred = pred[:,-1,:] # B, C, V -> B, 1, V
            probs = F.softmax(pred, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1) # [0.1, 0.7, 0.2]
            seq = torch.cat((seq, next_token), dim=1)
        return seq
        
model = GPTModel()
# summary(model)

import sentencepiece as spm
import sys

model_file = "bird_shooter.model"
sp = spm.SentencePieceProcessor()
if not sp.load(model_file=model_file):
    print("load tokenizer model failed")
    sys.exit(1)

user_input = "郭靖一掌挥出"   
context = torch.tensor([sp.encode(user_input)], dtype=torch.int32)
gpt_output = model.generate(context, max_new_tokens=20)[0].tolist()
print(f"gpt => {sp.decode(gpt_output)}")

gpt => 郭靖一掌挥出恶毒果是一技谄的话不加理会众姬犒原本赏去甚是畅快油腻一灯大师的两个嗤的一声点点头延倚奈
