In [1]:
import sys
from pathlib import Path

here = Path.cwd().resolve()
repo_root = here if (here / "src").exists() else here.parents[1]

if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

In [2]:
import torch 

In [3]:
vocab_size = 50257
output_dim = 256
token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)

In [4]:
from src.gpt_blocks.data_loader import create_dataloader_v1

In [5]:
with open("the-verdict.txt", "r", encoding="utf-8") as f:
    raw_text = f.read()

In [6]:
max_length = 4
dataloader = create_dataloader_v1(
raw_text, batch_size=8, max_length=max_length,
stride=max_length, shuffle=False
)
data_iter = iter(dataloader)
inputs, targets = next(data_iter)
print("Token IDs:\n", inputs)
print("\nInputs shape:\n", inputs.shape)

Token IDs:
 tensor([[   40,   367,  2885,  1464],
        [ 1807,  3619,   402,   271],
        [10899,  2138,   257,  7026],
        [15632,   438,  2016,   257],
        [  922,  5891,  1576,   438],
        [  568,   340,   373,   645],
        [ 1049,  5975,   284,   502],
        [  284,  3285,   326,    11]])

Inputs shape:
 torch.Size([8, 4])


In [7]:
token_embeddings = token_embedding_layer(inputs)
print(token_embeddings.shape)

torch.Size([8, 4, 256])


In [12]:
token_embeddings

tensor([[[-1.2134,  1.2211,  0.3301,  ...,  0.8280, -0.3973,  0.7939],
         [ 0.9694,  1.2444, -1.3146,  ..., -2.0704,  0.0074,  0.0206],
         [ 0.0369,  0.2536, -0.5361,  ...,  0.3765, -0.4631, -2.0810],
         [-0.9944,  1.6484,  0.6639,  ..., -0.0684,  0.3267, -0.5061]],

        [[-0.0610,  0.2250,  0.6774,  ..., -1.1411,  0.0616,  0.3454],
         [-0.2939,  1.1642, -2.4171,  ..., -0.4432,  2.1483,  0.5958],
         [-0.9705,  1.5022, -1.1859,  ..., -0.5569,  1.7828, -0.3870],
         [-0.7826,  0.0840,  1.5790,  ...,  0.0882, -0.0665,  0.1866]],

        [[ 0.1592,  1.7994, -0.2480,  ...,  0.2744,  0.9752, -0.0873],
         [-1.4104, -1.8607,  1.5669,  ..., -0.4709,  0.4952,  1.5287],
         [ 0.3872, -1.2531, -0.1333,  ...,  0.3078, -0.8653, -0.0966],
         [ 0.1558,  0.4309, -1.3325,  ...,  2.2045, -0.0404,  0.0356]],

        ...,

        [[ 0.9477, -0.6950,  0.6029,  ...,  0.4806, -1.1319,  1.9854],
         [-0.6073,  0.8065, -0.0648,  ..., -2.0335,  1.65

In [None]:
#batch 1 - 4 tokens - 256 embedding dimensions
token_embeddings[:][0][:]

In [8]:
context_length = max_length
pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)
pos_embeddings = pos_embedding_layer(torch.arange(context_length))
print(pos_embeddings.shape)

torch.Size([4, 256])


In [13]:
pos_embeddings

tensor([[ 0.7668,  0.2183,  0.1065,  ..., -0.0249, -1.7862,  0.3125],
        [ 0.9664, -2.4627,  0.5354,  ..., -0.5696,  0.3536,  0.7189],
        [ 0.5135,  1.1237, -0.3547,  ...,  0.7116, -0.7971, -0.3304],
        [ 0.6905,  0.8488, -1.7337,  ..., -1.3906,  1.6632, -0.9377]],
       grad_fn=<EmbeddingBackward0>)

In [15]:
# add position embeddings to token embeddings
input_embeddings = token_embeddings + pos_embeddings
print(input_embeddings.shape)

torch.Size([8, 4, 256])


In [16]:
input_embeddings

tensor([[[-0.4466,  1.4393,  0.4365,  ...,  0.8031, -2.1835,  1.1064],
         [ 1.9358, -1.2183, -0.7791,  ..., -2.6400,  0.3610,  0.7395],
         [ 0.5503,  1.3774, -0.8908,  ...,  1.0881, -1.2603, -2.4115],
         [-0.3039,  2.4973, -1.0698,  ..., -1.4591,  1.9900, -1.4438]],

        [[ 0.7058,  0.4433,  0.7838,  ..., -1.1660, -1.7247,  0.6579],
         [ 0.6725, -1.2985, -1.8817,  ..., -1.0128,  2.5019,  1.3148],
         [-0.4571,  2.6260, -1.5406,  ...,  0.1547,  0.9856, -0.7174],
         [-0.0921,  0.9328, -0.1547,  ..., -1.3024,  1.5967, -0.7511]],

        [[ 0.9260,  2.0177, -0.1416,  ...,  0.2495, -0.8111,  0.2252],
         [-0.4441, -4.3234,  2.1024,  ..., -1.0405,  0.8488,  2.2477],
         [ 0.9007, -0.1294, -0.4880,  ...,  1.0194, -1.6624, -0.4270],
         [ 0.8463,  1.2797, -3.0662,  ...,  0.8139,  1.6229, -0.9021]],

        ...,

        [[ 1.7145, -0.4768,  0.7093,  ...,  0.4557, -2.9182,  2.2979],
         [ 0.3590, -1.6562,  0.4706,  ..., -2.6031,  2.00