In [1]:
import warnings
warnings.filterwarnings('ignore')

import pickle
from product_classification.data_processing.text_processing import FitCategoricalData, InitFields, CreateDatasets, BuildTextVocabulary, BuildIterators
from product_classification.learner.train import CreateLearner, TrainHighLevels, FinetuneAll, EvaluateClassifier
from product_classification.models import CnnHyperParameters, Epochs, LearningRates

In [2]:
path = "../data"
with open(f"{path}/datasets.pkl", "rb") as handle:
    simple_datasets = pickle.load(handle)
    
with open(f"{path}/multilabel_binarizer.pkl", "rb") as handle:
    multilabel_binarizer = pickle.load(handle)
    
with open(f"{path}/pos_weight.pkl", "rb") as handle:
    pos_weight = pickle.load(handle)

#### Define CNN Hyperparameters

In [3]:
cnn_hyperparameters = CnnHyperParameters(nb_filters=100,
                                           kernels=[3,4,5],
                                           dropout=0.5,
                                           lrates=LearningRates(init_phase=5e-4,
                                                                finetuning_phase=1e-4),
                                           epochs=Epochs(init_phase=2,
                                                         finetuning_phase=3)
                                          )

#### Init torch components for text processing

In [4]:
fit_categorical_data = FitCategoricalData()
one_hot_encoder = fit_categorical_data.execute(processed_data=simple_datasets)

In [5]:
init_fields = InitFields()
torch_fields = init_fields.execute(one_hot_encoder=one_hot_encoder)

2022-09-11 20:02:21,483 :: text_processing/text_processing.py/execute :: INFO :: Categorical field created
2022-09-11 20:02:22,822 :: text_processing/text_processing.py/execute :: INFO :: Text field created
2022-09-11 20:02:22,824 :: text_processing/text_processing.py/execute :: INFO :: Label field created


#### Create datasets mapping features and labels

In [None]:
labels = list(simple_datasets.training.iloc[:,5:].columns)
cat_cols = ["brand_name", "merchant_name"]
create_datasets = CreateDatasets()
torch_dataset = create_datasets.execute(processed_data=simple_datasets,
                                        torch_fields=torch_fields,
                                        cnn_hparams=cnn_hyperparameters,
                                        txt_col="product_name",
                                        cat_cols=cat_cols,
                                        lbl_cols=labels)

2022-09-11 20:02:25,021 :: __init__/__init__.py/__init__ :: INFO :: fields: [('text', <torchtext.data.field.Field object at 0x7f1404094640>), ('category', <torchtext.data.field.Field object at 0x7f1404094ac0>), ('label', <torchtext.data.field.LabelField object at 0x7f1404094700>)]


#### Build text vocabulary with fasttext embeddings

In [None]:
vector_cache=".vector_cache"
vocab_size=100000
embedding_name="fasttext.simple.300d"
build_text_vocabulary = BuildTextVocabulary()
torch_fields = build_text_vocabulary.execute(torch_datasets=torch_dataset,
        torch_fields=torch_fields,
        vocab_size=vocab_size,
        embedding_name=embedding_name,
        vectors_cache=vector_cache)

#### Create pytorch iterators

In [None]:
batch_size=128 
device="cpu"
build_iterators = BuildIterators()
torch_iterators = build_iterators.execute(torch_datasets=torch_dataset,
                                             batch_size=batch_size,
                                             device=device)

#### Create Learner to train CNN

In [None]:
create_learner = CreateLearner()
learner = create_learner.execute(cnn_hparams=cnn_hyperparameters,
                                    embedding_name=embedding_name,
                                    torch_fields=torch_fields,
                                    processed_data=simple_datasets,
                                    batch_size=batch_size,
                                    label_number=len(labels),
                                    one_hot_encoder=one_hot_encoder,
                                    pos_weight=pos_weight)

#### Training

Train high levels layers (embedding layer freezed)

In [None]:
train_high_levels = TrainHighLevels()
learner = train_high_levels.execute(cnn_learner=learner,
        device=device,
        torch_iterators=torch_iterators,
        cnn_hparams=cnn_hyperparameters)

Train all model layers

In [None]:
finetune_all = FinetuneAll()
learner = finetune_all.execute(cnn_learner=learner,
                                device=device,
                                torch_iterators=torch_iterators,
                                cnn_hparams=cnn_hyperparameters)

#### Evaluation

In [None]:
evaluate_classifier = EvaluateClassifier()
perfs = evaluate_classifier.execute(cnn_learner=learner,
                                        device=device,
                                        torch_iterators=torch_iterators,
                                        multilabel_binarizer=multilabel_binarizer)