In [1]:
from argparse import Namespace
from src.learner import Learner

## Args

In [2]:
args = Namespace(
    # Data and Path hyper parameters
    news_csv="data/ag_news/news_with_splits.csv",
    vectorizer_file="vectorizer.json",
    model_state_file="model.pth",
    save_dir="model_storage/ch5/document_classification",
    # Model hyper parameters
    glove_filepath='data/glove/glove.6B.100d.txt', 
    use_glove=True,
    embedding_size=100, 
    hidden_dim=100, 
    num_channels=100, 
    # Training hyper parameter
    seed=1337, 
    learning_rate=0.001, 
    dropout_p=0.1, 
    batch_size=128, 
    num_epochs=2,#100, 
    early_stopping_criteria=5, 
    # Runtime option
    cuda=True, 
    catch_keyboard_interrupt=True, 
    reload_from_files=False,
    expand_filepaths_to_save_dir=True
) 

## Learner

In [3]:
learner=Learner.learner_from_args(args)

Expanded filepaths: 
	model_storage/ch5/document_classification\vectorizer.json
	model_storage/ch5/document_classification\model.pth
Using CUDA: False
Loading dataset and creating vectorizer
Using pre-trained embeddings


## Train

In [4]:
learner.train(num_epochs=1)

HBox(children=(IntProgress(value=0, description='training routine', max=1, style=ProgressStyle(description_wid…

HBox(children=(IntProgress(value=0, description='split=train', max=656, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='split=val', max=140, style=ProgressStyle(description_width='i…

## Validate

In [5]:
learner.validate()

Test loss: 0.669
Test Accuracy: 75.1


## Predict

In [6]:
learner.predict_category('new crisis is coming')

{'category': 'Sports', 'probability': 0.5312585830688477}

In [7]:
learner.predict_category('they won stanley cup')

{'category': 'Sports', 'probability': 0.9244305491447449}

In [8]:
learner.predict_category('malpractice insurers face a tough market')

{'category': 'Business', 'probability': 0.5694186687469482}

## Load model

In [9]:
args_saved = Namespace(
    # Data and Path hyper parameters
    news_csv="data/ag_news/news_with_splits.csv",
    vectorizer_file="vectorizer.json",
    model_state_file="model.pth",
    save_dir="model_storage/ch5/document_classification",
    # Model hyper parameters
    glove_filepath='data/glove/glove.6B.100d.txt', 
    use_glove=True,
    embedding_size=100, 
    hidden_dim=100, 
    num_channels=100, 
    # Training hyper parameter
    seed=1337, 
    learning_rate=0.001, 
    dropout_p=0.1, 
    batch_size=128, 
    num_epochs=2,#100, 
    early_stopping_criteria=5, 
    # Runtime option
    cuda=True, 
    catch_keyboard_interrupt=True, 
    reload_from_files=True,
    expand_filepaths_to_save_dir=True
) 

In [10]:
learner_loaded=Learner.learner_from_args(args_saved)

Expanded filepaths: 
	model_storage/ch5/document_classification\vectorizer.json
	model_storage/ch5/document_classification\model.pth
Using CUDA: False
Loading dataset and loading vectorizer
Using pre-trained embeddings


In [11]:
learner_loaded.validate()

Test loss: 0.669
Test Accuracy: 75.1
