In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.nn import functional as F

In [2]:
block_size = 8
batch_size = 4

In [3]:
from datasets import load_dataset
ds = load_dataset("Trelis/tiny-shakespeare")

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
ds_ = ds['train']['Text']

In [5]:
chars = sorted(set(''.join(ds_))) # vocabulory
len(chars)

65

In [None]:
encoder_dict = {k:v for k,v in zip(chars, range(len(chars)))}
decoder_dict = {v:k for k,v in zip(chars, range(len(chars)))}

# Encoder, Decoder
encode = lambda x: [encoder_dict[letter] for letter in x]
decode = lambda x: ''.join([decoder_dict[letter] for letter in x])
encode('hello'), decode([46, 43, 50, 50, 53])

([46, 43, 50, 50, 53], 'hello')

In [7]:
ds_all = '\n'.join(ds_)
ds_encoded = encode(ds_all)

In [8]:
n = int(.9*len(ds_encoded))
train_data = ds_encoded[:n]
val_data = ds_encoded[n:]

In [39]:
def get_batch(data, block_size=block_size, batch_size=batch_size):
    ix = torch.randint(len(data)-block_size, (batch_size,)) # these numbers are the start of each batch
    xy = torch.tensor([[data[i] for i in range(ix[i],ix[i]+block_size+1)] for i in range(batch_size)])
    xb = xy[:, :block_size]
    yb = xy[:,1:]
    # print(xb.shape, yb.shape) # (B x T)
    return xb, yb

get_batch(train_data)

(tensor([[59, 51, 54, 46, 39, 52, 58,  1],
         [57, 58, 39, 63,  1, 58, 47, 50],
         [58, 56, 47, 41, 47, 39, 52, 57],
         [57, 43,  0, 21, 52,  1, 58, 46]]),
 tensor([[51, 54, 46, 39, 52, 58,  1, 44],
         [58, 39, 63,  1, 58, 47, 50, 50],
         [56, 47, 41, 47, 39, 52, 57,  2],
         [43,  0, 21, 52,  1, 58, 46, 43]]))

In [95]:
class LLM(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx)
        if not targets:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
            # print(logits.shape, loss.shape)
        return logits, loss
    
    def generate(self, idx, max_tokens=100):
        for _ in range(max_tokens):
            logits, loss = self(idx)
            # logits is (B x T x C)
            logits = logits[:, -1, :] # take only the last(latest) one in T component
            probs = F.softmax(logits, dim=-1) # (B, C)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

llm = LLM(len(chars))
# llm.forward(get_batch(train_data)[0])
# decode(llm.generate(get_batch(train_data)[0]))
print(decode(llm.generate(idx = torch.zeros((1, 1), dtype=torch.long))[0].tolist()))



gz$ULKqMkcHesTNDCfhYzsl$l3VkZcw?gNWcKRn$!MENSCIsMisfHDrd,amO!Ioj3
EXEFjZwFNYFUkGXl!!o$o$lbjN,Nse'soY


In [76]:
decode(llm.generate(get_batch(train_data)[0]).tolist())

TypeError: unhashable type: 'list'