In [1]:
import torch
from tqdm import tqdm
from pathlib import Path
from model import MiniGPT
from config import MiniGPTConfig

In [2]:
dataset_filepath = Path("./data.txt")
trained_model_filepaht = Path("./mini_gpt.pt")
batch_size = 32
block_size = 8

# Create Vocab

In [3]:
with open(dataset_filepath,"r") as file:
    text = file.read()

In [4]:
# Create vocab for tokenizer
chars = sorted(list(set(text)))
vocab_size = len(chars)

In [5]:
# Create lookup table
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

In [6]:
# Create simple encoder and decoder in global namespace
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])

# Create Simple Dataset

In [7]:
# Encode dataset and store it as a torch long tensor.
data = torch.tensor(encode(text), dtype=torch.long)

In [8]:
def get_batch(split):
    """
    generate a small batch of data of inputs x and targets y.
    this helps you to get random chunk in the entire dataset
    """
    # Create four random numbers.
    ix = torch.randint(len(data) - block_size, (batch_size,)) # exp. tensor([ 76049, 234249, 934904, 560986])
    # Create x and y
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

Check if encoder/decoder work properly.

In [9]:
print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [10]:
model_config = MiniGPTConfig()
model = MiniGPT(model_config)

In [11]:
device = torch.device("cuda")
model.to(device)

MiniGPT(
  (token_embedder): Embedding(65, 768)
  (positional_embedder): Embedding(1024, 768)
  (dropout): Dropout(p=0.0, inplace=False)
  (decoder_layers): ModuleList(
    (0-11): 12 x Decoder(
      (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (masked_multihead_attention): CausalSelfAttention(
        (attention_projection): Linear(in_features=768, out_features=2304, bias=True)
        (output_projection): Linear(in_features=768, out_features=768, bias=True)
        (attention_dropout): Dropout(p=0.0, inplace=False)
        (residual_dropout): Dropout(p=0.0, inplace=False)
      )
      (feed_forward): FeedForward(
        (linear1): Linear(in_features=768, out_features=3072, bias=True)
        (gelu): GELU(approximate='none')
        (linear2): Linear(in_features=3072, out_features=768, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (layer_normali

Just some random text to generate to see it's total non-sense.

In [12]:
# Text generation is going to start with new line character
idx = torch.zeros((1,1), dtype= torch.long, device=device)
# Use index 0 to unplug a sample
print(decode(model.generate(idx, max_new_tokens=100)[0].tolist()))


:$gq3fUeoCiuooMtT
kyWhLEEocl
'fybK
QcEo.yOWR:'ynWCahe$T:yc?gwOK
o-PHQunRu.Ux3.jdWRIlYAcEQ.l;j,
?RhIv


Since this is just a simple demo we won't use tensorboard and only print out models loss before and after the training.

In [13]:
xb, yb = get_batch("train")
xb = xb.to(device)
yb = yb.to(device)
logits, loss = model(xb, yb)
print(f"Random sample model loss before training: {loss.item()}")

Random sample model loss before training: 4.333423137664795


# Train Model

In [14]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
model.train()
for steps in tqdm(range(50000)):
    xb, yb = get_batch("train")
    xb = xb.to(device)
    yb = yb.to(device)
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(f"Random sample model loss after training: {loss.item()}")

100%|██████████| 50000/50000 [19:04<00:00, 43.69it/s]

Random sample model loss after training: 1.572378158569336





In [17]:
# Save trained model
model.state_dict()
torch.save(model.state_dict(), trained_model_filepaht)

Generate some text to see if model has learned a little or not (Don't expect anything crazy this is just to see if training is working or not.)

In [31]:
model.eval()
with torch.inference_mode():
    print(decode(model.generate(idx, max_new_tokens=50)[0].tolist())) # 0 is to unplug a sample


son, your thor twalllld whounourne ougeld y thinou
