In [1]:
import torch
import torch.nn as nn
import os
import math
import requests
import tiktoken
import pandas as pd
import matplotlib.pyplot as plt

In [48]:
# Hyperparameters
batch_size = 4
# context_length = 16
max_seq_len = 16
d_model = 64
n_heads = 4
n_layers = 4
learning_rate = 1e-3
dropout = 0.1
max_iters = 5000
eval_interval = 50
eval_iters = 20
device = 'cuda' if torch.cuda.is_available() else 'cpu'

TORCH_SEED = 3047
torch.manual_seed(TORCH_SEED)

<torch._C.Generator at 0x2e03d4c6950>

## Prepare the Dataset

In [11]:
if not os.path.exists('datasets/sales_textbook.txt'):
    url = 'https://huggingface.co/datasets/goendalf666/sales-textbook_for_convincing_and_selling/raw/main/sales_textbook.txt'
    with open('datasets/sales_textbook.txt', 'w') as f:
        f.write(requests.get(url).text)

with open('datasets/sales_textbook.txt', 'r', encoding='utf-8') as f:
    text = f.read()

## Step 1: Tokenization

In [33]:
# Using Tiktoken to tokenize the source text
encoding = tiktoken.get_encoding('cl100k_base')
tokenized_text = encoding.encode(text)
tokenized_text = torch.tensor(tokenized_text, dtype=torch.long)
vocab_size = len(set(tokenized_text))
max_token_value = max(tokenized_text)

print(f'Tokenized text size: {len(tokenized_text)}')
print(f'Vocabulary size: {vocab_size}')
print(f'The maximum value in the tokenized text is: {max_token_value}')

Tokenized text size: 77919
Vocabulary size: 77919
The maximum value in the tokenized text is: 100069


## Step 2: Word Embedding

In [24]:
# split train and validation
split_idx = int(len(tokenized_text) * 0.8)
train_data = tokenized_text[:split_idx]
val_data = tokenized_text[split_idx:]

# training_batch
idxs = torch.randint(low=0, high=len(train_data) - max_seq_len, size=(batch_size,))
x_batch = torch.stack([train_data[idx:idx + max_seq_len] for idx in idxs])
y_batch = torch.stack([train_data[idx + 1:idx + max_seq_len + 1] for idx in idxs])

x_batch.shape, y_batch.shape

(torch.Size([4, 16]), torch.Size([4, 16]))

In [30]:
pd.DataFrame(x_batch.numpy())

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
0,11,433,374,16996,311,8881,323,63179,279,6130,596,14847,13,1115,15105,539
1,1317,627,2520,54111,5552,311,5597,28846,11447,11,3085,19351,323,1862,311,21736
2,11411,311,9455,279,1888,6425,369,1124,627,644,17102,11,9204,22785,2802,304
3,13,12040,279,892,311,2610,1825,84175,4860,323,22815,9020,311,872,14847,13


In [32]:
encoding.decode(x_batch[0].numpy())

", it is crucial to reflect and summarize the customer's responses. This technique not"

In [34]:
# define input embedding table
input_embedding_lookup_table = nn.Embedding(max_token_value + 1, d_model)
input_embedding_lookup_table.weight.data

tensor([[-2.4084, -0.0766,  0.1495,  ...,  0.4334,  0.3934,  0.0186],
        [-0.3088,  1.7746, -1.9879,  ..., -0.2933, -0.9630, -0.4039],
        [-0.3746,  0.4040, -1.6022,  ...,  0.4541, -0.0680, -0.9949],
        ...,
        [-0.8532, -0.7116, -0.4817,  ...,  0.6795,  0.0603,  1.1497],
        [ 0.6459, -0.1134, -0.3980,  ...,  0.5834, -1.3536,  1.3549],
        [ 1.6187, -0.2309,  1.6308,  ...,  0.7032, -0.2993,  1.4346]])

In [None]:
x_batch_embedding = input_embedding_lookup_table(x_batch).to(device)
y_batch_embedding = input_embedding_lookup_table(y_batch).to(device)

x_batch_embedding, y_batch_embedding

## Step 3: Positional Encoding

In [52]:
# positional encoding
from transformer_from_scratch.embedding.positional_encoding import PositionalEncoding

positional_encoding = PositionalEncoding(d_model=d_model, max_seq_len=max_seq_len, device=device)
positional_encoding = positional_encoding.positional_encoding.unsqueeze(0).expand(batch_size, -1, -1)

# [batch_size, seq_len, d_model]
positional_encoding.shape

torch.Size([4, 16, 64])

In [53]:
x_batch_embedding = x_batch_embedding + positional_encoding
y_batch_embedding = x_batch_embedding + positional_encoding

In [55]:
x_example = x_batch_embedding[0].detach().cpu().numpy()
pd.DataFrame(x_example)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,54,55,56,57,58,59,60,61,62,63
0,-2.128163,1.098318,0.061589,0.659486,-0.486816,-1.684022,0.531371,-0.144989,1.098932,2.950274,...,0.586137,1.686508,-0.753595,1.446968,-1.097538,1.330085,-0.87864,1.390788,0.380007,2.132031
1,-0.224984,1.834972,0.39189,1.039342,1.002326,1.506255,0.933553,-0.084586,0.273862,1.94319,...,-1.214233,0.296266,0.479791,1.20636,0.100527,-0.301546,-0.596328,0.335199,-0.32895,1.047948
2,0.70946,2.405553,0.478345,-0.359229,0.712264,1.05307,2.090398,0.196823,1.711592,1.730036,...,-1.34138,-0.154286,0.437158,0.352002,-1.144311,1.890062,0.686984,2.137396,0.677471,0.322397
3,1.687803,-0.035164,0.453855,-0.539957,-0.5149,0.320815,0.11701,0.744642,1.506247,0.805907,...,1.003182,-1.810592,-0.917659,2.064042,-0.729338,0.875718,0.587507,2.125895,-1.236793,-1.225436
4,0.466571,-1.589764,-0.452444,-2.848962,-0.391449,-0.294052,0.314704,-1.106485,1.848796,0.64099,...,-0.633569,2.333714,-0.634152,1.278784,-0.91001,1.094407,-0.205325,1.954334,0.04486,-0.120603
5,-2.151568,0.799066,1.076717,-0.917921,1.902736,-1.722682,0.910654,0.126296,1.507143,-2.642344,...,1.432108,1.310794,2.867361,-0.445517,0.85283,0.803569,1.808615,0.509389,-0.282991,1.647054
6,-0.606167,-1.072449,-0.482319,-0.340754,-0.254739,-1.708647,1.312923,-0.526044,2.225625,0.941511,...,-0.430943,0.453285,-1.937428,1.839861,-0.682724,-0.290616,-0.686125,2.562037,-0.057366,0.903035
7,1.670955,0.732331,-1.220412,0.670999,1.564215,-0.218517,1.145501,0.373214,0.285209,-1.377281,...,0.883945,1.907699,-1.401604,0.891313,-0.396377,0.891786,-2.270334,1.776717,-0.353932,1.630784
8,1.333383,0.098351,-1.309119,2.373107,-1.221122,-0.64833,0.240119,-2.220782,1.847678,-1.169461,...,-0.227376,2.118015,-1.067782,1.980372,0.070198,0.73668,-0.656532,0.906004,-2.101119,1.25267
9,0.595091,-1.312883,0.687742,1.216406,-1.838146,-1.903039,-1.301278,-2.901025,0.851294,-0.959089,...,0.2373,1.297597,-0.856803,-0.740361,0.012208,2.505404,-1.421558,-0.04777,-1.499051,0.449086
