In [11]:
import torch
import numpy as np
import time
import torch.nn as nn
import os
import math
import torch.nn.functional as F 

In [3]:
class SciFiConfig:
    vocab_size: int = 100277  # cl100k-base
    n_embd: int = 768  # GPT-2
    
class MLP(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.config = config
    
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu = nn.GELU()
        self.proj = nn.Linear(4 * config.n_embd, config.vocab_size)
        
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.uniform_(module.weight, -1.0, 1.0)
    
    def forward(self, idx, targets=None):
        #B, T = idx.shape
        tok_emb = self.wte(idx) # (B, T, n_embd)
        logits = self.proj(self.gelu(self.fc(tok_emb)))
        
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1), ignore_index=-1)
        return logits, loss
    
    def optimizer(self, learning_rate):
        optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate, betas=(0.9, 0.95), eps=1e-8)
        return optimizer

# ---------------------------------------------------------------------------------
import tiktoken
import numpy as np

def load_tokens(filename):
    tokens = np.loadtxt(filename, dtype=np.int32)
    tokens = torch.tensor(tokens, dtype=torch.long)
    return tokens

class DataLoaderSciFi:
    def __init__(self, B, T, split=None):
        self.B = B
        self.T = T
        
        # get filename of dataset
        self.data_dir = 'data'
        self.tokens_filename = 'tokens.txt'
        self.tokens_path = os.path.join(self.data_dir, self.tokens_filename)
        
        self.reset()
    
    def reset(self):
        self.tokens = load_tokens(self.tokens_path)
        self.cur_pos = 0
    
    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.cur_pos : self.cur_pos + B * T + 1]
        x = buf[: B * T].view(B, T)
        y = buf[1: ].view(B, T)
        self.cur_pos = self.cur_pos + B * T 
        return x, y    

In [4]:
# import tiktoken

# enc = tiktoken.get_encoding('cl100k_base')
# text = "Hello scientific fiction!"
# tokens = torch.tensor(enc.encode(text))
# targets = torch.cat((tokens[1:], torch.tensor([-1])), dim=-1)

# print(f"Encoded text: {tokens}")
# print(f"Targets: {targets}")
# print(f"The vocab_size of cl100k_base is {enc.n_vocab}.")

In [12]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
print(f"using device: {device}")

B = 64
T = 64
max_steps = 438230 # 28046749 tokens / (B * T)
mini_steps = 1000
lr = 3e-3

loader = DataLoaderSciFi(B=B, T=T)
model = MLP(SciFiConfig)
model.to(device)
optimizer = model.optimizer(lr)

torch.manual_seed(2024)

for step in range(mini_steps):
    t0 = time.time()
    x, y = loader.next_batch()
    x, y = x.to(device), y.to(device)
    optimizer.zero_grad()
    logits, loss = model(x, y)
    loss.backward()
    
    optimizer.step()
    t1 = time.time()
    dt = t1 - t0

    if step % 100 == 0:
        print(f'step {step} duration: {dt*1000:.2f}ms loss={loss:.4f}')

using device: cuda
step 0 duration: 14.29ms loss=11.5189
step 100 duration: 193.06ms loss=6.5283
step 200 duration: 197.85ms loss=5.3500
step 300 duration: 195.67ms loss=4.8795
step 400 duration: 197.31ms loss=8.5506
step 500 duration: 196.71ms loss=6.2702
step 600 duration: 196.45ms loss=5.4567
step 700 duration: 196.80ms loss=5.9329
step 800 duration: 197.70ms loss=5.6451
step 900 duration: 194.07ms loss=5.8578


In [35]:
# generate a sample after 1000 iterations
import tiktoken

max_length = 128
num_samples = 4
enc = tiktoken.get_encoding('cl100k_base')

tokens = enc.encode("In 2024, a boy was sitting in front of a park.")
tokens = torch.tensor(tokens, dtype=torch.long)
tokens = tokens.unsqueeze(0).repeat(num_samples, 1) # (4, n_tokens)
gen_idx = tokens.to(device)

sample_rng = torch.Generator(device=device)
sample_rng.manual_seed(2028)

while gen_idx.size(1) < max_length:
     logits, _ = model(gen_idx) # logits (num_samples, vocab_size)
     logits = logits[:, -1, :]
     probs = F.softmax(logits, dim=-1) 
     top_probs, top_indices = torch.topk(probs, 50, dim=-1)
     ix = torch.multinomial(top_probs, 1, generator=sample_rng)
     out_ix = torch.gather(top_indices, -1, ix)
     gen_idx = torch.cat((gen_idx, out_ix), dim=1)

for i in range(num_samples):
     tokens = gen_idx[i, :].tolist()
     decoded = enc.decode(tokens)
     print(f"sample {i}: {decoded}")

sample 0: In 2024, a boy was sitting in front of a park. I would
cunning at what might you are you will be
that he
me_ be
as he could not
for each of this time the
and there is."
"You'd better yet in hand. "I don't think in upon
had failed."
"Sure. It seemed to be in in this very
her, after all over. I've had been
their minds from the
n't been the top of the
succeeded in an intelligent faces."
"Where do, the shipwonderful
not to make you,
sample 1: In 2024, a boy was sitting in front of a park. He moved in so the
to die, the girl with the first to call
be to their position. "Well the
his
you didn't hurt.
"But why it, for
for you, even for it
all to the rest of course.
"Exactly." She can't see. He
was, for I'd never have
the worst.
"If you and more
of men. It's a week."
"I'm not know how
paled, not of their hands.
I see your
you're going to a
the first
sample 2: In 2024, a boy was sitting in front of a park. We will get back to see you
their winger of
her, but that,
it's heart of his fe