In [1]:
import warnings
warnings.filterwarnings('ignore')

import pickle
from product_classification.data_processing.text_processing import FitCategoricalData, InitFields, CreateDatasets, BuildTextVocabulary, BuildIterators
from product_classification.learner.train import CreateLearner, TrainHighLevels, FinetuneAll, EvaluateClassifier
from product_classification.models import CnnHyperParameters, Epochs, LearningRates

In [2]:
path = "../data"
with open(f"{path}/datasets.pkl", "rb") as handle:
    simple_datasets = pickle.load(handle)
    
with open(f"{path}/multilabel_binarizer.pkl", "rb") as handle:
    multilabel_binarizer = pickle.load(handle)
    
with open(f"{path}/pos_weight.pkl", "rb") as handle:
    pos_weight = pickle.load(handle)

#### Define CNN Hyperparameters

In [3]:
cnn_hyperparameters = CnnHyperParameters(nb_filters=100,
                                           kernels=[3,4,5],
                                           dropout=0.5,
                                           lrates=LearningRates(init_phase=5e-4,
                                                                finetuning_phase=1e-4),
                                           epochs=Epochs(init_phase=2,
                                                         finetuning_phase=3)
                                          )

#### Init torch components for text processing

In [4]:
fit_categorical_data = FitCategoricalData()
one_hot_encoder = fit_categorical_data.execute(processed_data=simple_datasets)

In [5]:
init_fields = InitFields()
torch_fields = init_fields.execute(one_hot_encoder=one_hot_encoder)

2022-09-12 07:10:15,812 :: text_processing/text_processing.py/execute :: INFO :: Categorical field created
2022-09-12 07:10:17,163 :: text_processing/text_processing.py/execute :: INFO :: Text field created
2022-09-12 07:10:17,165 :: text_processing/text_processing.py/execute :: INFO :: Label field created


#### Create datasets mapping features and labels

In [6]:
labels = list(simple_datasets.training.iloc[:,5:].columns)
cat_cols = ["brand_name", "merchant_name"]
create_datasets = CreateDatasets()
torch_dataset = create_datasets.execute(processed_data=simple_datasets,
                                        torch_fields=torch_fields,
                                        cnn_hparams=cnn_hyperparameters,
                                        txt_col="product_name",
                                        cat_cols=cat_cols,
                                        lbl_cols=labels)

2022-09-12 07:10:17,196 :: __init__/__init__.py/__init__ :: INFO :: fields: [('text', <torchtext.data.field.Field object at 0x7f309a944c10>), ('category', <torchtext.data.field.Field object at 0x7f309a946670>), ('label', <torchtext.data.field.LabelField object at 0x7f312fbdfa00>)]
2022-09-12 07:11:00,794 :: __init__/__init__.py/__init__ :: INFO :: fields: [('text', <torchtext.data.field.Field object at 0x7f309a944c10>), ('category', <torchtext.data.field.Field object at 0x7f309a946670>), ('label', <torchtext.data.field.LabelField object at 0x7f312fbdfa00>)]
2022-09-12 07:11:04,378 :: __init__/__init__.py/__init__ :: INFO :: fields: [('text', <torchtext.data.field.Field object at 0x7f309a944c10>), ('category', <torchtext.data.field.Field object at 0x7f309a946670>), ('label', <torchtext.data.field.LabelField object at 0x7f312fbdfa00>)]
2022-09-12 07:11:08,149 :: text_processing/text_processing.py/execute :: INFO :: Torchtext datasets created from dataframes


#### Build text vocabulary with fasttext embeddings

In [7]:
vector_cache=".vector_cache"
vocab_size=100000
embedding_name="fasttext.simple.300d"
build_text_vocabulary = BuildTextVocabulary()
torch_fields = build_text_vocabulary.execute(torch_datasets=torch_dataset,
        torch_fields=torch_fields,
        vocab_size=vocab_size,
        embedding_name=embedding_name,
        vectors_cache=vector_cache)

2022-09-12 07:11:10,112 :: text_processing/text_processing.py/execute :: INFO :: Text vocabulary created. Corpus_dim: 62838


#### Create pytorch iterators

In [8]:
batch_size=128 
device="cpu"
build_iterators = BuildIterators()
torch_iterators = build_iterators.execute(torch_datasets=torch_dataset,
                                             batch_size=batch_size,
                                             device=device)

2022-09-12 07:11:10,122 :: text_processing/text_processing.py/execute :: INFO :: Torchtext iterators created from datasets


#### Create Learner to train CNN

