In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(1337)

<torch._C.Generator at 0x22b3fde8870>

In [35]:
with open('../data/input.txt', 'r', encoding='utf-8') as f:
    shakespeare = f.read()
    
# List all unique characters that occurs in the input text
chars = sorted(list(set(shakespeare)))
vocab_size = len(chars)

In [22]:
# mapping from characters to integers for encoding
stoi = { ch:i for i, ch in enumerate(chars) }
itos = { i:ch for i, ch in enumerate(chars) }
encode = lambda s : [stoi[c] for c in s] # Take a string and output a list of integers
decode = lambda i : ''.join([itos[c] for c in i]) # Take a list of integers and output a list of string

test_string = "hii, there"
print(encode(test_string))
print(decode(encode(test_string)))

[46, 47, 47, 6, 1, 58, 46, 43, 56, 43]
hii, there


In [36]:
# encode the entire text dataset and store it into a torch.Tensor
data = torch.tensor(encode(shakespeare), dtype=torch.long)

In [24]:
# Let's now split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [37]:
torch.manual_seed(1337)
batch_size = 4 # How many independent sequences will be process in parallel?
block_size = 8 # What is the maximum context length for predictions?

In [39]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [29]:
# A modified neural network for bigram model
class EnhancedBigramLanguageModel(nn.Module):
    def __init__(self, vocab_size, n_embd) -> None:
        super().__init__()
        # the embedding table is in the shape of (vocab_size, vocab_size)
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        
        # ------------------modified--------------------------------------
        # the linear layer to map the embedding to the output
        self.ln_head = nn.Linear(n_embd, vocab_size)
    
    def forward(self, idx, targets=None):
        # idx and target are in the size of (B, T)
        tok_emb = self.token_embedding_table(idx) # (B, T, n_embd)
        
        # ------------------modified--------------------------------------
        logits = self.ln_head(tok_emb)            # (B, T, vocab_size)
        
        if targets is None: # for generation without providing target
            loss = None
        else:
            # Reshape the logits tensor to meet definition of the cross_entropy function in Pytorch
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # Concatenate B, T
            targets = targets.view(B*T)    # Same reshaping to the target tensor
            loss = F.cross_entropy(logits, targets) # calculate the loss
        
        return logits, loss
    
    def generation(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, _ = self(idx)
            
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
            
        return idx

In [30]:
mn = EnhancedBigramLanguageModel(vocab_size, n_embd=32)

In [31]:
optimizer = torch.optim.AdamW(mn.parameters(), lr=1e-3)

In [32]:
from tqdm import tqdm

batch_size = 32
epochs = 10000
losses = []
for e in tqdm(range(epochs)):
    xb, yb = get_batch('train')
    
    logits, loss = mn(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

100%|██████████| 10000/10000 [00:13<00:00, 721.15it/s]


In [33]:
print(loss.item())

2.3846118450164795


In [34]:
print(decode(mn.generation(idx=torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))


Grk b n mertee sonidin y y arermet hn y, denjod be w illd CHALe mer thoun s's:
Thicuntilalllepane sthalldy hangilyoteng h hasbe pan hatrance
RDe hicomyonthar's
PES:
AKEd ith henoungincenonthiousir thondy, y heltieiengerofo'dsssit ey
KINld pe wither vouplloutherccnoha t,
K:

My hind tt hinig t ouchos tes; st younind wotte grotonear 'so it t jod weancothanan hay. t--s n prids, r loncave w holldse se O:
HIs; ht anje caike ineent.

Lavinde.
athave l.
KEONH:
ARThanco be y,-

NEEYowedy scace, aridesar
