In [2]:
import torch

### Adding Context

Previously in bigrams we considered the probability of a character given the previous character (hence the bi in bigram). Now we'd like to add more context, Lets consider the probabilty of the next character given three previous characters.

In [6]:
names = open("names.txt", "r").read().splitlines()
names[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [8]:
len(names)

32033

In [16]:
### Build the vocab ###

#find all characters in our dataset
vocab = sorted(list(set(''.join(names))))
vocab.insert(0, ".")

#create mappings
char_to_idx = {char: idx for idx, char in enumerate(vocab)}
idx_to_char = {idx: char for idx, char in enumerate(vocab)}

In [49]:
context_len = 3
X, Y = [], []

for name in names[:5]:
    #initalise context with "." characters (index 0)
    context = [0] * context_len
    
    print(name)
    for char in name + ".":
        y = char_to_idx[char]
        X.append(context)
        Y.append(y)
        print(f'{"".join(idx_to_char[idx] for idx in context)} ---> {idx_to_char[y]} | {context} ---> {y}')

        #shift the context (like a rolling window)
        context = context[1:] + [y]
    
    print(end='\n')

#store as tensors
X = torch.tensor(X)
Y = torch.tensor(Y)

emma
... ---> e | [0, 0, 0] ---> 5
..e ---> m | [0, 0, 5] ---> 13
.em ---> m | [0, 5, 13] ---> 13
emm ---> a | [5, 13, 13] ---> 1
mma ---> . | [13, 13, 1] ---> 0

olivia
... ---> o | [0, 0, 0] ---> 15
..o ---> l | [0, 0, 15] ---> 12
.ol ---> i | [0, 15, 12] ---> 9
oli ---> v | [15, 12, 9] ---> 22
liv ---> i | [12, 9, 22] ---> 9
ivi ---> a | [9, 22, 9] ---> 1
via ---> . | [22, 9, 1] ---> 0

ava
... ---> a | [0, 0, 0] ---> 1
..a ---> v | [0, 0, 1] ---> 22
.av ---> a | [0, 1, 22] ---> 1
ava ---> . | [1, 22, 1] ---> 0

isabella
... ---> i | [0, 0, 0] ---> 9
..i ---> s | [0, 0, 9] ---> 19
.is ---> a | [0, 9, 19] ---> 1
isa ---> b | [9, 19, 1] ---> 2
sab ---> e | [19, 1, 2] ---> 5
abe ---> l | [1, 2, 5] ---> 12
bel ---> l | [2, 5, 12] ---> 12
ell ---> a | [5, 12, 12] ---> 1
lla ---> . | [12, 12, 1] ---> 0

sophia
... ---> s | [0, 0, 0] ---> 19
..s ---> o | [0, 0, 19] ---> 15
.so ---> p | [0, 19, 15] ---> 16
sop ---> h | [19, 15, 16] ---> 8
oph ---> i | [15, 16, 8] ---> 9
phi ---> a | [16, 8,

In [46]:
print(f"{X.shape} with dtype: {X.dtype}")
print(f"{Y.shape} with dtype: {Y.dtype}")

torch.Size([32, 3]) with dtype: torch.int64
torch.Size([32]) with dtype: torch.int64
