# Train BigramLanguageModel

    This notebook trains a neural network based on the lecture
    https://www.youtube.com/watch?v=kCc8FmEb1nY
    Let's build GPT: from scratch, in code, spelled out. by Andrej Karpathy.

    The objective is to assimilate the contents of the lecture and practice with pytorch, transformer architecture, and neural network training.

    This is approached by implementing a gpt-like model from scratch using the layers pytorch provides, instead of the custom ones from the lecture, to replicate the results of loss and text generation.

In [78]:
import torch

from bigram_language_model_karpathy import BigramLanguageModelKarpathy
from bigram_language_model_torch_layers import BigramLanguageModelTorchLayers

In [79]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [80]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2023-03-01 14:54:01--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.10’


2023-03-01 14:54:01 (30.2 MB/s) - ‘input.txt.10’ saved [1115394/1115394]



In [81]:
batch_size = 32
block_size = 128
max_iters = 1000
eval_interval = 500
learning_rate = 3e-4
dropout = 0.2
eval_iters = 200
number_layers = 2
number_heads = 4
number_embeddings = number_heads * 32  # 384 / 6 = 64 dimensional heads

torch.manual_seed(1337)

# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open("input.txt", "r", encoding="utf-8") as f:
    corpus = f.read()

# The properties of the test
chars = sorted(list(set(corpus)))
vocab_size = len(chars)

# Encoding and decoding
string_to_int = {character: index for index, character in enumerate(chars)}
int_to_string = {index: character for index, character in enumerate(chars)}
encode = lambda string: [string_to_int[char] for char in string]
decode = lambda list_int: "".join([int_to_string[integer] for integer in list_int])

# Train and test splits
data = torch.tensor(encode(corpus), dtype=torch.long, device=device)
number_train = int(0.9 * len(data))
train_data = data[:number_train]
validation_data = data[:number_train]

In [82]:
is_model_karpathy = False

if is_model_karpathy:
    model = BigramLanguageModelKarpathy(
        vocab_size=vocab_size,
        number_embeddings=number_embeddings,
        block_size=block_size,
        number_heads=number_heads,
        number_layers=number_layers,
        dropout=dropout,
        device=device,
    ).to(device)
else:
    model = BigramLanguageModelTorchLayers(
        vocab_size=vocab_size,
        number_embeddings=number_embeddings,
        block_size=block_size,
        number_heads=number_heads,
        number_layers=number_layers,
        dropout=dropout,
        device=device,
    ).to(device)

In [83]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

model.train_model(
    max_iters,
    train_data,
    validation_data,
    optimizer,
    batch_size,
    eval_interval,
    eval_iters,
)

At step 0: train loss 4.2914, val loss 4.2902.
At step 500: train loss 2.4619, val loss 2.4626.
At step 999: train loss 2.3506, val loss 2.3495.


In [84]:
# Generate from the model
print("\nAn example of text generated from the model.")
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, 300)[0].tolist()))


An example of text generated from the model.

CENI:
Whofowe th sarod, theJor. mplay thereis re.
Y RUK:

DI pr Pumer d yovaroubomsed t,hy dll, adirea d the. Ris Ve?
SQ

ORESABOMALARUCERIH:
Whore Whyathe ' VETofre S:
Are chas ckdsor
Ye wndes't may lowe tod, t, ssswin opamspe
The llleaits pater los h o s langorabe hyot ng
An wice h and: habeld onk


In [86]:
print("\nNow we generate text forever.\n")
model.generate_forever(decode, 0.1)


Now we generate text forever.



Sititee rd th woutoucrs. PUThy she, thy whor maw---pe F icoureanclinlele thenp w
e,

Maver:
