In [1]:
import torch
block_size = 8
batch_size = 4

In [2]:
with open('data/wizard_of_oz.txt', 'r', encoding = 'utf-8') as f:
    text = f.read()
chars = sorted(set(text))
print(chars)
print(len(chars))

['\n', ' ', '!', '"', '&', "'", '(', ')', '*', ',', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
80


In [3]:
string_to_int = {ch:i for i, ch in enumerate(chars)}
int_to_string = {i:ch for i, ch in enumerate(chars)}
encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: ''.join(int_to_string[i] for i in l)

In [4]:
encoded_paul = encode('Paul')
decoded_paul = decode(encoded_paul)
print(encoded_paul, decoded_paul)

[40, 54, 74, 65] Paul


In [5]:
data = torch.tensor(encode(text), dtype = torch.long)
print(data[:100])

tensor([ 1,  1, 28, 39, 42, 39, 44, 32, 49,  1, 25, 38, 28,  1, 44, 32, 29,  1,
        47, 33, 50, 25, 42, 28,  1, 33, 38,  1, 39, 50,  0,  0,  1,  1, 26, 49,
         0,  0,  1,  1, 36, 11,  1, 30, 42, 25, 38, 35,  1, 26, 25, 45, 37,  0,
         0,  1,  1, 25, 45, 44, 32, 39, 42,  1, 39, 30,  1, 44, 32, 29,  1, 47,
        33, 50, 25, 42, 28,  1, 39, 30,  1, 39, 50,  9,  1, 44, 32, 29,  1, 36,
        25, 38, 28,  1, 39, 30,  1, 39, 50,  9])


In [6]:
data.shape

torch.Size([232309])

In [7]:
len(text)

232309

In [8]:
n = int(0.8 * len(data))
train_data = data[:n]
val_data = data[n:]

In [9]:
train_data.shape, val_data.shape

(torch.Size([185847]), torch.Size([46462]))

In [12]:
def get_batch(split):
  data = train_data if split == 'train' else val_data
  ix = torch.randint(len(data) - block_size, (batch_size,))
  print(ix)
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1: i+block_size+1] for i in ix])
  return x, y

x, y = get_batch('train')
print(f'Inputs \n: {x}')
print(f'Outputs \n: {y}')


tensor([ 41110, 138355, 144507, 143536])
Inputs 
: tensor([[ 1, 73, 68,  1, 73, 61, 58,  0],
        [60, 68, 78, 65, 58, 72,  9,  1],
        [61, 62, 72,  1, 76, 71, 62, 67],
        [58, 54, 73, 65, 78,  1, 58, 77]])
Outputs 
: tensor([[73, 68,  1, 73, 61, 58,  0, 47],
        [68, 78, 65, 58, 72,  9,  1, 76],
        [62, 72,  1, 76, 71, 62, 67, 64],
        [54, 73, 65, 78,  1, 58, 77, 56]])


In [11]:
x = train_data[:block_size]
y = train_data[1: block_size+1]

for t in range(block_size):
    content = x[:t+1]
    target = y[t]
    print(f'When input is {content} target is {target}')

When input is tensor([1]) target is 1
When input is tensor([1, 1]) target is 28
When input is tensor([ 1,  1, 28]) target is 39
When input is tensor([ 1,  1, 28, 39]) target is 42
When input is tensor([ 1,  1, 28, 39, 42]) target is 39
When input is tensor([ 1,  1, 28, 39, 42, 39]) target is 44
When input is tensor([ 1,  1, 28, 39, 42, 39, 44]) target is 32
When input is tensor([ 1,  1, 28, 39, 42, 39, 44, 32]) target is 49
