In [1]:
import torch

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

  cpu = _conversion_method_template(device=torch.device("cpu"))


Using mps device


In [2]:
class DataCollator:
    def __init__(self, path_to_data, ratio_train):
        with open(path_to_data, mode='r', encoding='utf-8') as f:
            content = f.read()

        self.content = content
        self.vocab = sorted(list(set(content)))

        dict_ctoi = { char:idx for idx, char in enumerate(self.vocab) }
        dict_itoc = { idx:char for idx, char in enumerate(self.vocab) }
        self.fn_encode = lambda s: [dict_ctoi[c] for c in s]
        self.fn_decode = lambda s: [dict_itoc[i] for i in s]

        data = torch.tensor(self.fn_encode(content), dtype=torch.long)
        n = int(len(data) * ratio_train)
        self.train_data = data[:n]
        self.eval_data = data[n:]

    def collate_data(self, category, batch_size, context_size):
        data = self.train_data if category == 'train' else self.eval_data
        batch_start_idx = torch.randint(len(data) - context_size - 1, (batch_size,))
        x = torch.stack([data[idx:idx+context_size] for idx in batch_start_idx])
        y = torch.stack([data[idx+1:idx+context_size+1] for idx in batch_start_idx])
        x, y = x.to(device), y.to(device)
        return x, y

In [3]:
dc = DataCollator('./TinyS.txt', 0.8)

print("Read in: ", len(dc.content))
print(dc.vocab, len(dc.vocab))

Read in:  1115394
['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 65


In [4]:
class MaskedSingleHeadAttention(torch.nn.Module):
    def __init__(self, head_size, context_size, n_feature, dropout_p):
        super().__init__()
        self.query = torch.nn.Linear(n_feature, head_size, bias=False)
        self.key = torch.nn.Linear(n_feature, head_size, bias=False)
        self.value = torch.nn.Linear(n_feature, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(context_size, context_size)))
        self.dropout = torch.nn.Dropout(dropout_p)

    def forward(self, x):
        # x: (b, c, f)
        batch, ctx, features = x.shape
        # q or k: (b, c, f) @ (f, h) = (b, c, h) where h(head_size) = f / n_head
        q = self.query(x)
        k = self.key(x)
        # calc attention score, w: (b, c, c)
        w = q @ k.transpose(-2, -1) * features**-0.5
        w = w.masked_fill(self.tril[:ctx, :ctx] == 0, float('-inf'))
        w = torch.nn.functional.softmax(w, dim=-1)
        w = self.dropout(w)
        # cal weighted value, v: (b, c, h)
        v = self.value(x)
        # (b, c, c) @ (b, c, h) = (b, c ,h)
        rslt = w @ v
        return rslt

class MaskedMultiHeadAttention(torch.nn.Module):
    def __init__(self, n_head, context_size, n_feature, dropout_p):
        super().__init__()
        head_size = n_feature // n_head
        self.heads = torch.nn.ModuleList([MaskedSingleHeadAttention(head_size, context_size, n_feature, dropout_p) for _ in range(n_head)])
        self.projection = torch.nn.Linear(n_feature, n_feature)
        self.dropout = torch.nn.Dropout(dropout_p)

    def forward(self, x):
        # (b, c ,h) --cat--> (b, c, f)
        rslt = torch.cat([head(x) for head in self.heads], dim=-1)
        rslt = self.dropout(self.projection(rslt))
        return rslt

class FeedFoward(torch.nn.Module):
    def __init__(self, n_feature, dropout_p):
        super().__init__()
        self.seq = torch.nn.Sequential(
            torch.nn.Linear(n_feature, n_feature * 4),
            torch.nn.ReLU(),
            torch.nn.Linear(n_feature * 4, n_feature),
            torch.nn.Dropout(dropout_p),
        )

    def forward(self, x):
        return self.seq(x)

