In [1]:
from argparse import Namespace
from src.learner import Learner

## Args

In [2]:
args = Namespace(
    # Data and path information
    surname_csv="data/surnames/surnames_with_splits.csv",
    vectorizer_file="vectorizer.json",
    model_state_file="model.pth",
    save_dir="model_storage/ch6/surname_classification",
    # Model hyper parameter
    char_embedding_size=100,
    rnn_hidden_size=64,
    # Training hyper parameter
    num_epochs=3,
    learning_rate=1e-3,
    batch_size=64,
    seed=1337,
    early_stopping_criteria=5,
    # Runtime hyper parameter
    cuda=True,
    catch_keyboard_interrupt=True,
    reload_from_files=False,
    expand_filepaths_to_save_dir=True,
)

## Learner, train

In [3]:
learner=Learner.learner_from_args(args)

Expanded filepaths: 
	model_storage/ch6/surname_classification\vectorizer.json
	model_storage/ch6/surname_classification\model.pth
Using CUDA: False
Loading dataset and creating vectorizer


In [4]:
learner.train()

HBox(children=(IntProgress(value=0, description='training routine', max=3, style=ProgressStyle(description_wid…

HBox(children=(IntProgress(value=0, description='split=train', max=120, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='split=val', max=25, style=ProgressStyle(description_width='in…

  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)


In [5]:
learner.validate()

Test loss: 2.68
Test Accuracy: 37.833


  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)


More detailed report: 
               precision    recall  f1-score   support

      Arabic       0.50      0.04      0.08       241
     Chinese       0.11      0.15      0.13        33
       Czech       0.18      0.24      0.21        63
       Dutch       0.03      0.11      0.04        36
     English       0.60      0.10      0.17       447
      French       0.04      0.03      0.03        35
      German       0.24      0.16      0.19        87
       Greek       0.09      0.92      0.16        24
       Irish       0.06      0.39      0.10        28
     Italian       0.26      0.81      0.40        90
    Japanese       0.42      0.42      0.42       117
      Korean       0.00      0.00      0.00        13
      Polish       0.10      0.06      0.07        18
  Portuguese       0.00      0.00      0.00         9
     Russian       0.75      0.59      0.66       357
    Scottish       0.00      0.00      0.00         2
     Spanish       0.00      0.00      0.00         0
  V

In [6]:
learner.predict_category('McMahan')

{'nationality': 'Irish',
 'probability': 0.10519526898860931,
 'surname': 'McMahan'}

In [7]:
learner.predict_category('Nakamoto')

{'nationality': 'Italian',
 'probability': 0.08115687966346741,
 'surname': 'Nakamoto'}

In [8]:
learner.predict_category('Wan')

{'nationality': 'Irish', 'probability': 0.09982389211654663, 'surname': 'Wan'}

In [9]:
learner.predict_category('Cho')

{'nationality': 'Italian',
 'probability': 0.07724784314632416,
 'surname': 'Cho'}

In [10]:
learner.predict_category('Che')

{'nationality': 'Irish', 'probability': 0.0656074807047844, 'surname': 'Che'}

## Load saved model

In [11]:
args = Namespace(
    # Data and path information
    surname_csv="data/surnames/surnames_with_splits.csv",
    vectorizer_file="vectorizer.json",
    model_state_file="model.pth",
    save_dir="model_storage/ch6/surname_classification",
    # Model hyper parameter
    char_embedding_size=100,
    rnn_hidden_size=64,
    # Training hyper parameter
    num_epochs=20,
    learning_rate=1e-3,
    batch_size=64,
    seed=1337,
    early_stopping_criteria=5,
    # Runtime hyper parameter
    cuda=True,
    catch_keyboard_interrupt=True,
    reload_from_files=True,
    expand_filepaths_to_save_dir=True,
)

learner_loaded=Learner.learner_from_args(args)

Expanded filepaths: 
	model_storage/ch6/surname_classification\vectorizer.json
	model_storage/ch6/surname_classification\model.pth
Using CUDA: False
Loading dataset and loading vectorizer


In [12]:
learner_loaded.validate()

Test loss: 2.68
Test Accuracy: 37.833
More detailed report: 
               precision    recall  f1-score   support

      Arabic       0.50      0.04      0.08       241
     Chinese       0.11      0.15      0.13        33
       Czech       0.18      0.24      0.21        63
       Dutch       0.03      0.11      0.04        36
     English       0.60      0.10      0.17       447
      French       0.04      0.03      0.03        35
      German       0.24      0.16      0.19        87
       Greek       0.09      0.92      0.16        24
       Irish       0.06      0.39      0.10        28
     Italian       0.26      0.81      0.40        90
    Japanese       0.42      0.42      0.42       117
      Korean       0.00      0.00      0.00        13
      Polish       0.10      0.06      0.07        18
  Portuguese       0.00      0.00      0.00         9
     Russian       0.75      0.59      0.66       357
    Scottish       0.00      0.00      0.00         2
     Spanish       

In [13]:
learner_loaded.predict_category('McMahan')

{'nationality': 'Irish',
 'probability': 0.10519526898860931,
 'surname': 'McMahan'}

In [14]:
learner_loaded.predict_category('Nakamoto')

{'nationality': 'Italian',
 'probability': 0.08115687966346741,
 'surname': 'Nakamoto'}