In [1]:
import torch

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

  cpu = _conversion_method_template(device=torch.device("cpu"))


Using mps device


In [2]:
import tiktoken

class DataCollator:
    def __init__(self, path_to_data, ratio_train):
        with open(path_to_data, mode='r', encoding='utf-8', errors='replace') as f:
            content = f.read()

        self.content = content
        
        # (1) using Home-made solution
        # self.vocab = sorted(list(set(content)))
        # self.n_vocab = len(self.vocab)
        # dict_ctoi = { char:idx for idx, char in enumerate(self.vocab) }
        # dict_itoc = { idx:char for idx, char in enumerate(self.vocab) }
        # self.fn_encode = lambda s: [dict_ctoi[c] for c in s]
        # self.fn_decode = lambda s: ''.join([dict_itoc[i] for i in s])

        # (2) using tiktoken
        encoding = tiktoken.get_encoding("gpt2")
        self.n_vocab = encoding.n_vocab
        self.fn_encode = encoding.encode
        self.fn_decode = encoding.decode

        data = torch.tensor(self.fn_encode(content), dtype=torch.long)
        n = int(len(data) * ratio_train)
        self.train_data = data[:n]
        self.eval_data = data[n:]

    def collate_data(self, category, batch_size, context_size):
        data = self.train_data if category == 'train' else self.eval_data
        batch_start_idx = torch.randint(len(data) - context_size - 1, (batch_size,))
        x = torch.stack([data[idx:idx+context_size] for idx in batch_start_idx])
        y = torch.stack([data[idx+1:idx+context_size+1] for idx in batch_start_idx])
        x, y = x.to(device), y.to(device)
        return x, y

In [3]:
# dc = DataCollator('./TinyS.txt', 0.9)
dc = DataCollator('./Tolkien.txt', 0.9)

print("Read in: ", len(dc.content))
print(dc.n_vocab)

Read in:  3712783
50257


In [4]:
class MaskedSingleHeadAttention(torch.nn.Module):
    def __init__(self, head_size, context_size, n_embedding, dropout_p):
        super().__init__()
        self.query = torch.nn.Linear(n_embedding, head_size, bias=False)
        self.key = torch.nn.Linear(n_embedding, head_size, bias=False)
        self.value = torch.nn.Linear(n_embedding, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(context_size, context_size)))
        self.dropout = torch.nn.Dropout(dropout_p)

    def forward(self, x):
        # x: (b, c, f)
        batch, ctx, features = x.shape
        # q or k: (b, c, f) @ (f, h) = (b, c, h) where h(head_size) = f / n_head
        q = self.query(x)
        k = self.key(x)
        # calc attention score, w: (b, c, c)
        w = q @ k.transpose(-2, -1) * features**-0.5
        w = w.masked_fill(self.tril[:ctx, :ctx] == 0, float('-inf'))
        w = torch.nn.functional.softmax(w, dim=-1)
        w = self.dropout(w)
        # cal weighted value, v: (b, c, h)
        v = self.value(x)
        # (b, c, c) @ (b, c, h) = (b, c ,h)
        rslt = w @ v
        return rslt

# params: 4 * n_embedding ^ 2 (Q, K, V, projection)
class MaskedMultiHeadAttention(torch.nn.Module):
    def __init__(self, n_head, context_size, n_embedding, dropout_p):
        super().__init__()
        head_size = n_embedding // n_head
        self.heads = torch.nn.ModuleList([MaskedSingleHeadAttention(head_size, context_size, n_embedding, dropout_p) for _ in range(n_head)])
        self.projection = torch.nn.Linear(n_embedding, n_embedding)
        self.dropout = torch.nn.Dropout(dropout_p)

    def forward(self, x):
        # (b, c ,h) --cat--> (b, c, f)
        rslt = torch.cat([head(x) for head in self.heads], dim=-1)
        rslt = self.dropout(self.projection(rslt))
        return rslt

