In [2]:
import torch

device = (
    "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [16]:
with open("input.txt", encoding="utf-8") as file:
    text = file.read()


chars = sorted(list(set(text)))
vocab_size = len(chars)


print("Dataset vocab size:", vocab_size)
print("Dataset chars:", "".join(chars))

Dataset vocab size: 65
Dataset chars: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [32]:
# create mapping from char integers
str_to_int = {char: i for i, char in enumerate(chars)}
int_to_str = {i: char for i, char in enumerate(chars)}

# print(str_to_int)
# print(int_to_str)


def encode(string):
    return [str_to_int[char] for char in string]


def decode(int_list):
    return "".join([int_to_str[i] for i in int_list])


string = "hii there"

encoded = encode(string)
print(f"'{string}' encoded is: {encoded}")
print(f"'{encoded}' decoded is: '{decode(encoded)}'")

'hii there' encoded is: [46, 47, 47, 1, 58, 46, 43, 56, 43]
'[46, 47, 47, 1, 58, 46, 43, 56, 43]' decoded is: 'hii there'


In [44]:
# convert text to pytorch tensor
import torch

data = torch.tensor(encode(text), dtype=torch.long)
print(data)

# prepare data
n = int(0.9 * len(data))  # use 90% for train
train_data = data[:n]
val_data = data[n:]

print(f"Data lenght: {len(data)}")
print(f"Train data lenght: {len(train_data)}")
print(f"Validation data lenght: {len(val_data)}")

tensor([18, 47, 56,  ..., 52, 45,  8])
Data lenght: 1115393
Train data lenght: 1003853
Validation data lenght: 111540


In [46]:
block_size = 8
train_data[: block_size + 1]


x = train_data[:block_size]
y = train_data[1 : block_size + 1]
for t in range(block_size):
    context = x[: t + 1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is tensor([18]) the target: 47
when input is tensor([18, 47]) the target: 56
when input is tensor([18, 47, 56]) the target: 57
when input is tensor([18, 47, 56, 57]) the target: 58
when input is tensor([18, 47, 56, 57, 58]) the target: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58


In [58]:
batch_size = 4
sequence_length = 8


def get_batch(split):
    # generate a small batch of data of inputs (sequences) and targets
    dataset = train_data if split == "train" else val_data  # select appropriate dataset

    # generate random starting indices for sequences
    start_indices = torch.randint(
        len(dataset) - sequence_length, tuple([batch_size])
    )  # out ->   [100,5,45,6...] this represents: the first batch starts at index 100, then second batch starts at index 5

    input_sequences = torch.stack(
        [dataset[start_idx : start_idx + sequence_length] for start_idx in start_indices]
    )
    target_sequences = torch.stack(
        [dataset[start_idx + 1 : start_idx + sequence_length + 1] for start_idx in start_indices]
    )

    return input_sequences, target_sequences


inputs_batch, targets_batch = get_batch("train")
print("Inputs:")
print(inputs_batch)
print("Targets:")
print(targets_batch)

print(" ---- ")

for batch_idx in range(batch_size):  # iterate over batch dimension
    for time_step in range(sequence_length):  # iterate over time dimension
        current_context = inputs_batch[batch_idx, : time_step + 1]
        target_token = targets_batch[batch_idx, time_step]
        print(f"When context is {current_context.tolist()}, the target token: {target_token}")

Inputs:
tensor([[ 1, 45, 56, 43, 39, 58,  1, 39],
        [56, 43, 56,  6,  1, 39, 40, 53],
        [ 1, 58, 46, 43,  1, 41, 56, 53],
        [52,  0, 32, 46, 39, 58,  1, 56]])
Targets:
tensor([[45, 56, 43, 39, 58,  1, 39, 50],
        [43, 56,  6,  1, 39, 40, 53, 59],
        [58, 46, 43,  1, 41, 56, 53, 61],
        [ 0, 32, 46, 39, 58,  1, 56, 53]])
 ---- 
When context is [1], the target token: 45
When context is [1, 45], the target token: 56
When context is [1, 45, 56], the target token: 43
When context is [1, 45, 56, 43], the target token: 39
When context is [1, 45, 56, 43, 39], the target token: 58
When context is [1, 45, 56, 43, 39, 58], the target token: 1
When context is [1, 45, 56, 43, 39, 58, 1], the target token: 39
When context is [1, 45, 56, 43, 39, 58, 1, 39], the target token: 50
When context is [56], the target token: 43
When context is [56, 43], the target token: 56
When context is [56, 43, 56], the target token: 6
When context is [56, 43, 56, 6], the target token: 1
