In [1]:
import warnings
warnings.filterwarnings('ignore')

import pickle
from product_classification.data_processing.text_processing import FitCategoricalData, InitFields, CreateDatasets, BuildTextVocabulary, BuildIterators
from product_classification.learner.train import CreateLearner, TrainHighLevels, FinetuneAll, EvaluateClassifier
from product_classification.models import CnnHyperParameters, Epochs, LearningRates

In [2]:
path = "../data"
with open(f"{path}/datasets.pkl", "rb") as handle:
    simple_datasets = pickle.load(handle)
    
with open(f"{path}/multilabel_binarizer.pkl", "rb") as handle:
    multilabel_binarizer = pickle.load(handle)

#### Define CNN Hyperparameters

In [4]:
cnn_hyperparameters = CnnHyperParameters(nb_filters=100,
                                           kernels=[3,4,5],
                                           dropout=0.5,
                                           lrates=LearningRates(init_phase=5e-4,
                                                                finetuning_phase=1e-4),
                                           epochs=Epochs(init_phase=1,
                                                         finetuning_phase=1)
                                          )

#### Init torch components for text processing

In [5]:
fit_categorical_data = FitCategoricalData()
one_hot_encoder = fit_categorical_data.execute(processed_data=simple_datasets)

In [6]:
init_fields = InitFields()
torch_fields = init_fields.execute(one_hot_encoder=one_hot_encoder)

2022-09-11 16:11:23,654 :: text_processing/text_processing.py/execute :: INFO :: Categorical field created
2022-09-11 16:11:24,943 :: text_processing/text_processing.py/execute :: INFO :: Text field created
2022-09-11 16:11:24,944 :: text_processing/text_processing.py/execute :: INFO :: Label field created


#### Create datasets mapping features and labels

In [7]:
labels = list(simple_datasets.training.iloc[:,5:].columns)
cat_cols = ["brand_name", "merchant_name"]
create_datasets = CreateDatasets()
torch_dataset = create_datasets.execute(processed_data=simple_datasets,
                                        torch_fields=torch_fields,
                                        cnn_hparams=cnn_hyperparameters,
                                        txt_col="product_name",
                                        cat_cols=cat_cols,
                                        lbl_cols=labels)

2022-09-11 16:11:24,961 :: __init__/__init__.py/__init__ :: INFO :: fields: [('text', <torchtext.data.field.Field object at 0x7f1752a79e80>), ('category', <torchtext.data.field.Field object at 0x7f1752a79d00>), ('label', <torchtext.data.field.LabelField object at 0x7f17e449e550>)]
2022-09-11 16:17:29,720 :: __init__/__init__.py/__init__ :: INFO :: fields: [('text', <torchtext.data.field.Field object at 0x7f1752a79e80>), ('category', <torchtext.data.field.Field object at 0x7f1752a79d00>), ('label', <torchtext.data.field.LabelField object at 0x7f17e449e550>)]
2022-09-11 16:17:36,751 :: __init__/__init__.py/__init__ :: INFO :: fields: [('text', <torchtext.data.field.Field object at 0x7f1752a79e80>), ('category', <torchtext.data.field.Field object at 0x7f1752a79d00>), ('label', <torchtext.data.field.LabelField object at 0x7f17e449e550>)]
2022-09-11 16:19:04,666 :: text_processing/text_processing.py/execute :: INFO :: Torchtext datasets created from dataframes


#### Build text vocabulary with fasttext embeddings

In [8]:
vector_cache=".vector_cache"
vocab_size=100000
embedding_name="fasttext.simple.300d"
build_text_vocabulary = BuildTextVocabulary()
torch_fields = build_text_vocabulary.execute(torch_datasets=torch_dataset,
        torch_fields=torch_fields,
        vocab_size=vocab_size,
        embedding_name=embedding_name,
        vectors_cache=vector_cache)

2022-09-11 16:19:05,672 :: text_processing/text_processing.py/execute :: INFO :: Text vocabulary created. Corpus_dim: 30851


#### Create pytorch iterators

In [9]:
batch_size=128 
device="cpu"
build_iterators = BuildIterators()
torch_iterators = build_iterators.execute(torch_datasets=torch_dataset,
                                             batch_size=batch_size,
                                             device=device)

2022-09-11 16:19:05,682 :: text_processing/text_processing.py/execute :: INFO :: Torchtext iterators created from datasets