# params: 2 * 4 * n_embedding ^ 2
class FeedFoward(torch.nn.Module):
    def __init__(self, n_embedding, dropout_p):
        super().__init__()
        self.seq = torch.nn.Sequential(
            torch.nn.Linear(n_embedding, n_embedding * 4),
            torch.nn.ReLU(),
            torch.nn.Linear(n_embedding * 4, n_embedding),
            torch.nn.Dropout(dropout_p),
        )

    def forward(self, x):
        return self.seq(x)

class TransformerUnit(torch.nn.Module):
    def __init__(self, n_head, context_size, n_embedding, dropout_p):
        super().__init__()
        self.mha = MaskedMultiHeadAttention(n_head, context_size, n_embedding, dropout_p)
        self.ff = FeedFoward(n_embedding, dropout_p)
        self.mha_ln = torch.nn.LayerNorm(n_embedding)
        self.ff_ln = torch.nn.LayerNorm(n_embedding)

    def forward(self, x):
        x = x + self.mha(self.mha_ln(x))
        x = x + self.ff(self.ff_ln(x))
        return x

# params: vocab_size * n_embedding * 2 + context_size * n_embedding
class NaiveLangModel(torch.nn.Module):
    def __init__(self, vocab_size, n_layer, n_head, context_size, n_embedding, dropout_p):
        super().__init__()
        # params: vocab_size * n_embedding
        self.token_embed = torch.nn.Embedding(vocab_size, n_embedding)
        # params: context_size * n_embedding
        self.position_embed = torch.nn.Embedding(context_size, n_embedding)
        self.units = torch.nn.Sequential(*[TransformerUnit(n_head, context_size, n_embedding, dropout_p) for _ in range(n_layer)])
        self.ln = torch.nn.LayerNorm(n_embedding)
        # params: vocab_size * n_embedding
        self.pred_head = torch.nn.Linear(n_embedding, vocab_size)
        self.context_size = context_size

    def forward(self, inputs, labels=None):
        batch, ctx = inputs.shape
        # t_embed: (b, c, f); p_embed: (c,f)
        t_embed = self.token_embed(inputs)
        p_embed = self.position_embed(torch.arange(ctx, device=device))
        # x: (b, c, f)
        x = t_embed + p_embed
        x = self.units(x)
        x = self.ln(x)
        # logits: (b, c, v) 
        logits = self.pred_head(x)

        if labels is None:
            return logits, None

        batch, ctx, features = logits.shape
        predicts = logits.view(batch*ctx, features)
        targets = labels.view(batch*ctx)
        return logits, torch.nn.functional.cross_entropy(predicts, targets)

    def generate(self, inputs, max_gen):
        for _ in range(max_gen):
            inputs_last_window = inputs[:, -self.context_size:]
            logits, loss = self(inputs_last_window)
            logits = logits[:, -1, :]
            pred_next = torch.multinomial(torch.nn.functional.softmax(logits, dim=1), num_samples=1)
            inputs = torch.cat((inputs, pred_next), dim=1)
        return inputs

In [5]:
n_layer = 4
#n_head = 4
n_head = 6
# n_embedding = 256
n_embedding = 192
dropout_p = 0.2
context_size=128 # context length for prediction

# params: vocab_size * n_embedding * 2 + context_size * n_embedding + 12 * n_embedding ^ 2
model = NaiveLangModel(vocab_size=dc.n_vocab, n_layer=n_layer, n_head=n_head, context_size=context_size, n_embedding=n_embedding, dropout_p=dropout_p)
model = model.to(device)
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

21.151057 M parameters


In [None]:
# model.load_state_dict(torch.load("model.pth", weights_only=True))

