In [52]:
import os, re, torch
import torch.nn as nn
from torch.nn import functional as F
device = 'cpu'

In [58]:

data_str = open(os.path.join('data_processing','media','tinyshakespeare.txt'), 'r').read() # should be simple plain text file
character_all = set(data_str)
SIZE_VOCAB  = len(character_all)
# data = re.sub(r'[^a-zA-Z\s]', '', data) # remove non-alphabet
# seq_length = 8
# words_all = data.split()
# words = list(set([w for w in words_all if len(w) == seq_length]))
# chars_all = [c for w in words for c in w] #[*data]
# chars = list(set(chars_all))  #list(set(data))
# data_size, vocab_size = len(words), len(chars)
# print('data has %d words of size %d, and %d unique characters.' % (len(words), seq_length, vocab_size))
# 'Chars: ' + ' '.join(chars)

In [59]:
d_char2ID = {c:i for i,c in enumerate(character_all)}
d_ID2char = {i:c for i,c in enumerate(character_all)}
encode = lambda string: [d_char2ID[s] for s in string]
decode = lambda code: ''.join([d_ID2char[i] for i in code])
test_str = 'Hello world!'
_encoded = encode(test_str)

print(f'{_encoded = }\n{test_str = }, {decode(encode(test_str)) == test_str = }')

_encoded = [32, 39, 12, 12, 40, 38, 33, 40, 21, 12, 51, 42]
test_str = 'Hello world!', decode(encode(test_str)) == test_str = True


In [60]:
data = torch.tensor(encode(data_str), dtype=torch.long)
train_part = int(0.9* len(data))
data_train = data[:train_part]
data_valid = data[train_part:]

In [65]:
torch.manual_seed(69)
SIZE_CONTEXT = 8
SIZE_BATCH   = 3

def get_batch(which):
    
    data_which = data_train if which == 'train' else data_valid
    data_big = torch.zeros(size=(SIZE_BATCH, SIZE_CONTEXT + 1), dtype = torch.long, device=device)
    # if len(data_which) = size_context + 1 -> index_start = randint(0,1) -> only 0
    index_start = torch.randint(0, len(data_which) - SIZE_CONTEXT, size = (SIZE_BATCH,))  
    for batch, i in enumerate(index_start):
        data_big[batch] = data_which[i:i+SIZE_CONTEXT+1]
    x, y = data_big[:,:-1], data_big[:,1:]
    return x,y

xb, yb = get_batch('train')
xb, yb

(tensor([[40, 14,  7, 53, 24,  8, 38,  3],
         [40, 21, 51, 46,  7, 53,  9, 38],
         [12,  2, 24, 22, 16, 23, 52, 14]]),
 tensor([[14,  7, 53, 24,  8, 38,  3, 62],
         [21, 51, 46,  7, 53,  9, 38,  3],
         [ 2, 24, 22, 16, 23, 52, 14, 39]]))

In [83]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size) -> None:
        super().__init__()
        # table retrieves vocab encodings. SEQ-> [SEQ_LEN, VOCAB]. and its batched.
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self,  IDs, targets = None):
        logits = self.token_embedding_table(IDs) # [BATCH, SEQ_LEN, VOCAB]
        if targets is None:
            loss = None
        else:
            BATCH, SEQ_LEN, VOCAB = logits.size()
            logits = logits.view(BATCH*SEQ_LEN, VOCAB)  
            targets = targets.reshape(BATCH*SEQ_LEN)
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss
    
    def generate(self, IDs, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(IDs)
            # take last in SEQ
            logits = logits[:,-1,:]             
            # prob-normalize
            probs = F.softmax(logits, dim = -1) 
            # get ID from prob
            sample_probs = torch.multinomial(probs, num_samples=1)  
            # add to SEQ dimension
            IDs = torch.cat((IDs, sample_probs), dim = 1)
        return IDs



mod = BigramLanguageModel(SIZE_VOCAB)
decode(mod.generate(IDs = torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0].tolist())

"W'? Sqjv3&QRXzTeb$N'suq;QcYiCea3U:SUqqiJVWdt TdvBMtPH$QY A;, ZSXq\nz;\nHtP ctMt h&s;l.FvuAwI\nzGIbak!B'f"

In [85]:
optimizer = torch.optim.Adam(mod.parameters(), lr=1e-3)

In [93]:
SIZE_BATCH = 32

for step in range(1):
    xb, yb = get_batch('train')
    logits, loss = mod(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(loss.item())
print(decode(mod.generate(IDs = torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0].tolist()))

2.479257345199585
Wind mathe m w cht d's meswit t bu I de,
Tovil;
D whiseng.
Anove,
K:
Th.
Theg ath w, mee s whinuth he
