In [36]:
from functools import partial
import torch
from word_window_model import train, custom_collate_fn

## Generate Dataset

In [37]:
training_data = [
    "We always come to Paris",
    "The professor is from Australia",
    "I live in Stanford",
    "He comes from Taiwan",
    "The capital of Turkey is Ankara"
]

locations = {"australia", "ankara", "paris", "stanford", "taiwan", "turkey"}

training_sentences = [sentence.lower().split() for sentence in training_data]

training_labels = [[1 if word in locations else 0 for word in sentence] for sentence in training_sentences]

# print(f"training_sentences:\n{training_sentences}")
# print(f"training_labels:\n{training_labels}")

## Model Hyperparameters
- window_size: the window
- hidden_dim
- embedded_dim
- batch_size

In [38]:
model_hyperparameters = {
    'window_size': 2,
    'hidden_dim': 25,
    'embedded_dim': 25,
    'batch_size': 4
}

### Training our model

In [39]:
model, word2idx = train(training_data, training_labels, model_hyperparameters)

  4%|▍         | 45/1000 [00:00<00:02, 413.53it/s]

0.20507566258311272


 15%|█▌        | 150/1000 [00:00<00:01, 466.20it/s]

0.14433403685688972


 27%|██▋       | 266/1000 [00:00<00:01, 500.67it/s]

0.1214066743850708
0.10080908611416817


 47%|████▋     | 467/1000 [00:00<00:01, 482.54it/s]

0.0808707382529974


 57%|█████▋    | 569/1000 [00:01<00:00, 482.62it/s]

0.08577993791550398


 66%|██████▋   | 665/1000 [00:01<00:00, 449.28it/s]

0.05147194303572178


 76%|███████▌  | 756/1000 [00:01<00:00, 436.98it/s]

0.02967604622244835


 86%|████████▋ | 865/1000 [00:01<00:00, 465.57it/s]

0.02411596290767193


 96%|█████████▌| 961/1000 [00:02<00:00, 448.08it/s]

0.03127680439502001


100%|██████████| 1000/1000 [00:02<00:00, 463.67it/s]


### Test

In [40]:
test_sentences = ['She comes from Paris']

test_data = [sentence.lower().split() for sentence in test_sentences]
test_labels = [[0, 0, 0, 1, ]]

test_data = list(zip(test_data, test_labels))

collate_fn = partial(custom_collate_fn, window_size=model_hyperparameters['window_size'], word2idx=word2idx)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, collate_fn=collate_fn)

for test_instance, labels, _ in test_loader:
    outputs = model(test_instance)
    print(f"True labels: {labels}")
    print(f"Estimated Probabilities: {outputs}")

True labels: tensor([[0, 0, 0, 1]])
Estimated Probabilities: tensor([[0.0569, 0.0497, 0.0229, 0.9464]], grad_fn=<ViewBackward0>)


Based on the probabilities above, we can conclude that there is very high probability that the fourth word in the test sentence is a location word!