In [5]:
%run model.ipynb

train has 301,966 tokens
val has 36,059 tokens


In [6]:
class Trainer:

    def __init__(self, config, model, train_dataset):
        self.config = config
        self.model = model
        self.optimizer = None
        self.train_dataset = train_dataset
        self.callbacks = defaultdict(list)
        self.device = config.device
        self.model = self.model.to(self.device)

        # variables that will be assigned to trainer class later for logging and etc
        self.iter_num = 0
        self.iter_time = 0.0
        self.iter_dt = 0.0

    def add_callback(self, onevent: str, callback):
        self.callbacks[onevent].append(callback)

    def set_callback(self, onevent: str, callback):
        self.callbacks[onevent] = [callback]

    def trigger_callbacks(self, onevent: str):
        for callback in self.callbacks.get(onevent, []):
            callback(self)

    def run(self):
        model, config = self.model, self.config

        # setup the optimizer
        self.optimizer = model.configure_optimizers(config)

        # setup the dataloader
        train_loader = DataLoader(
            self.train_dataset,
            sampler=torch.utils.data.RandomSampler(self.train_dataset, replacement=True, num_samples=int(1e10)),
            shuffle=False,
            # pin_memory=True,
            batch_size=config.batch_size,
            num_workers=config.num_workers,
        )

        model.train()
        self.iter_num = 0
        self.iter_time = time.time()
        data_iter = iter(train_loader)
        while True:

            # fetch the next batch (x, y) and re-init iterator if needed
            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(train_loader)
                batch = next(data_iter)
            batch = [t.to(self.device) for t in batch]
            x, y = batch

            # forward the model
            logits, self.loss = model(x, y)

            # backprop and update the parameters
            model.zero_grad(set_to_none=True)
            self.loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
            self.optimizer.step()

            self.trigger_callbacks('on_batch_end')
            self.iter_num += 1
            tnow = time.time()
            self.iter_dt = tnow - self.iter_time
            self.iter_time = tnow

            # termination conditions
            if config.max_iters is not None and self.iter_num >= config.max_iters:
                break

In [38]:
class GPTConfig:
    def __init__(self, vocab_size, **kwargs):
        self.vocab_size = vocab_size
        for key, value in kwargs.items():
            setattr(self, key, value)

class CustomConfig(GPTConfig):
    n_layer = 8
    n_head = 8
    n_embd = 256
    embd_pdrop = 0.1
    resid_pdrop = 0.1
    attn_pdrop = 0.1
    dropout = 0.1
    compile = True
    device = 'cuda'
    num_workers = 0
    max_iters = 2e4
    batch_size = 4
    block_size = 64
    learning_rate = 6e-4
    betas = (0.9, 0.95)
    weight_decay = 1e-1
    grad_norm_clip = 1.0

vocab_size = len(train_ids)
config = CustomConfig(vocab_size=vocab_size)

In [39]:
model = GPT(config).to(config.device)
model.load_state_dict(torch.load('model.pth'))
trainer = Trainer(config, model, train_dataset)

def batch_end_callback(trainer):
    if trainer.iter_num % 500 == 0:
        print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}")
trainer.set_callback('on_batch_end', batch_end_callback)
trainer.run()

number of parameters: 83.64M
iter_dt 0.00ms; iter 0: train loss 5.14950
iter_dt 34.36ms; iter 500: train loss 4.63397
iter_dt 33.56ms; iter 1000: train loss 4.19386
iter_dt 34.42ms; iter 1500: train loss 5.01648
iter_dt 34.47ms; iter 2000: train loss 4.15115
iter_dt 34.44ms; iter 2500: train loss 4.45165
iter_dt 34.45ms; iter 3000: train loss 4.13863
iter_dt 34.43ms; iter 3500: train loss 4.23959
iter_dt 34.51ms; iter 4000: train loss 4.50570
iter_dt 34.47ms; iter 4500: train loss 3.18282
iter_dt 34.48ms; iter 5000: train loss 4.10450
iter_dt 34.45ms; iter 5500: train loss 3.56379
iter_dt 34.43ms; iter 6000: train loss 4.21151
iter_dt 34.48ms; iter 6500: train loss 4.54891
iter_dt 34.39ms; iter 7000: train loss 4.05246
iter_dt 34.43ms; iter 7500: train loss 4.15091
iter_dt 34.39ms; iter 8000: train loss 3.32997
iter_dt 34.42ms; iter 8500: train loss 4.45516
iter_dt 34.46ms; iter 9000: train loss 3.54904
iter_dt 34.45ms; iter 9500: train loss 3.37663
iter_dt 34.45ms; iter 10000: train l

