In [1]:
import tiktoken

enc = tiktoken.get_encoding("gpt2")

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm

In [3]:
batch_size = 64
block_size = 128
max_iters = 2400
eval_interval = 300
learning_rate = 5e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
with open('gone_with_the_wind.txt', 'r', encoding='utf-8') as f:
    text = f.read()
text = enc.encode(text)

In [5]:
data = torch.tensor(text, dtype=torch.long, device=device)
n = int(0.8 * len(data))

In [6]:
train_data = data[:n]
val_data = data[n:]

def get_batch(split):
    data = train_data if split == 'train' else val_data
    
    ix = torch.randint(0, len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    
    return x, y

In [7]:
from dataclasses import dataclass
@dataclass
class Config:
    n_vocab: int
    d_model: int
    n_block: int
    n_head: int
    n_layer: int
    d_inner: int
    dropout: float
    emb_dropout: float
    bias: bool = False
    
    def __post_init__(self):
        self.d_k = self.d_v = self.d_model // self.n_head
        self.n_embd = self.d_model

In [8]:
class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.qkv = nn.Linear(config.d_model, 3*config.d_model, bias=config.bias)
        self.proj = nn.Linear(config.d_model, config.d_model, bias=config.bias)
        self.proj_drop = nn.Dropout(config.dropout)
        
        self.dropout_p = config.dropout
        
        self.n_head = config.n_head
        self.d_model = config.d_model
        self.d_k = config.d_k
        
    def forward(self, x):
        B, T, C = x.shape
        
        q,k,v = self.qkv(x).split(self.d_model, dim=2)
        q = q.view(B, T, self.n_head, self.d_k).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.d_k).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.d_k).transpose(1, 2)
        
        y = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout_p if self.train else 0, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        
        y = self.proj_drop(self.proj(y))
        
        return y
    
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.fc1 = nn.Linear(config.d_model,  config.d_inner, bias=config.bias)
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(config.d_inner, config.d_model, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attn = Attention(config)
        self.ln1 = nn.LayerNorm(config.d_model)
        self.mlp = MLP(config)
        self.ln2 = nn.LayerNorm(config.d_model)
    
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x
        
    
class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        # self.embedding = nn.Embedding(vocab_size, dim)
        self.word_embeddings = nn.Embedding(config.n_vocab, config.d_model)
        self.position_embeddings = nn.Embedding(config.n_block, config.d_model)
        self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)])
        self.layernorm = nn.LayerNorm(config.d_model)
        self.lm_head = nn.Linear(config.d_model, config.n_vocab, bias=False)
        
        self.lm_head.weight = self.word_embeddings.weight
        
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        
    def forward(self, idx, targets=None):
        tok_emb = self.word_embeddings(idx)
        pos_emb = self.position_embeddings(torch.arange(idx.shape[1], device=idx.device))
        x = tok_emb + pos_emb
        
        for block in self.blocks:
            x = block(x)

        x = self.layernorm(x)
        
        if targets is not None:
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            logits = self.lm_head(x[:,[-1],:])
            loss = None
            
        return logits, loss
        
    def generate(self, idx, max_new_tokens, temperature=1.0):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=-1)
            
        return idx

In [9]:
config = Config(
    n_vocab=enc.n_vocab,
    d_model=128,
    n_block=block_size,
    n_head=4,
    n_layer=6,
    d_inner=512,
    dropout=0.2,
    emb_dropout=0.0,
    bias=True
)
model = Transformer(config).to(device)

In [10]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [11]:
for iter in tqdm.tqdm(range(max_iters)):
    x, y = get_batch('train')
    logits, loss = model(x, y)
    
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    if iter % eval_interval == 0:
        times = 10
        total_loss = 0
        for _ in range(times):
            x, y = get_batch('val')
            logits, loss = model(x, y)
            total_loss += loss.item()
        print(f'Validation loss: {total_loss / times:.4f}')


  0%|          | 1/2400 [00:05<3:30:53,  5.27s/it]

Validation loss: 10.7486


 13%|█▎        | 301/2400 [06:01<1:30:04,  2.57s/it]

Validation loss: 5.8670


 25%|██▌       | 601/2400 [11:59<1:16:46,  2.56s/it]

Validation loss: 5.1771


 38%|███▊      | 901/2400 [17:56<1:04:14,  2.57s/it]

Validation loss: 4.8755


 50%|█████     | 1201/2400 [23:55<51:53,  2.60s/it] 

Validation loss: 4.6956


 63%|██████▎   | 1501/2400 [29:56<39:00,  2.60s/it]

Validation loss: 4.6197


 75%|███████▌  | 1801/2400 [35:54<25:39,  2.57s/it]

Validation loss: 4.5403


 88%|████████▊ | 2101/2400 [41:53<12:52,  2.58s/it]

Validation loss: 4.4818


100%|██████████| 2400/2400 [47:46<00:00,  1.19s/it]


In [24]:
x, y = get_batch('val')
logits, loss = model(x, y)
print(f'Validation loss: {loss.item()}')

ctx = 'She smiled, and the world was hers.'
ctx = enc.encode(ctx)
ctx = torch.tensor(ctx, dtype=torch.long, device=device).unsqueeze(0)
ctx = model.generate(ctx, 60, 0.4)
ctx = ctx.squeeze(0).tolist()
ctx = enc.decode(ctx)
print(ctx)

Validation loss: 4.48477029800415
She smiled, and the world was hers. She was not to have a child and she had not know.
　“I’m afraid of the world,” he said. “I’ll think I’ve got a silly. I’m not sorry to have a man.”


In [21]:
torch.save(model.state_dict(), 'model.pth')