In [2]:
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from datetime import datetime
from byte_tokenizer import ByteTokenizer
from class_gpt import GPT
from class_textdataset import TextDataset

In [3]:
today = datetime.today().strftime('%Y-%m-%d')

# parameters of the model
context_length = 16
model_dim = 12  # dimensionality for embedding and attention
num_blocks = 4  # number of repetitions of the transformer block
num_heads = 4  # number of self attention instances, each with size model_dim // num_heads

tokenizer = ByteTokenizer()

vocab_size = tokenizer.vocab_size
batch_size = 8
epochs = 10
lr=3e-4  # learning rate for the gradient descent method
checkpoints = sorted([f for f in os.listdir() if 'pth' in f])
checkpoint_file = torch.load(checkpoints[-1])

To train the model, it takes the txt file 'bon_jovi'

In [4]:
with open("bon_jovi.txt", 'r', encoding='utf-8') as file:
    test = file.read()
print(len(test))  # 39604

dataset = TextDataset(test, tokenizer, context_length)

loader = DataLoader(dataset, batch_size=8, shuffle=True)


39604


In [5]:
# Load the checkpoint and continue training

model = GPT(vocab_size, context_length, model_dim, num_blocks, num_heads)
optimizer = optim.AdamW(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

# Load checkpoint
checkpoint = checkpoint_file 
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1


new_lr = 0.001  # new learning rate

for param_group in optimizer.param_groups:
    param_group['lr'] = new_lr

model.train()  # make sure to set back to training mode

GPT(
  (token_embeddings): Embedding(258, 12)
  (pos_embeddings): Embedding(16, 12)
  (blocks): Sequential(
    (0): TransformerBlock(
      (mhsa): MultiHeadedSelfAttention(
        (heads): ModuleList(
          (0-3): 4 x SingleHeadAttention(
            (get_keys): Linear(in_features=12, out_features=3, bias=False)
            (get_queries): Linear(in_features=12, out_features=3, bias=False)
            (geet_values): Linear(in_features=12, out_features=3, bias=False)
          )
        )
      )
      (first_ln): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
      (second_ln): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
      (ff): VanillaNeuralNetwork(
        (fc1): Linear(in_features=12, out_features=12, bias=True)
        (fc2): Linear(in_features=12, out_features=12, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
      )
    )
    (1): TransformerBlock(
      (mhsa): MultiHeadedSelfAttention(
        (heads): ModuleList(
          (0-3): 4 x Sing

In [6]:

# training loop
for epoch in range(10):
    for batch_x, batch_y in loader:
        
        # forward
        logits = model(batch_x).squeeze()  # [batch, seq, vocab]
        loss = criterion(logits.view(-1, logits.size(-1)), batch_y.view(-1))

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch} | Loss {loss.item():.4f}")

Epoch 0 | Loss 5.0596
Epoch 1 | Loss 5.0596
Epoch 2 | Loss 5.0596
Epoch 3 | Loss 5.0596
Epoch 4 | Loss 5.0596
Epoch 5 | Loss 5.0596
Epoch 6 | Loss 5.0596
Epoch 7 | Loss 5.0596
Epoch 8 | Loss 5.0596
Epoch 9 | Loss 5.0596


In [7]:
# Saving the model’s weights only
gpt_decoder_weights_file = f'gpt_decoder_weights_{today}.pth'

torch.save(model.state_dict(), gpt_decoder_weights_file )

"""
This saves only the parameters (weights and biases).

Lightweight and flexible.

To load it, you need to recreate the model architecture first:
"""


'\nThis saves only the parameters (weights and biases).\n\nLightweight and flexible.\n\nTo load it, you need to recreate the model architecture first:\n'

In [8]:
# Saving the entire model (architecture + weights)
gpt_decoder_full_file = f'decoder_full_{today}.pth'

torch.save(model, gpt_decoder_full_file)


In [9]:
# Save optimizer state too. If you want to resume training:
checkpoint_file = f'checkpoint_{today}.pth'

torch.save({
    'epoch': epoch,                    # last completed epoch
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss                       # optional, for logging
}, checkpoint_file)


⚠️ Important Tips

Always call model.train() after loading if you want to train.

If you only want to evaluate, use model.eval().

Make sure the model architecture and optimizer are created exactly as before.

You don’t need pickle — torch.save / torch.load handles everything safely.

You can change the learning rate (or other hyperparameters) after loading a checkpoint. The key is that the optimizer state (like momentum in Adam/SGD) will still be there, but you can override the learning rate.
