In [1]:
import torch

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

Using cpu device


In [2]:
def prepare_vocab(path_to_data):
    with open(path_to_data, mode='r', encoding='utf-8') as f:
        content = f.read()

    return content, sorted(list(set(content)))

In [3]:
content, vocab = prepare_vocab('./TinyS.txt')
vocab_size = len(vocab)
print("Read in: ", len(content))
print(vocab, vocab_size)

Read in:  1115394
['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 65


In [4]:
def prepare_data(content, vocab, ratio_train):
    dict_ctoi = { char:idx for idx, char in enumerate(vocab) }
    dict_itoc = { idx:char for idx, char in enumerate(vocab) }
    fn_encode = lambda s: [dict_ctoi[c] for c in s]
    fn_decode = lambda s: [dict_itoc[i] for i in s]

    data = torch.tensor(fn_encode(content), dtype=torch.long)
    n = int(len(data) * ratio_train)
    return data[:n], data[n:], fn_decode

In [5]:
train_data, eval_data, decode = prepare_data(content, vocab, 0.8)

In [6]:
# need train_data, eval_data
def collate_data(category):
    data = train_data if category == 'train' else eval_data
    batch_start_idx = torch.randint(len(data) - context_size - 1, (batch_size,))
    x = torch.stack([data[idx:idx+context_size] for idx in batch_start_idx])
    y = torch.stack([data[idx+1:idx+context_size+1] for idx in batch_start_idx])
    x, y = x.to(device), y.to(device)
    return x, y

In [7]:
@torch.no_grad()
def calc_loss():
    rslt = {}
    model.eval()
    for c in ['train', 'eval']:
        losses = torch.zeros(n_eval)
        for i in range(n_eval):
            x, y = collate_data(c)
            _, loss = model(x, y)
            losses[i] = loss.item()
        rslt[c] = losses.mean()
    model.train()
    return rslt

In [8]:
class MaskedSingleHeadAttention(torch.nn.Module):
    def __init__(self, head_size, context_size, n_feature, dropout_p):
        super.__init__()
        self.query = torch.nn.Linear(n_feature, head_size, bias=False)
        self.key = torch.nn.Linear(n_feature, head_size, bias=False)
        self.value = torch.nn.Linear(n_feature, head_size, bias=False)
        self.register_buffer('tril', torch.tril(context_size, context_size))
        self.dropout = torch.nn.Dropout(dropout_p)

    def forward(self, x):
        # x: (b, c, f)
        batch, ctx, features = x.shape
        # q or k: (b, c, f) @ (f, h) = (b, c, h) where h(head_size) = f / n_head
        q = self.query(x)
        k = self.key(x)
        # calc attention score, w: (b, c, c)
        w = q @ k.transpose(-2, -1) * features**-0.5
        w = w.masked_fill(self.tril[:ctx, :ctx] == 0, float('-inf'))
        w = torch.nn.functional.softmax(w, dim=-1)
        w = self.dropout(w)
        # cal weighted value, v: (b, c, h)
        v = self.value(x)
        # (b, c, c) @ (b, c, h) = (b, c ,h)
        rslt = w @ v
        return rslt

In [9]:
class MaskedMultiHeadAttention(torch.nn.Module):
    def __init__(self, n_head, context_size, n_feature, dropout_p):
        super.__init__()
        head_size = n_feature // n_head
        self.heads = torch.nn.ModuleList([MaskedSingleHeadAttention(head_size, context_size, n_feature, dropout_p)])
        self.projection = torch.nn.Linear(n_feature, n_feature)
        self.dropout = torch.nn.Dropout(dropout_p)

    def forward(self, x):
        # (b, c ,h) --cat--> (b, c, f)
        rslt = torch.cat([head(x) for head in self.heads], dim=-1)
        return self.dropout(self.projection(rslt))

In [10]:
class FeedFoward(torch.nn.Module):
    def __init__(self, n_feature, dropout_p):
        super().__init__()
        self.seq = torch.nn.Sequential(
            torch.nn.Linear(n_feature, n_feature * 4),
            torch.nn.ReLU(),
            torch.nn.Linear(n_feature * 4, n_feature),
            torch.nn.Dropout(dropout_p),
        )

    def forward(self, x):
        return self.seq(x)

In [11]:
class TransformerUnit(torch.nn.Module):
    def __init__(self, n_head, context_size, n_feature, dropout_p):
        super().__init__()
        self.mha = MaskedMultiHeadAttention(n_head, context_size, n_feature, dropout_p)
        self.ff = FeedFoward(n_feature, dropout_p)
        self.mha_ln = torch.nn.LayerNorm(n_feature)
        self.ff_ln = torch.nn.LayerNorm(n_feature)

    def forward(self, x):
        x += self.mha(self.mha_ln(x))
        x += self.ff(self.ff_ln(x))
        return x

In [12]:
batch_size = 16
context_size = 64  # context length for prediction
n_eval = 100       # evaluate n_eval times then calculate the mean
n_feature = 64
n_head = 4
dropout_p = 0.0