In [12]:
import torch

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

Using cpu device


In [25]:
def prepare_vocab(path_to_data):
    with open(path_to_data, mode='r', encoding='utf-8') as f:
        content = f.read()

    return content, sorted(list(set(content)))

In [26]:
content, vocab = prepare_vocab('./TinyS.txt')
vocab_size = len(vocab)
print("Read in: ", len(content))
print(vocab, vocab_size)

Read in:  1115394
['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 65


In [29]:
def prepare_data(content, vocab, ratio_train):
    dict_ctoi = { char:idx for idx, char in enumerate(vocab) }
    dict_itoc = { idx:char for idx, char in enumerate(vocab) }
    fn_encode = lambda s: [dict_ctoi[c] for c in s]
    fn_decode = lambda s: [dict_itoc[i] for i in s]

    data = torch.tensor(fn_encode(content), dtype=torch.long)
    n = int(len(data) * ratio_train)
    return data[:n], data[n:], fn_decode

In [30]:
train_data, eval_data, decode = prepare_data(content, vocab, 0.8)

In [53]:
batch_size = 16
context_size = 64  # context length for prediction
n_eval = 100       # evaluate n_eval times then calculate the mean
n_feature = 64
dropout_p = 0.0

In [50]:
def collate_data(category):
    data = train_data if category == 'train' else eval_data
    batch_start_idx = torch.randint(len(data) - context_size - 1, (batch_size,))
    x = torch.stack([data[idx:idx+context_size] for idx in batch_start_idx])
    y = torch.stack([data[idx+1:idx+context_size+1] for idx in batch_start_idx])
    x, y = x.to(device), y.to(device)
    return x, y

In [55]:
x, y = collate_data('train')
x.size()

torch.Size([16, 64])

In [52]:
@torch.no_grad()
def calc_loss():
    rslt = {}
    model.eval()
    for c in ['train', 'eval']:
        losses = torch.zeros(n_eval)
        for i in range(n_eval):
            x, y = collate_data(c)
            _, loss = model(x, y)
            losses[i] = loss.item()
        rslt[c] = losses.mean()
    model.train()
    return rslt

In [None]:
class SingleHead(torch.nn.Module):
    def __init__(self, size):
        super.__init__()
        self.query = torch.nn.Linear(n_feature, size, bias=False)
        self.key = torch.nn.Linear(n_feature, size, bias=False)
        self.value = torch.nn.Linear(n_feature, size, bias=False)
        self.register_buffer('tril', torch.tril(context_size, context_size))
        self.dropout = torch.nn.Dropout(dropout_p)

    def forward(self, x):
        batch, ctx, features = x.shape
        