#### Create Learner to train CNN

In [10]:
create_learner = CreateLearner()
learner = create_learner.execute(cnn_hparams=cnn_hyperparameters,
                                    embedding_name=embedding_name,
                                    torch_fields=torch_fields,
                                    processed_data=simple_datasets,
                                    batch_size=batch_size,
                                    label_number=len(labels),
                                    one_hot_encoder=one_hot_encoder)

2022-09-11 16:19:05,878 :: train/train.py/execute :: INFO :: CNN learner compiled for the text classification task


#### Training

Train high levels layers (embedding layer freezed)

In [11]:
train_high_levels = TrainHighLevels()
learner = train_high_levels.execute(cnn_learner=learner,
        device=device,
        torch_iterators=torch_iterators,
        cnn_hparams=cnn_hyperparameters)

2022-09-11 16:19:05,887 :: train/train.py/execute :: INFO :: Starting training of CNN and classification layers


Embedding(30851, 300)


Training Epoch [1/1]:  10%|██████████▍                                                                                                | 38/391 [01:25<13:21,  2.27s/it, loss=0.707]
  0%|                                                                                                                                                        | 0/8 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                            | 0/8 [00:00<?, ?it/s][A
Validation:  12%|████████████████▌                                                                                                                   | 1/8 [00:01<00:13,  1.94s/it][A
Validation:  12%|███████████████                                                                                                         | 1/8 [00:03<00:13,  1.94s/it, loss=0.334][A
Validation:  25%|██████████████████████████████                                         

Validation:
	hamming_loss = 0.031
	f1_score = 0.283
	loss = 0.668
****************************************


Training Epoch [1/1]:  20%|█████████████████████▎                                                                                      | 77/391 [03:04<11:10,  2.14s/it, loss=0.47]
  0%|                                                                                                                                                        | 0/8 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                            | 0/8 [00:00<?, ?it/s][A
Validation:  12%|████████████████▌                                                                                                                   | 1/8 [00:01<00:13,  1.91s/it][A
Validation:  12%|██████████████▉                                                                                                        | 1/8 [00:03<00:13,  1.91s/it, loss=0.0544][A
Validation:  25%|█████████████████████████████▊                                         

Validation:
	hamming_loss = 0.033
	f1_score = 0.052
	loss = 0.101
****************************************


Training Epoch [1/1]:  30%|███████████████████████████████▍                                                                          | 116/391 [04:42<09:35,  2.09s/it, loss=0.343]
  0%|                                                                                                                                                        | 0/8 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                            | 0/8 [00:00<?, ?it/s][A
Validation:  12%|████████████████▌                                                                                                                   | 1/8 [00:01<00:13,  1.96s/it][A
Validation:  12%|██████████████▉                                                                                                        | 1/8 [00:03<00:13,  1.96s/it, loss=0.0448][A
Validation:  25%|█████████████████████████████▊                                         

Validation:
	hamming_loss = 0.026
	f1_score = 0.258
	loss = 0.078
****************************************


Training Epoch [1/1]:  40%|██████████████████████████████████████████                                                                | 155/391 [06:21<08:21,  2.12s/it, loss=0.275]
  0%|                                                                                                                                                        | 0/8 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                            | 0/8 [00:00<?, ?it/s][A
Validation:  12%|████████████████▌                                                                                                                   | 1/8 [00:01<00:13,  1.88s/it][A
Validation:  12%|██████████████▉                                                                                                        | 1/8 [00:03<00:13,  1.88s/it, loss=0.0381][A
Validation:  25%|█████████████████████████████▊                                         

Validation:
	hamming_loss = 0.021
	f1_score = 0.410
	loss = 0.064
****************************************


Training Epoch [1/1]:  50%|████████████████████████████████████████████████████▌                                                     | 194/391 [07:59<06:53,  2.10s/it, loss=0.233]
  0%|                                                                                                                                                        | 0/8 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                            | 0/8 [00:00<?, ?it/s][A
Validation:  12%|████████████████▌                                                                                                                   | 1/8 [00:02<00:14,  2.05s/it][A
Validation:  12%|██████████████▉                                                                                                        | 1/8 [00:03<00:14,  2.05s/it, loss=0.0342][A
Validation:  25%|█████████████████████████████▊                                         

Validation:
	hamming_loss = 0.018
	f1_score = 0.527
	loss = 0.057
****************************************


Training Epoch [1/1]:  60%|███████████████████████████████████████████████████████████████▏                                          | 233/391 [09:37<05:44,  2.18s/it, loss=0.204]
  0%|                                                                                                                                                        | 0/8 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                            | 0/8 [00:00<?, ?it/s][A
Validation:  12%|████████████████▌                                                                                                                   | 1/8 [00:01<00:12,  1.82s/it][A
Validation:  12%|██████████████▉                                                                                                        | 1/8 [00:03<00:12,  1.82s/it, loss=0.0303][A
Validation:  25%|█████████████████████████████▊                                         

Validation:
	hamming_loss = 0.015
	f1_score = 0.612
	loss = 0.052
****************************************


Training Epoch [1/1]:  70%|█████████████████████████████████████████████████████████████████████████▋                                | 272/391 [11:15<04:10,  2.11s/it, loss=0.182]
  0%|                                                                                                                                                        | 0/8 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                            | 0/8 [00:00<?, ?it/s][A
Validation:  12%|████████████████▌                                                                                                                   | 1/8 [00:01<00:13,  1.95s/it][A
Validation:  12%|██████████████▉                                                                                                        | 1/8 [00:03<00:13,  1.95s/it, loss=0.0278][A
Validation:  25%|█████████████████████████████▊                                         

Validation:
	hamming_loss = 0.014
	f1_score = 0.648
	loss = 0.048
****************************************


Training Epoch [1/1]:  80%|████████████████████████████████████████████████████████████████████████████████████▎                     | 311/391 [12:53<02:53,  2.17s/it, loss=0.165]
  0%|                                                                                                                                                        | 0/8 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                            | 0/8 [00:00<?, ?it/s][A
Validation:  12%|████████████████▌                                                                                                                   | 1/8 [00:02<00:14,  2.06s/it][A
Validation:  12%|███████████████                                                                                                         | 1/8 [00:04<00:14,  2.06s/it, loss=0.027][A
Validation:  25%|██████████████████████████████                                         

Validation:
	hamming_loss = 0.014
	f1_score = 0.661
	loss = 0.046
****************************************


Training Epoch [1/1]:  90%|██████████████████████████████████████████████████████████████████████████████████████████████▉           | 350/391 [14:32<01:25,  2.09s/it, loss=0.151]
  0%|                                                                                                                                                        | 0/8 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                            | 0/8 [00:00<?, ?it/s][A
Validation:  12%|████████████████▌                                                                                                                   | 1/8 [00:01<00:12,  1.84s/it][A
Validation:  12%|██████████████▉                                                                                                        | 1/8 [00:03<00:12,  1.84s/it, loss=0.0247][A
Validation:  25%|█████████████████████████████▊                                         

Validation:
	hamming_loss = 0.013
	f1_score = 0.683
	loss = 0.043
****************************************


Training Epoch [1/1]:  99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▍| 389/391 [16:09<00:04,  2.10s/it, loss=0.14]
  0%|                                                                                                                                                        | 0/8 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                            | 0/8 [00:00<?, ?it/s][A
Validation:  12%|████████████████▌                                                                                                                   | 1/8 [00:02<00:14,  2.04s/it][A
Validation:  12%|██████████████▉                                                                                                        | 1/8 [00:04<00:14,  2.04s/it, loss=0.0259][A
Validation:  25%|█████████████████████████████▊                                         

Validation:
	hamming_loss = 0.013
	f1_score = 0.657
	loss = 0.042
****************************************


Training Epoch [1/1]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [16:26<00:00,  2.52s/it, loss=0.14]


Training at epoch 1:
	hamming_loss = 0.013
	f1_score = 0.656
	loss = 0.140


Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:15<00:00,  1.88s/it, loss=0.0366]

Validation:
	hamming_loss = 0.013
	f1_score = 0.658
	loss = 0.042
****************************************





Train all model layers

In [12]:
finetune_all = FinetuneAll()
learner = finetune_all.execute(cnn_learner=learner,
                                device=device,
                                torch_iterators=torch_iterators,
                                cnn_hparams=cnn_hyperparameters)

2022-09-11 16:35:47,607 :: train/train.py/execute :: INFO :: Unfreezing the whole CNN network
2022-09-11 16:35:47,611 :: train/train.py/execute :: INFO :: Starting finetuning of the whole CNN network
Training Epoch [1/1]:  10%|██████████▍                                                                                                | 38/391 [01:31<13:29,  2.29s/it, loss=0.612]
  0%|                                                                                                                                                        | 0/8 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                            | 0/8 [00:00<?, ?it/s][A
Validation:  12%|████████████████▌                                                                                                                   | 1/8 [00:01<00:13,  1.93s/it][A
Validation:  12%|███████████████                                       

Validation:
	hamming_loss = 0.083
	f1_score = 0.643
	loss = 0.622
****************************************


Training Epoch [1/1]:  20%|█████████████████████                                                                                      | 77/391 [03:17<11:55,  2.28s/it, loss=0.373]
  0%|                                                                                                                                                        | 0/8 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                            | 0/8 [00:00<?, ?it/s][A
Validation:  12%|████████████████▌                                                                                                                   | 1/8 [00:01<00:13,  1.95s/it][A
Validation:  12%|██████████████▉                                                                                                        | 1/8 [00:03<00:13,  1.95s/it, loss=0.0342][A
Validation:  25%|█████████████████████████████▊                                         

Validation:
	hamming_loss = 0.016
	f1_score = 0.558
	loss = 0.059
****************************************


Training Epoch [1/1]:  30%|███████████████████████████████▍                                                                          | 116/391 [05:01<10:42,  2.34s/it, loss=0.266]
  0%|                                                                                                                                                        | 0/8 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                            | 0/8 [00:00<?, ?it/s][A
Validation:  12%|████████████████▌                                                                                                                   | 1/8 [00:01<00:13,  1.86s/it][A
Validation:  12%|██████████████▉                                                                                                        | 1/8 [00:03<00:13,  1.86s/it, loss=0.0287][A
Validation:  25%|█████████████████████████████▊                                         

Validation:
	hamming_loss = 0.015
	f1_score = 0.605
	loss = 0.050
****************************************


Training Epoch [1/1]:  40%|██████████████████████████████████████████                                                                | 155/391 [06:45<08:42,  2.21s/it, loss=0.211]
  0%|                                                                                                                                                        | 0/8 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                            | 0/8 [00:00<?, ?it/s][A
Validation:  12%|████████████████▌                                                                                                                   | 1/8 [00:01<00:12,  1.84s/it][A
Validation:  12%|███████████████                                                                                                         | 1/8 [00:03<00:12,  1.84s/it, loss=0.027][A
Validation:  25%|██████████████████████████████                                         

Validation:
	hamming_loss = 0.014
	f1_score = 0.635
	loss = 0.046
****************************************


Training Epoch [1/1]:  50%|████████████████████████████████████████████████████▌                                                     | 194/391 [08:30<07:19,  2.23s/it, loss=0.177]
  0%|                                                                                                                                                        | 0/8 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                            | 0/8 [00:00<?, ?it/s][A
Validation:  12%|████████████████▌                                                                                                                   | 1/8 [00:01<00:12,  1.85s/it][A
Validation:  12%|██████████████▉                                                                                                        | 1/8 [00:03<00:12,  1.85s/it, loss=0.0252][A
Validation:  25%|█████████████████████████████▊                                         

Validation:
	hamming_loss = 0.013
	f1_score = 0.669
	loss = 0.043
****************************************


Training Epoch [1/1]:  60%|███████████████████████████████████████████████████████████████▏                                          | 233/391 [10:12<05:48,  2.20s/it, loss=0.154]
  0%|                                                                                                                                                        | 0/8 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                            | 0/8 [00:00<?, ?it/s][A
Validation:  12%|████████████████▌                                                                                                                   | 1/8 [00:01<00:13,  1.92s/it][A
Validation:  12%|██████████████▉                                                                                                        | 1/8 [00:03<00:13,  1.92s/it, loss=0.0234][A
Validation:  25%|█████████████████████████████▊                                         

Validation:
	hamming_loss = 0.012
	f1_score = 0.693
	loss = 0.040
****************************************


Training Epoch [1/1]:  70%|█████████████████████████████████████████████████████████████████████████▋                                | 272/391 [11:57<04:34,  2.30s/it, loss=0.138]
  0%|                                                                                                                                                        | 0/8 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                            | 0/8 [00:00<?, ?it/s][A
Validation:  12%|████████████████▌                                                                                                                   | 1/8 [00:02<00:16,  2.34s/it][A
Validation:  12%|██████████████▉                                                                                                        | 1/8 [00:04<00:16,  2.34s/it, loss=0.0229][A
Validation:  25%|█████████████████████████████▊                                         

Validation:
	hamming_loss = 0.012
	f1_score = 0.699
	loss = 0.039
****************************************


Training Epoch [1/1]:  80%|████████████████████████████████████████████████████████████████████████████████████▎                     | 311/391 [13:41<02:59,  2.25s/it, loss=0.125]
  0%|                                                                                                                                                        | 0/8 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                            | 0/8 [00:00<?, ?it/s][A
Validation:  12%|████████████████▌                                                                                                                   | 1/8 [00:01<00:13,  1.87s/it][A
Validation:  12%|███████████████                                                                                                         | 1/8 [00:03<00:13,  1.87s/it, loss=0.022][A
Validation:  25%|██████████████████████████████                                         

Validation:
	hamming_loss = 0.011
	f1_score = 0.725
	loss = 0.037
****************************************


Training Epoch [1/1]:  90%|██████████████████████████████████████████████████████████████████████████████████████████████▉           | 350/391 [15:25<01:33,  2.28s/it, loss=0.115]
  0%|                                                                                                                                                        | 0/8 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                            | 0/8 [00:00<?, ?it/s][A
Validation:  12%|████████████████▌                                                                                                                   | 1/8 [00:01<00:13,  1.88s/it][A
Validation:  12%|██████████████▉                                                                                                        | 1/8 [00:03<00:13,  1.88s/it, loss=0.0222][A
Validation:  25%|█████████████████████████████▊                                         

Validation:
	hamming_loss = 0.011
	f1_score = 0.726
	loss = 0.037
****************************************


Training Epoch [1/1]:  99%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▍| 389/391 [17:09<00:04,  2.23s/it, loss=0.107]
  0%|                                                                                                                                                        | 0/8 [00:00<?, ?it/s][A
Validation:   0%|                                                                                                                                            | 0/8 [00:00<?, ?it/s][A
Validation:  12%|████████████████▌                                                                                                                   | 1/8 [00:01<00:13,  1.98s/it][A
Validation:  12%|██████████████▉                                                                                                        | 1/8 [00:03<00:13,  1.98s/it, loss=0.0208][A
Validation:  25%|█████████████████████████████▊                                         

Validation:
	hamming_loss = 0.011
	f1_score = 0.741
	loss = 0.035
****************************************


Training Epoch [1/1]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [17:26<00:00,  2.68s/it, loss=0.107]


Training at epoch 1:
	hamming_loss = 0.011
	f1_score = 0.736
	loss = 0.107


Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:14<00:00,  1.83s/it, loss=0.0307]
2022-09-11 16:53:28,816 :: train/train.py/execute :: INFO :: Training done


Validation:
	hamming_loss = 0.011
	f1_score = 0.741
	loss = 0.035
****************************************


#### Evaluation

In [13]:
evaluate_classifier = EvaluateClassifier()
perfs = evaluate_classifier.execute(cnn_learner=learner,
                                        device=device,
                                        torch_iterators=torch_iterators,
                                        multilabel_binarizer=multilabel_binarizer)

2022-09-11 16:56:06,153 :: train/train.py/execute :: INFO :: Test results
2022-09-11 16:56:06,156 :: train/train.py/execute :: INFO :: {
    "animalerie": {
        "precision": 0.9841772151898734,
        "recall": 0.4712121212121212,
        "f1-score": 0.6372950819672132,
        "support": 660
    },
    "auto et moto": {
        "precision": 0.0,
        "recall": 0.0,
        "f1-score": 0.0,
        "support": 342
    },
    "bagages et sacs": {
        "precision": 0.6666666666666666,
        "recall": 0.008333333333333333,
        "f1-score": 0.01646090534979424,
        "support": 240
    },
    "beaute et parfum": {
        "precision": 0.9097888675623801,
        "recall": 0.7834710743801653,
        "f1-score": 0.8419182948490231,
        "support": 605
    },
    "bebe et puericulture": {
        "precision": 1.0,
        "recall": 0.15447154471544716,
        "f1-score": 0.2676056338028169,
        "support": 369
    },
    "bijoux": {
        "precision": 0.974358974358