# Datasets


In [4]:
import torch
import torch.nn as nn

from bow_text_classifier.data import create_training_datasets

# Load the default datasets
train_data, test_data, word_to_index, tag_to_index = create_training_datasets()

# Training


In [5]:
from bow_text_classifier.nn import _BoW, train_bow

device = "cuda" if torch.cuda.is_available() else "cpu"

type = torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor

# train and test the BoW model
model = _BoW(len(word_to_index), len(tag_to_index)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
type = torch.LongTensor

if torch.cuda.is_available():
    model.to(device)
    type = torch.cuda.LongTensor


# call the train_bow function
train_bow(model, optimizer, criterion, train_data, test_data, type)

ITER: 1 | train loss/sent: 1.4751 | train accuracy: 0.3626 | test accuracy: 0.4005
ITER: 2 | train loss/sent: 1.1220 | train accuracy: 0.6074 | test accuracy: 0.4100
ITER: 3 | train loss/sent: 0.9124 | train accuracy: 0.7117 | test accuracy: 0.4140
ITER: 4 | train loss/sent: 0.7687 | train accuracy: 0.7670 | test accuracy: 0.4113
ITER: 5 | train loss/sent: 0.6628 | train accuracy: 0.8078 | test accuracy: 0.4158
ITER: 6 | train loss/sent: 0.5821 | train accuracy: 0.8325 | test accuracy: 0.4068
ITER: 7 | train loss/sent: 0.5164 | train accuracy: 0.8545 | test accuracy: 0.4032
ITER: 8 | train loss/sent: 0.4637 | train accuracy: 0.8714 | test accuracy: 0.4036
ITER: 9 | train loss/sent: 0.4188 | train accuracy: 0.8814 | test accuracy: 0.4018
ITER: 10 | train loss/sent: 0.3823 | train accuracy: 0.8915 | test accuracy: 0.3941


# Saving the model


In [6]:
from bow_text_classifier.nn import save_model, model_dir

save_model(model, word_to_index, tag_to_index, model_dir / "bow_model")

In [None]:
# Inference and Loading

In [7]:
from bow_text_classifier.nn import BoW_Classifier

classifier = BoW_Classifier()
classifier.load_model(model_dir / "bow_model")

sample_sentence = "I love programming"
predicted_tag = classifier.predict(sample_sentence)

print(f"Predicted Tag: {predicted_tag}")

Predicted Tag: 4