In [6]:
def train_model(learning_rate, batch_size, steps, eval_interval, n_eval):
    @torch.no_grad()
    def calc_loss(n_eval, batch_size):
        rslt = {}
        model.eval()
        for c in ['train', 'eval']:
            losses = torch.zeros(n_eval)
            for i in range(n_eval):
                x, y = dc.collate_data(c, batch_size, model.context_size)
                _, loss = model(x, y)
                losses[i] = loss.item()
            rslt[c] = losses.mean()
        model.train()
        return rslt

    def prepare_optimizer(learning_rate, weight_decay):
        decay = set()
        no_decay = set()
        
        for name, module in model.named_modules():
            for pname, param in module.named_parameters(recurse=False):
                full_name = f"{name}.{pname}" if name else pname
                if pname.endswith("bias") or "embed" in name or "pred_head" in name:
                    no_decay.add(full_name)
                else:
                    decay.add(full_name)
        
        param_dict = {pn: p for pn, p in model.named_parameters()}
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
    
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate)
        return optimizer
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    # optimizer = prepare_optimizer(learning_rate=learning_rate, weight_decay=0.01)

    for step in range(steps):
        if step % eval_interval == 0 or step == steps - 1:
            losses = calc_loss(n_eval, batch_size)
            print(f"[step {step}] train loss {losses['train']:.4f}, eval loss {losses['eval']:.4f}")
    
        x, y = dc.collate_data('train', batch_size, model.context_size)
        _, loss = model(x, y)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

In [7]:
steps = 5000
eval_interval = 100 # evaluate every N steps
# batch_size = 128
batch_size = 64
n_eval = 100       # evaluate n_eval times then calculate the mean
lr = 3e-4

train_model(learning_rate=lr, batch_size=batch_size, steps=steps, eval_interval=eval_interval, n_eval=n_eval)

[step 0] train loss 11.0381, eval loss 11.0385
[step 100] train loss 5.3590, eval loss 5.4330
[step 200] train loss 4.7298, eval loss 4.9026
[step 300] train loss 4.4295, eval loss 4.6635
[step 400] train loss 4.2294, eval loss 4.5488
[step 500] train loss 4.1018, eval loss 4.4709
[step 600] train loss 3.9775, eval loss 4.4301
[step 700] train loss 3.8573, eval loss 4.3623
[step 800] train loss 3.7787, eval loss 4.3467
[step 900] train loss 3.6940, eval loss 4.3356
[step 1000] train loss 3.6344, eval loss 4.2964
[step 1100] train loss 3.5427, eval loss 4.3204
[step 1200] train loss 3.4754, eval loss 4.2962
[step 1300] train loss 3.4174, eval loss 4.2885
[step 1400] train loss 3.3711, eval loss 4.3041
[step 1500] train loss 3.3121, eval loss 4.2956
[step 1600] train loss 3.2549, eval loss 4.2969
[step 1700] train loss 3.2134, eval loss 4.3234
[step 1800] train loss 3.1553, eval loss 4.3328
[step 1900] train loss 3.1056, eval loss 4.3219
[step 2000] train loss 3.0641, eval loss 4.3338
[s

In [8]:
prompt = torch.zeros((1, 1), dtype=torch.long, device=device)
print(dc.fn_decode(model.generate(prompt, max_gen=2000)[0].tolist()))

! I can change your mind away.' At the Lady's shoulder had opened his hand and opened his eyes and vanished. The gloom opened his eyes; slungive at this colour else closed their eyes; then they came to his side:
     `Smagol!' cried Shagrat. `Froh, nice water!' A quick end as Gollum had stabbed him in. `Don't wash the rest, go! I want '
     `He's no only yet,' he said Sam. `Be up, soft his neck wasn't got more than even an elft as well.'
     `Ach! ' His enemy was tied, and as they'll know what they mean. No fear of them'll lose!'
     Sam stood out that he fernhelm!' said Sam unled. He was bearing the bundle that he held up and ran, both sterner bolted in all view. 'Well, Lugbrz, eh? But the trick of the magcordor we found here, things needed speed towards Mordor. What about, and why, I had thought to doings made; but at least he made Frodo's getting on the Ring, trying to escape by some good advice, was full. He didn't think that he'd have much _at which he's doing stuck out the tru

In [9]:
torch.save(model.state_dict(), "model.pth")
print("Model saved.")

Model saved.