class TransformerUnit(torch.nn.Module):
    def __init__(self, n_head, context_size, n_feature, dropout_p):
        super().__init__()
        self.mha = MaskedMultiHeadAttention(n_head, context_size, n_feature, dropout_p)
        self.ff = FeedFoward(n_feature, dropout_p)
        self.mha_ln = torch.nn.LayerNorm(n_feature)
        self.ff_ln = torch.nn.LayerNorm(n_feature)

    def forward(self, x):
        x = x + self.mha(self.mha_ln(x))
        x = x + self.ff(self.ff_ln(x))
        return x

class NaiveLangModel(torch.nn.Module):
    def __init__(self, vocab_size, n_layer, n_head, context_size, n_feature, dropout_p):
        super().__init__()
        self.token_embed = torch.nn.Embedding(vocab_size, n_feature)
        self.position_embed = torch.nn.Embedding(context_size, n_feature)
        self.units = torch.nn.Sequential(*[TransformerUnit(n_head, context_size, n_feature, dropout_p) for _ in range(n_layer)])
        self.ln = torch.nn.LayerNorm(n_feature)
        self.pred_head = torch.nn.Linear(n_feature, vocab_size)
        self.context_size = context_size

    def forward(self, inputs, labels=None):
        batch, ctx = inputs.shape
        # t_embed: (b, c, f); p_embed: (c,f)
        t_embed = self.token_embed(inputs)
        p_embed = self.position_embed(torch.arange(ctx, device=device))
        # x: (b, c, f)
        x = t_embed + p_embed
        x = self.units(x)
        x = self.ln(x)
        # logits: (b, c, v) 
        logits = self.pred_head(x)

        if labels is None:
            return logits, None

        batch, ctx, features = logits.shape
        predicts = logits.view(batch*ctx, features)
        targets = labels.view(batch*ctx)
        return logits, torch.nn.functional.cross_entropy(predicts, targets)

    def generate(self, inputs, max_gen):
        for _ in range(max_gen):
            inputs_last_window = inputs[:, -self.context_size:]
            logits, loss = self(inputs_last_window)
            logits = logits[:, -1, :]
            pred_next = torch.multinomial(torch.nn.functional.softmax(logits, dim=1), num_samples=1)
            inputs = torch.cat((inputs, pred_next), dim=1)
        return inputs

In [5]:
# n_layer = 4
# n_head = 4
# n_feature = 64
# dropout_p = 0.0
# context_size=32 # context length for prediction

model = NaiveLangModel(vocab_size=len(dc.vocab), n_layer=4, n_head=4, context_size=32, n_feature=64, dropout_p=0.0)
model = model.to(device)
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

0.209729 M parameters


In [6]:
def train_model(learning_rate, batch_size, steps, eval_interval, n_eval):
    @torch.no_grad()
    def calc_loss(n_eval, batch_size):
        rslt = {}
        model.eval()
        for c in ['train', 'eval']:
            losses = torch.zeros(n_eval)
            for i in range(n_eval):
                x, y = dc.collate_data(c, batch_size, model.context_size)
                _, loss = model(x, y)
                losses[i] = loss.item()
            rslt[c] = losses.mean()
        model.train()
        return rslt

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

    for step in range(steps):
        if step % eval_interval == 0 or step == steps - 1:
            losses = calc_loss(n_eval, batch_size)
            print(f"[step {step}] train loss {losses['train']:.4f}, eval loss {losses['eval']:.4f}")
    
        x, y = dc.collate_data('train', batch_size, model.context_size)
        _, loss = model(x, y)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

In [None]:
# steps = 5000
# eval_interval = 100 # evaluate every N steps
# batch_size = 16
# n_eval = 100       # evaluate n_eval times then calculate the mean
# lr = 1e-3
train_model(learning_rate=1e-3, batch_size=16, steps=5000, eval_interval=100, n_eval=100)

In [None]:
prompt = torch.zeros((1, 1), dtype=torch.long, device=device)
print(dc.fn_decode(model.generate(prompt, max_gen=1000)[0].tolist()))