In [1]:
from nltk.tokenize import word_tokenize
import torch

In [2]:
# nltk.download() # if not install the nltk library then uncomment this line

In [3]:
text = 'I love this flavor! It\'s by far the best choice and my go-to whenever I go to the grocery store. I wish they would restock it more often though.'

In [4]:
word_tokens = word_tokenize(text)
print(word_tokens)

['I', 'love', 'this', 'flavor', '!', 'It', "'s", 'by', 'far', 'the', 'best', 'choice', 'and', 'my', 'go-to', 'whenever', 'I', 'go', 'to', 'the', 'grocery', 'store', '.', 'I', 'wish', 'they', 'would', 'restock', 'it', 'more', 'often', 'though', '.']


In [5]:
with open('../data/input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [6]:
words = sorted(list(set(word_tokenize(text))))
vocab_size = len(words)
vocab_size

14310

In [7]:
# create a mapping from characters to integers
stoi = { w:i for i,w in enumerate(words) }
itos = { i:w for i,w in enumerate(words) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ' '.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [8]:
test_string = 'You are all resolved rather to die than to famish?'
print(encode(word_tokenize(test_string)))
print(decode(encode(word_tokenize(test_string))))

[3053, 3512, 3324, 11053, 10791, 13010, 5723, 12819, 13010, 6533, 225]
You are all resolved rather to die than to famish ?


In [9]:
# encode the entire text dataset and store it into a torch.Tensor
data = torch.tensor(encode(word_tokenize(text)), dtype=torch.long)
print(data.shape, data.type)
print(data[:1000])

torch.Size([254509]) <built-in method type of Tensor object at 0x000001B072FA8C20>
tensor([ 1152,   709,   223,   482, 13877, 10480,  3440,  7080,   219,  7604,
         8993, 12087,   221,   323,   223,  2520,   219, 12087,   221,  1152,
          709,   223,  3053,  3512,  3324, 11053, 10791, 13010,  5723, 12819,
        13010,  6533,   225,   323,   223,  2256,   221, 11053,   221,  1152,
          709,   223,  1152,   219, 14291,  8402,   640,  1769,  8232,  4679,
         6251, 13010, 12831, 10036,   221,   323,   223,  2919,  8404,   219,
        13877,  8404,   221,  1152,   709,   223,  1679, 13581,  8340,  7738,
          219,  3412, 13877,   162,  7567,  5147,  3596,  9761,  9833, 10439,
          221,  1547,  3061, 13661,   225,   323,   223,  1924,  9271, 12703,
         9583,  9377,   224,  8580,  8243,  3786,  5952,   223,  3659,   219,
         3659,     0,  2370,   709,   223,  1972, 14183,   219,  7277,  4738,
          221,  1152,   709,   223,  2919,  3512,  3145, 10

In [10]:
# Let's now split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]
print(f'The size of train data is : {len(train_data)}')
print(f'The size of val data is : {len(val_data)}')

The size of train data is : 229058
The size of val data is : 25451


In [11]:
block_size = 8
train_data[:block_size+1]

tensor([ 1152,   709,   223,   482, 13877, 10480,  3440,  7080,   219])

In [12]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f'input: {context}, target: {target}')

input: tensor([1152]), target: 709
input: tensor([1152,  709]), target: 223
input: tensor([1152,  709,  223]), target: 482
input: tensor([1152,  709,  223,  482]), target: 13877
input: tensor([ 1152,   709,   223,   482, 13877]), target: 10480
input: tensor([ 1152,   709,   223,   482, 13877, 10480]), target: 3440
input: tensor([ 1152,   709,   223,   482, 13877, 10480,  3440]), target: 7080
input: tensor([ 1152,   709,   223,   482, 13877, 10480,  3440,  7080]), target: 219


In [13]:
torch.manual_seed(1337)
batch_size = 4 # How many independent sequences will be process in parallel?
block_size = 8 # What is the maximum context length for predictions?

In [14]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [15]:
batch_size = 4
block_size = 8

sindex = 0
eindex = batch_size * block_size
batch = data[sindex: eindex]
print("Unseprated batch:")
print(batch, "\n")
batch = batch.reshape((batch_size, block_size))
print(f"Sperated batch:")
print(batch)

Unseprated batch:
tensor([ 1152,   709,   223,   482, 13877, 10480,  3440,  7080,   219,  7604,
         8993, 12087,   221,   323,   223,  2520,   219, 12087,   221,  1152,
          709,   223,  3053,  3512,  3324, 11053, 10791, 13010,  5723, 12819,
        13010,  6533]) 

Sperated batch:
tensor([[ 1152,   709,   223,   482, 13877, 10480,  3440,  7080],
        [  219,  7604,  8993, 12087,   221,   323,   223,  2520],
        [  219, 12087,   221,  1152,   709,   223,  3053,  3512],
        [ 3324, 11053, 10791, 13010,  5723, 12819, 13010,  6533]])


In [16]:
def get_batch(batch_size, split):
    data = train_data if split == 'train' else val_data
    
    sindex = 0
    eindex = batch_size * block_size
    while eindex < (len(data) - block_size):
        # Extract batch_size * block_size tokens in the data
        xb = data[sindex: eindex]
        yb = data[sindex+1: eindex+1]
        # Reshape the batch by the shape (batch_size, block_size)
        xb = xb.reshape(batch_size, block_size)
        yb = yb.reshape(batch_size, block_size)
        # Update the indexes for extracting tokens sections
        temp = eindex
        eindex = eindex + block_size * batch_size
        sindex = temp
        yield xb, yb
    
    # For the last batch that cannot be batched as in the size (batch_size, block_size)
    # Maintain the 'block_size' dimension
    if eindex >= (len(data) - block_size):
        # The number of useable completed sample including 'block_size' elements
        num_sample = (len(data) - sindex) // block_size 
        if num_sample >= 1:
            xb = data[sindex: sindex + (num_sample * block_size)]
            yb = data[sindex+1: sindex + (num_sample * block_size)+1]
            xb = xb.reshape(num_sample, block_size)
            yb = yb.reshape(num_sample, block_size)
            yield xb, yb 

In [19]:
batch_size = 32
block_size = 32

num_batch = 0
for i, batch in enumerate(get_batch(batch_size, 'train')):
    xb, yb = batch
    if i < 1:
        print('inputs:')
        print(xb.shape)
        print(xb)
        print('targets:')
        print(yb.shape)
        print(yb)
    num_batch += 1

print(f'TOTAL number of batches: {num_batch}')

inputs:
torch.Size([32, 32])
tensor([[ 1152,   709,   223,  ..., 12819, 13010,  6533],
        [  225,   323,   223,  ...,  8404,   221,  1152],
        [  709,   223,  1679,  ...,  8580,  8243,  3786],
        ...,
        [ 8243,  5721, 10938,  ..., 13984, 12831,  9756],
        [ 8161,   919, 11488,  ..., 14028,  4132,   221],
        [ 2661,  3916,  3430,  ...,  8348,  9659, 11942]])
targets:
torch.Size([32, 32])
tensor([[  709,   223,   482,  ..., 13010,  6533,   225],
        [  323,   223,  2256,  ...,   221,  1152,   709],
        [  223,  1679, 13581,  ...,  8243,  3786,  5952],
        ...,
        [ 5721, 10938,  1478,  ..., 12831,  9756,  8161],
        [  919, 11488,  3412,  ...,  4132,   221,  2661],
        [ 3916,  3430,   130,  ...,  9659, 11942,   219]])
TOTAL number of batches: 224


In [18]:
for b in range(batch_size): # batch dimension
    print(f'batch {b+1}/{batch_size}')
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")
    print()

batch 1/4
when input is [11835] the target: 219
when input is [11835, 219] the target: 9085
when input is [11835, 219, 9085] the target: 14291
when input is [11835, 219, 9085, 14291] the target: 13793
when input is [11835, 219, 9085, 14291, 13793] the target: 8636
when input is [11835, 219, 9085, 14291, 13793, 8636] the target: 3061
when input is [11835, 219, 9085, 14291, 13793, 8636, 3061] the target: 12377
when input is [11835, 219, 9085, 14291, 13793, 8636, 3061, 12377] the target: 223

batch 2/4
when input is [223] the target: 8986
when input is [223, 8986] the target: 1478
when input is [223, 8986, 1478] the target: 3786
when input is [223, 8986, 1478, 3786] the target: 11970
when input is [223, 8986, 1478, 3786, 11970] the target: 4137
when input is [223, 8986, 1478, 3786, 11970, 4137] the target: 13010
when input is [223, 8986, 1478, 3786, 11970, 4137, 13010] the target: 8402
when input is [223, 8986, 1478, 3786, 11970, 4137, 13010, 8402] the target: 12831

batch 3/4
when input 