In [1]:
from src.learner import Learner
from argparse import Namespace

## Args

In [2]:
args = Namespace(
    # Data and Path information
    surname_csv="data/surnames/surnames_with_splits.csv",
    vectorizer_file="vectorizer.json",
    model_state_file="model.pth",
    save_dir="model_storage/ch7/model2_conditioned_surname_generation",
    # Model hyper parameters
    char_embedding_size=32,
    rnn_hidden_size=32,
    # Training hyper parameters
    seed=1337,
    learning_rate=0.001,
    batch_size=128,
    num_epochs=10,
    early_stopping_criteria=5,
    # Runtime options
    catch_keyboard_interrupt=True,
    cuda=True,
    expand_filepaths_to_save_dir=True,
    reload_from_files=False,
    conditioned=True
)

## Learner

In [3]:
learner=Learner.learner_from_args(args)

Expanded filepaths: 
	model_storage/ch7/model2_conditioned_surname_generation\vectorizer.json
	model_storage/ch7/model2_conditioned_surname_generation\model.pth
Using CUDA: False
Loading dataset and creating vectorizer


## Train

In [4]:
learner.train()

HBox(children=(IntProgress(value=0, description='training routine', max=10, style=ProgressStyle(description_wi…

HBox(children=(IntProgress(value=0, description='split=train', max=60, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='split=val', max=12, style=ProgressStyle(description_width='in…

In [5]:
learner.validate()

Test loss: 2.776
Test Accuracy: 21.825


## Generate

In [6]:
learner.generate(3)

Sampled for Arabic: 
-  Surkin
-  Anan
-  Hhasaon
Sampled for Chinese: 
-  ta
-  Aeh
-  Kiy
Sampled for Czech: 
-  Ha
-  Beh
-  Taatyela
Sampled for Dutch: 
-  áealeVivenri
-  Mrrtge
-  Daele
Sampled for English: 
-  SaTutot
-  Ltbinss
-  Socerh
Sampled for French: 
-  Dehe
-  Vada
-  Pmhtn
Sampled for German: 
-  Det
-  Cineo
-  Sellec
Sampled for Greek: 
-  Lltedn
-  Ha
-  Taha
Sampled for Irish: 
-  Amte
-  Osaeuo
-  Wenr
Sampled for Italian: 
-  JiHbar
-  Celte
-  iie
Sampled for Japanese: 
-  Hra
-  Iaetb
-  Karda
Sampled for Korean: 
-  Mel
-  Soirei
-  Oeo
Sampled for Polish: 
-  Mbetv
-  
-  Hlonrs
Sampled for Portuguese: 
-  Saeao
-  Msar
-  Saii
Sampled for Russian: 
-  Vaka
-  Sei
-  Eaoun
Sampled for Scottish: 
-  Silae
-  Var
-  Dacanor
Sampled for Spanish: 
-  Mzahra
-  BVaead
-  Si
Sampled for Vietnamese: 
-  kl
-  Medy
-  1aal


## Load saved model

In [7]:
args_saved= Namespace(
    # Data and Path information
    surname_csv="data/surnames/surnames_with_splits.csv",
    vectorizer_file="vectorizer.json",
    model_state_file="model.pth",
    save_dir="model_storage/ch7/model2_conditioned_surname_generation",
    # Model hyper parameters
    char_embedding_size=32,
    rnn_hidden_size=32,
    # Training hyper parameters
    seed=1337,
    learning_rate=0.001,
    batch_size=128,
    num_epochs=10,
    early_stopping_criteria=5,
    # Runtime options
    catch_keyboard_interrupt=True,
    cuda=True,
    expand_filepaths_to_save_dir=True,
    reload_from_files=True,
    conditioned=True
)

In [8]:
learner_loaded=Learner.learner_from_args(args_saved)
learner_loaded.validate()

Expanded filepaths: 
	model_storage/ch7/model2_conditioned_surname_generation\vectorizer.json
	model_storage/ch7/model2_conditioned_surname_generation\model.pth
Using CUDA: False
Loading dataset and loading vectorizer
Test loss: 2.776
Test Accuracy: 21.825
