E06: meta-exercise! Think of a fun/interesting exercise and complete it.

My idea is to use the model on a different data set, and see how it performs. I will use German street names, from the [OpenAdresses dataset of Germany](https://www.kaggle.com/datasets/openaddresses/openaddresses-europe?resource=download&select=germany.csv)

In [327]:
import torch
import torch.nn.functional as F
from math import floor

In [328]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [329]:
device

device(type='cuda')

In [330]:
# Read in the dataset
streetnames = open('./sample_data/street-names.txt', 'r', encoding='utf-8').read().splitlines()

In [331]:
streetnames[:10]

['am hohen rand',
 'lehrer-geßner-straße',
 'wittekindallee',
 'am darloh',
 'taxusweg',
 'hinterm friedhof',
 'kleine wende',
 'an der liff',
 'posthorn',
 'landsberger allee']

In [332]:
# Create training, dev, and test sets
train_index = floor(len(streetnames) * 0.90)
dev_index = floor(len(streetnames) * 0.95)

train = streetnames[:train_index]
dev = streetnames[train_index:dev_index]
test = streetnames[dev_index:]

In [333]:
special_token = '.'

In [334]:
chars = sorted(list(set(''.join([special_token] + streetnames))))

# Create look up tables for the alphabet
  # stoi = string to index
  # itos = index to string
stoi = {s:i for i, s in enumerate(chars)}
itos = {i:s for s, i in stoi.items()}

In [335]:
itos.values()

dict_values([' ', '&', "'", '-', '.', ':', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'ß', 'ä', 'ö', 'ü'])

In [336]:
alphabet_size = len(chars)
alphabet_size

36

In [337]:
xs_train,  ys_train = [], []

# Create the training data
# input xs: (ch1, ch2) 
# prediction ys: ch3
for streetname in train:

  # prepend two special characters and append one special characters to each streetname
  chs = [special_token] * 2 + list(streetname) + [special_token]
  for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    ix3 = stoi[ch3]
    xs_train.append((ix1, ix2))
    ys_train.append(ix3)

num = len(xs_train)
print('number of training examples: ', num)
xs_train = torch.tensor(xs_train, device=device)
ys_train = torch.tensor(ys_train, device=device)

number of training examples:  2753260


In [338]:
# Create the dev data
xs_dev,  ys_dev = [], []
for word in dev:
  chs = [special_token] * 2 + list(word) + [special_token]
  for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    ix3 = stoi[ch3]
    xs_dev.append((ix1, ix2))
    ys_dev.append(ix3)

xs_dev = torch.tensor(xs_dev, device=device)
ys_dev = torch.tensor(ys_dev, device=device)
print('number of development examples: ', xs_dev.nelement())

number of development examples:  306006


In [339]:
# Create the test data
xs_test,  ys_test = [], []
for word in test:
  chs = [special_token] * 2 + list(word) + [special_token]
  for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    ix3 = stoi[ch3]
    xs_test.append((ix1, ix2))
    ys_test.append(ix3)

xs_test = torch.tensor(xs_test, device=device)
ys_test = torch.tensor(ys_test, device=device)
print('number of testing examples: ', xs_test.nelement())

number of testing examples:  304704


In [340]:
g = torch.Generator(device=device).manual_seed(2147483647)
W = torch.randn((alphabet_size*2, alphabet_size), generator=g, device=device, requires_grad=True)

In [341]:
# gradient descent
iterations = 500
learning_rate = 200
smoothing_strenth = 0.01

for k in range(iterations):

  # forward pass
  xenc= F.one_hot(xs_train, num_classes=alphabet_size).float()
  xenc_flat = xenc.flatten(1) # flatten the one-hot encoded input vector
  logits = xenc_flat @ W # predict log-counts
  loss = F.cross_entropy(logits, ys_train) + smoothing_strenth * (W**2).mean() # compute loss
  # print loss every 10% of iterations
  if k % floor(iterations/10) == 0:
    print(f'loss at step {k}: {loss.item():.3f}')

  # backward pass
  W.grad = None # flush the gradients
  loss.backward()

  # update step
  cool_down = 1.0 / (1 + 0.01 * k)
  W.data += -learning_rate * cool_down * W.grad
print(f'final training loss: {loss.item():.3f}, with smoothing strength {smoothing_strenth}')

loss at step 0: 4.512
loss at step 50: 2.205
loss at step 100: 2.111
loss at step 150: 2.080
loss at step 200: 2.061
loss at step 250: 2.059
loss at step 300: 2.059
loss at step 350: 2.058
loss at step 400: 2.058
loss at step 450: 2.058
final training loss: 2.057, with smoothing strength 0.01


In [342]:
# Evaluate the model on the dev set
xenc= F.one_hot(xs_dev, num_classes=alphabet_size).float()
xenc_flat = xenc.flatten(1) # flatten the one-hot encoded input vector
logits = xenc_flat @ W # predict log-counts
# softmax
counts = logits.exp() # counts
probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
# loss function (cross-entropy) + regularization (L2)
loss = F.cross_entropy(logits, ys_dev)
print(f'loss on dev set: {loss.item():.3f}, with smoothing strength {smoothing_strenth}')

loss on dev set: 2.040, with smoothing strength 0.01


In [343]:
# Evaluate the model on the test set
xenc= F.one_hot(xs_test, num_classes=alphabet_size).float()
xenc_flat = xenc.flatten(1) # flatten the one-hot encoded input vector
logits = xenc_flat @ W # predict log-counts
# softmax
counts = logits.exp() # counts
probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
# loss function (cross-entropy) + regularization (L2)
loss = F.cross_entropy(logits, ys_test) + smoothing_strenth * (W**2).mean() # compute loss
print(f'loss on test set: {loss.item():.3f}')

loss on test set: 2.061


In [344]:
g = torch.Generator(device=device).manual_seed(2147483647)
# sample street names
name_count = 100
sampled_street_names = []
for i in range(name_count):

  out = []
  ix1 = stoi[special_token]
  ix2 = stoi[special_token]

  while True:
    xenc = F.one_hot(torch.tensor([ix1, ix2], device=device), num_classes=alphabet_size).float()
    xenc_flat = xenc.flatten()
    logits = xenc_flat @ W # predict log-counts
    # softmax
    counts = logits.exp() # counts, equivalent to N
    p = counts / counts.sum(0, keepdims=True) # probabilities for next character
    
    # move index to next character
    ix1 = ix2
    ix2 = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
    out.append(itos[ix2])
    if ix2 == stoi[special_token]:
      # stop if we reach the end of the word
      # 0 is the index of the special character '.'
      break

  sampled_street_names.append(''.join(out))
for streetname in sampled_street_names:
  print(streetname)

ga-strkoreindjjake.
hanwol-hen de.
arastraße.
aner straße.
bag-z.
vilbaufsuren.
imobner wies-old.
ser.
benborde.
sendenätsper weg.
rindet.
fendiuxweg.
jecken dwes ald-rtgerwereg.
weg.
obeildts-steimbre-ge.
aurweg.
mmieg.
arnachaul-weg.
of.
haustraße.
haundeit.
llter dühl-gauf.
borpistr der wegern-weg.
andstraßerweiläusstod.
chkim gerbstenm weg.
batzen.
heraße.
mofkertweg.
böstraße.
winsthng.
rnhupwegereh.
taastraße.
haubler stichtraße.
jastraße.
am chhraße.
jalbreomafl-elee denhen.
kaystatraße.
tben sjog.
rochaulestraße.
kuberger straße.
düstrastr atz-plaraße.
amer örmest.
votbachuf.
rum kttandericker bahofse.
fer wilsten.
psteelen.
lereße.
pastraße.
bucherstraße.
ornsilberg.
kot.
waber.
ülstraße.
traße.
mühkale.
kingalerg.
alttam .
karie-dotziegersweg.
wzwel.
raße.
dor'inb-jokarp.
zulgen dem ges ronder dtben dphieg.
jofferiestrathraße.
areße.
zuach.
peten weg.
schef derwegrausg.
intstrep.
zem raße.
ackserweg straße.
raße.
ben dhallttwar pllicken derstraße.
um gaberberheveng.
spökelder

In [345]:
out = ['a', 'm']
ix1 = stoi[out[0]]
ix2 = stoi[out[1]]

while True:
  xenc = F.one_hot(torch.tensor([ix1, ix2], device=device), num_classes=alphabet_size).float()
  xenc_flat = xenc.flatten()
  logits = xenc_flat @ W # predict log-counts
  # softmax
  counts = logits.exp() # counts, equivalent to N
  p = counts / counts.sum(0, keepdims=True) # probabilities for next character
  
  # move index to next character
  ix1 = ix2
  ix2 = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
  out.append(itos[ix2])
  if ix2 == stoi[special_token]:
    # stop if we reach the end of the word
    # 0 is the index of the special character '.'
    break
print(''.join(out))

am raße.
