In [7]:
from typing import Tuple

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.types import _device, _int, _size, _TensorOrTensors

from gpt import set_seed

set_seed(1999)

In [8]:
"""
Loading the sample text to train and validate the implemented gpt model.
"""
with open("assets/input.txt", 'r') as f:
    text = f.read()

print(text[:200])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


In [13]:
chars = sorted(list(set(text)))
print(repr(''.join(chars)))

"\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"


In [10]:
"""
create a lookup table to encode the given string (or list of characters) and decode it.
"""

cton = {c: n for n, c in enumerate(chars)} # Character to number lookup table
ntoc = {n: c for n, c in enumerate(chars)} # Number to character lookup table

def encode(s: list[str] | str):
    return [cton[c] for c in s]


def decode(e: list[int] | torch.Tensor | int):
    if isinstance(e, torch.Tensor):
        e = e.tolist()
    if isinstance(e, int):
        return ntoc[e]
    return [ntoc[n] for n in e]

test_text = "This must be properly encoded and decoded!"
enc_test = encode(test_text)
print(enc_test)
dec_text = decode(enc_test)
print("".join(dec_text))

[32, 46, 47, 57, 1, 51, 59, 57, 58, 1, 40, 43, 1, 54, 56, 53, 54, 43, 56, 50, 63, 1, 43, 52, 41, 53, 42, 43, 42, 1, 39, 52, 42, 1, 42, 43, 41, 53, 42, 43, 42, 2]
This must be properly encoded and decoded!


In [11]:
def get_batch(
    train: torch.Tensor, validation: torch.Tensor, batch_size: int, context_length: int
) -> Tuple[torch.Tensor, torch.Tensor]:
    """get mini-batch from given train and validation datasets

    Args:
        train (torch.Tensor): input dataset
        validation (torch.Tensor): validation dataset
        batch_size (int): batch size
        context_length (int): length window in time dimension

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: mini-batch of train (batch_size, context_length, feature_size), mini-batch of validation (batch_size, context_length)
    """

    assert train.size(0) == validation.size(0)

    offsets = torch.randint(train.size(0) - context_length, (batch_size,))
    x = torch.stack([train[i : i + context_length] for i in offsets])
    y = torch.stack([validation[i : i + context_length] for i in offsets])

    return x, y

In [16]:
"""
Example of getting mini-batch for test text. run this multiple time to see multiple examples
"""

inp_tensor = torch.tensor(encode(text))

batch_size = 2
context_length = 12
inp_mini_batch, target_mini_batch = get_batch(
    inp_tensor, torch.roll(inp_tensor, -1), batch_size, context_length
)
print(inp_mini_batch, target_mini_batch)

for b in range(batch_size):
    for i in range(context_length):
        print(
            f"given {repr(''.join(decode(inp_mini_batch[b][:i+1])))}, -> {repr(decode(target_mini_batch[b][i]))}"
        )

tensor([[53, 56, 51, 39, 50,  1, 61, 53, 51, 43, 52,  1],
        [53, 58,  1, 54, 43, 56, 57, 53, 52, 39, 50,  6]]) tensor([[56, 51, 39, 50,  1, 61, 53, 51, 43, 52,  1, 39],
        [58,  1, 54, 43, 56, 57, 53, 52, 39, 50,  6,  1]])
given 'o', -> 'r'
given 'or', -> 'm'
given 'orm', -> 'a'
given 'orma', -> 'l'
given 'ormal', -> ' '
given 'ormal ', -> 'w'
given 'ormal w', -> 'o'
given 'ormal wo', -> 'm'
given 'ormal wom', -> 'e'
given 'ormal wome', -> 'n'
given 'ormal women', -> ' '
given 'ormal women ', -> 'a'
given 'o', -> 't'
given 'ot', -> ' '
given 'ot ', -> 'p'
given 'ot p', -> 'e'
given 'ot pe', -> 'r'
given 'ot per', -> 's'
given 'ot pers', -> 'o'
given 'ot perso', -> 'n'
given 'ot person', -> 'a'
given 'ot persona', -> 'l'
given 'ot personal', -> ','
given 'ot personal,', -> ' '


['/mnt/e/internship/transformers/_test', '/home/alireza/.pyenv/versions/3.10.9/lib/python310.zip', '/home/alireza/.pyenv/versions/3.10.9/lib/python3.10', '/home/alireza/.pyenv/versions/3.10.9/lib/python3.10/lib-dynload', '', '/mnt/e/internship/transformers/.venv/lib/python3.10/site-packages']