In [40]:
torch.save(model.state_dict(), 'model.pth')

In [33]:
model.load_state_dict(torch.load('model.pth'))

<All keys matched successfully>

In [27]:

state_dict = torch.load('model.pth')
print("Saved state_dict keys:", state_dict.keys())

Saved state_dict keys: odict_keys(['_orig_mod.transformer.wte.weight', '_orig_mod.transformer.wpe.weight', '_orig_mod.transformer.h.0.ln_1.weight', '_orig_mod.transformer.h.0.ln_1.bias', '_orig_mod.transformer.h.0.attn.c_attn.weight', '_orig_mod.transformer.h.0.attn.c_attn.bias', '_orig_mod.transformer.h.0.attn.c_proj.weight', '_orig_mod.transformer.h.0.attn.c_proj.bias', '_orig_mod.transformer.h.0.ln_2.weight', '_orig_mod.transformer.h.0.ln_2.bias', '_orig_mod.transformer.h.0.mlp.c_fc.weight', '_orig_mod.transformer.h.0.mlp.c_fc.bias', '_orig_mod.transformer.h.0.mlp.c_proj.weight', '_orig_mod.transformer.h.0.mlp.c_proj.bias', '_orig_mod.transformer.h.1.ln_1.weight', '_orig_mod.transformer.h.1.ln_1.bias', '_orig_mod.transformer.h.1.attn.c_attn.weight', '_orig_mod.transformer.h.1.attn.c_attn.bias', '_orig_mod.transformer.h.1.attn.c_proj.weight', '_orig_mod.transformer.h.1.attn.c_proj.bias', '_orig_mod.transformer.h.1.ln_2.weight', '_orig_mod.transformer.h.1.ln_2.bias', '_orig_mod.transf

In [42]:
text = 'Lord:\nRise! My people, conquer the north!'
sample_ids = torch.Tensor(enc.encode_ordinary(text)).long()
sample_ids = torch.unsqueeze(sample_ids, 0).to(config.device)
model.eval()
result = model.generate(sample_ids, max_new_tokens=500, temperature=1, do_sample=False, top_k=None)
print(enc.decode(result.detach().cpu().tolist()[0]))

Lord:
Rise! My people, conquer the north!

DUCHESS OF YORK:
I pray thee, let me hear; I will not hear some of him.

DUCHESS OF YORK:
I am a man, sir, that was a man.

DUKE OF YORK:
I was, my lord, I know not what I was,
That I may be, I fear, that have done.

DUCHESS OF YORK:
I pray thee, let me hear; but, gentle Clarence, let me speak.

DUCHESS OF YORK:
I am a king, and I am not king.

YORK:
I am not, my lord.

DUCHESS OF YORK:
I am not yet, I am not yet.

YORK:
I am not, my lord.

YORK:
I am not yet, I am not yet.

YORK:
I am not, my lord.

KING RICHARD III:
I am not king; but, as I hear,
I am the king, the king, the king, and I,
The king, the king, the king, the king,
The king, the king, the king, and all,
The king, the king, and all his son,
With all his son, his son, his son, his son:
And then, I hope, he is at hand.

CLARENCE:
I am too late, my lord.

GLOUCESTER:

KING EDWARD IV:
I am too late, my lord.

YORK:
I am not yet, my lord.

GLOUCESTER:

YORK:
I am not yet.

YORK:
I am n