In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import config, generalizable
config.reload_config()

Loading config for ENV_NAME=dev
Reloaded config
Loading config for ENV_NAME=dev
Reloaded config


In [3]:
import torch
torch.manual_seed(1337)
device = generalizable.get_best_torch_device()

### Load data

In [4]:
from custom_tokenizer import FrequencyGreedyTokenizer

# raw_text = generalizable.load_text_file('~/datasets/complete_shakespeare.txt')
raw_text = generalizable.load_text_directory('~/datasets/books')
tokenizer = FrequencyGreedyTokenizer()
tokenizer.load("vocab.json")
tds = generalizable.encode_text(raw_text, tokenizer=tokenizer)
del raw_text
full_data = torch.tensor(tds.all_tokens, dtype=torch.long, device=device)
print(tds.decode_token_list_to_string(full_data[:10].tolist()))
print(full_data.shape, full_data.dtype, full_data.device)

Loading *.txt from: /Users/chris/datasets/books
Loaded 63 files, total length: 48810735
Text length: 48810735, total tokens: 13165927, vocab size: 9922
 WUTHERING HEI
torch.Size([13165927]) torch.int64 mps:0


In [5]:
TRAIN_FRACTION = 0.9

n = int(TRAIN_FRACTION * len(full_data))
train_data = full_data[:n]
val_data = full_data[n:]

### Define the model

In [6]:
import torch
import torch.nn as nn
from torch.nn import functional as F

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx and targets are both (B,T) tensor of integers
        # Logits are the raw (non-normalized) predictions that a classification model generates,
        # which are then passed to the normalization function, like softmax.
        # The cross-entropy loss expects raw logits, not the output of a softmax, so that it can apply its own softmax.
        logits = self.token_embedding_table(idx) # (B,T,C)
        # logits is B*T*C, becuase for each batch, we make `Timestep` predictions, and for each prediction we have C classes
        # One character/token predictions for every timestep across the batch sequences
        # print(f"logits.shape: {logits.shape}")

        if targets is None:
            loss = None
        else:
            # Batch, Timestep, Channels
            B, T, C = logits.shape
            
            # cross_entropy expects channels to be the second dimension, so we reshape
            # This view squishes the first two dimensions into one, so that the shape becomes (B*T, C)
            logits = logits.view(B*T, C)
            # print(f"logits.shape after view: {logits.shape}")
            
            # print(f"targets.shape: {targets.shape}")
            targets = targets.view(B*T)
            # print(f"targets.shape after view: {targets.shape}")
            
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # extract only the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        
        return idx

### Create a WandB run

In [7]:
import wandb
import humanize

LR = 1e-3
BLOCK_SIZE = 8
BATCH_SIZE = 128
NUM_LOSS_ESTIMATE_BATCHES = 400

wandb.init(
    # set the wandb project where this run will be logged
    project="language-models",
    
    # track hyperparameters and run metadata
    config={
        "learning_rate": LR,
        "batch_size": BATCH_SIZE,
        "block_size": BLOCK_SIZE,
        "model": "bigram",
        "tokenization": "greedy_frequency",
    }
)

m = BigramLanguageModel(tds.vocab_size).to(device)
# Learning rate should usually be 1e-4, but for small networks like this, 1e-3 works
optimizer = torch.optim.AdamW(m.parameters(), lr=LR)
epochs_trained = 0