In [9]:
import torch
pos_weight = torch.Tensor(pos_weight).to(device)
create_learner = CreateLearner()
learner = create_learner.execute(cnn_hparams=cnn_hyperparameters,
                                    embedding_name=embedding_name,
                                    torch_fields=torch_fields,
                                    processed_data=simple_datasets,
                                    batch_size=batch_size,
                                    label_number=len(labels),
                                    one_hot_encoder=one_hot_encoder,
                                    pos_weight=pos_weight)

2022-09-12 07:11:10,446 :: train/train.py/execute :: INFO :: CNN learner compiled for the text classification task


#### Training

Train high levels layers (embedding layer freezed)

In [10]:
train_high_levels = TrainHighLevels()
learner = train_high_levels.execute(cnn_learner=learner,
        device=device,
        torch_iterators=torch_iterators,
        cnn_hparams=cnn_hyperparameters)

2022-09-12 07:11:10,458 :: train/train.py/execute :: INFO :: Starting training of CNN and classification layers


Embedding(62838, 300)


Training Epoch [1/2]:  20%|██████████████████▎                                                                          | 132/669 [01:02<03:47,  2.36it/s, loss=1.16]
  0%|                                                                                                                                         | 0/60 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                             | 0/60 [00:00<?, ?it/s][A
Validation:   3%|███▉                                                                                                                 | 2/60 [00:00<00:03, 18.01it/s][A
Validation:   8%|█████████▊                                                                                                           | 5/60 [00:00<00:02, 20.26it/s][A
Validation:   8%|████████▊                                                                                                | 5/60 [00:00<00:02, 20.26it/s, loss

Train all model layers

In [11]:
finetune_all = FinetuneAll()
learner = finetune_all.execute(cnn_learner=learner,
                                device=device,
                                torch_iterators=torch_iterators,
                                cnn_hparams=cnn_hyperparameters,
                                processed_data=simple_datasets,
                                batch_size=batch_size,
                                pos_weight=pos_weight)

2022-09-12 07:20:24,558 :: train/train.py/execute :: INFO :: Unfreezing the whole CNN network
2022-09-12 07:20:24,561 :: train/train.py/execute :: INFO :: Starting finetuning of the whole CNN network
Training Epoch [1/2]:  20%|██████████████████▏                                                                         | 132/669 [01:39<06:53,  1.30it/s, loss=0.781]
  0%|                                                                                                                                         | 0/60 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                             | 0/60 [00:00<?, ?it/s][A
Validation:   3%|███▉                                                                                                                 | 2/60 [00:00<00:02, 19.49it/s][A
Validation:   8%|█████████▊                                                                                                    

#### Evaluation

In [12]:
evaluate_classifier = EvaluateClassifier()
perfs = evaluate_classifier.execute(cnn_learner=learner,
                                        device=device,
                                        torch_iterators=torch_iterators,
                                        multilabel_binarizer=multilabel_binarizer)

2022-09-12 07:34:57,711 :: train/train.py/execute :: INFO :: Test results
2022-09-12 07:34:57,712 :: train/train.py/execute :: INFO :: {'animalerie': {'precision': 0.7431818181818182, 'recall': 0.9108635097493036, 'f1-score': 0.818523153942428, 'support': 359}, 'auto et moto': {'precision': 0.556390977443609, 'recall': 0.8680351906158358, 'f1-score': 0.6781214203894617, 'support': 341}, 'bagages et sacs': {'precision': 0.45010615711252655, 'recall': 0.8760330578512396, 'f1-score': 0.5946704067321178, 'support': 242}, 'beaute et parfum': {'precision': 0.8623853211009175, 'recall': 0.9591836734693877, 'f1-score': 0.9082125603864735, 'support': 294}, 'bebe et puericulture': {'precision': 0.5568862275449101, 'recall': 0.8254437869822485, 'f1-score': 0.6650774731823599, 'support': 338}, 'bijoux': {'precision': 0.8057142857142857, 'recall': 0.9276315789473685, 'f1-score': 0.8623853211009175, 'support': 304}, 'bricolage': {'precision': 0.6003898635477583, 'recall': 0.9221556886227545, 'f1-sco

In [16]:
import pandas as pd
pd.DataFrame(perfs).T

Unnamed: 0,precision,recall,f1-score,support
animalerie,0.743182,0.910864,0.818523,359.0
auto et moto,0.556391,0.868035,0.678121,341.0
bagages et sacs,0.450106,0.876033,0.59467,242.0
beaute et parfum,0.862385,0.959184,0.908213,294.0
bebe et puericulture,0.556886,0.825444,0.665077,338.0
bijoux,0.805714,0.927632,0.862385,304.0
bricolage,0.60039,0.922156,0.727273,334.0
cd et vinyles,0.132075,0.875,0.229508,16.0
chaussures et accessoires,0.629474,0.943218,0.755051,317.0
"commerce, industrie et science",0.202703,0.6,0.30303,25.0
