In [19]:
from dataclasses import dataclass
import math

import torch
from torch import nn
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available() else 'cpu'

n_embd: int = 128
n_layer: int = 4
n_head: int = 4
multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
norm_eps: float = 1e-5
batch_size: int = 16
block_size: int = 128


text = open('../data/shakespeare.txt', encoding='utf-8').read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y



In [20]:
get_batch('train')

(tensor([[40, 39, 45,  ..., 46, 43,  1],
         [ 0, 32, 46,  ...,  0,  0, 18],
         [50, 43, 58,  ..., 53, 52, 11],
         ...,
         [58,  1, 40,  ...,  6,  0, 35],
         [ 1, 50, 53,  ..., 53, 59, 56],
         [43, 40, 39,  ...,  1, 40, 39]], device='cuda:0'),
 tensor([[39, 45, 45,  ..., 43,  1, 61],
         [32, 46, 39,  ...,  0, 18, 30],
         [43, 58,  1,  ..., 52, 11,  1],
         ...,
         [ 1, 40, 43,  ...,  0, 35, 47],
         [50, 53, 53,  ..., 59, 56,  5],
         [40, 39, 58,  ..., 40, 39, 41]], device='cuda:0'))