[34m[1mwandb[0m: Currently logged in as: [33mchrisc[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
def get_batch(dataset):
    return generalizable.get_batch(dataset, BATCH_SIZE, BLOCK_SIZE, device)

def log_wandb_stats(batches_trained):
    train_loss = generalizable.estimate_loss(NUM_LOSS_ESTIMATE_BATCHES, lambda : get_batch(train_data), m)
    val_loss = generalizable.estimate_loss(NUM_LOSS_ESTIMATE_BATCHES, lambda : get_batch(val_data), m)
    wandb.log({"train_loss": train_loss, "val_loss": val_loss}, step=batches_trained)
    return train_loss, val_loss

def print_text_sample(num_tokens=30):
    idx = torch.zeros((1, 1), dtype=torch.long, device=device)
    tokens = m.generate(idx = idx, max_new_tokens=num_tokens)[0].tolist()
    print(tds.decode_token_list_to_string(tokens))

wandb_interval = 400
sample_print_multiplier = 6

log_wandb_stats(epochs_trained * BATCH_SIZE)

for _, steps in enumerate(range(100000)):
    # sample a batch of data
    xb, yb = get_batch(train_data)

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    epochs_trained += 1
    
    if epochs_trained % wandb_interval == 0:
        train_loss, val_loss = log_wandb_stats(epochs_trained * BATCH_SIZE)
        
        if epochs_trained % (wandb_interval * sample_print_multiplier) == 0:
            print(f"\n\nStep {humanize.intcomma(epochs_trained)}, train_loss: {train_loss}, val_loss: {val_loss}")
            print("------")
            print_text_sample()

print("\n\Final text sample:\n==============\n\n")
print_text_sample(100)

step 2,400, train_loss: 0.045582786202430725, val_loss: 0.045298460870981216
  He wes, t He  went in ch e hererinoug staon’ go t hishowevceivewhen own te ads n, aexprieldinid
w enoforts de r
aow thmone
step 4,800, train_loss: 0.04175248742103577, val_loss: 0.04218919202685356
 [ul d’our Sance. but the bo thesan s had
th hKύbertto s bettf
e onodech
crowisYespirs oned, a ben sudMi
step 7,200, train_loss: 0.03910098597407341, val_loss: 0.039747752249240875
 er o eimber rifEnd cm wall affe. Tbles! ifttansupproddaysaffςs no ύe
cat’sl yoks ouressioiness busme o
step 9,600, train_loss: 0.036468103528022766, val_loss: 0.037566471844911575
  walway,  hop

—ut i rect, thing.d. Thpond mi readt by eignt inrrieittle

—sonal tatferenld,e an. T, soldih, an to pbly 
step 12,000, train_loss: 0.034075602889060974, val_loss: 0.03594163432717323
  thre olmonth anck passn mes
thmar
Drely“Itith ad evhad tng at
Tperhss.necede,ar, dera h, and I shformount re ofs. A
step 14,400, train_loss: 0.03266495838761329

In [11]:
# idx = torch.zeros((1, 1), dtype=torch.long, device=device)
# tokens = m.generate(idx = idx, max_new_tokens=500)[0].tolist()
# # print(tds.decode_token_list_to_string(tokens))
# print('\n'.join([tds.int_to_str[tok] for tok in tokens]))
# print_text_sample()
def print_text_sample(num_tokens=30):
    idx = torch.zeros((1, 1), dtype=torch.long, device=device)
    tokens = m.generate(idx = idx, max_new_tokens=num_tokens)[0].tolist()
    print(tds.decode_token_list_to_string(tokens))

print_text_sample(100)

 ze:

— Yo, se, we don’t me
andp tog blastrighrough its flacy, broad, and close and SENO! Siby, Mine! ‘ORuth of the truth.

Julies,
Tashoreover, she might behavi each symen, who did she put my idea—“No, bendingw height face, id, I am now living betweeelicious course you eat many such tried to twentosura combing the ECzard.

Pip


In [9]:
# Finish the WandB run
wandb.finish()


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train_loss,█▇▆▅▄▄▃▃▂▂▂▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▇▆▅▄▄▃▃▂▂▂▂▂▁▁▂▁▁▂▁▁▂▁▂▁▁▁▁▁▁▁▁▂▁▁▁▁▁▂▂

0,1
train_loss,0.02567
val_loss,0.0297


In [None]:
with_chars = """
Gofuidrme :
Toun hu: l:
I w herseso thell de de; omole l to.
Aneverkined leoumy e thuftxy, wlveres juront lure omy hesph yo harer core' veneayoro ne trsw'll isurendime ilerve fin?

BRI meso ot ce pr w o can w'losto h prin hif mem! Soums masan IZBELARomit Mu out ikentharn hen rvaiwit, curencer y ve il
UKI I to welilearas be mpure thiks! hatte'd aized'?

tar m anju tathe afedind inghevere Yowhare fofot foling a ge wr'farmfien ar aw bellathy feldin bawd henoslendothecendeang--

d! ond buthet, miz! 
"""