In [1]:
import warnings
warnings.filterwarnings('ignore')

import pickle
from product_classification.data_processing.text_processing import FitCategoricalData, InitFields, CreateDatasets, BuildTextVocabulary, BuildIterators
from product_classification.learner.train import CreateLearner, TrainHighLevels, FinetuneAll, EvaluateClassifier
from product_classification.models import CnnHyperParameters, Epochs, LearningRates

In [2]:
path = "../data"
with open(f"{path}/datasets.pkl", "rb") as handle:
    simple_datasets = pickle.load(handle)
    
with open(f"{path}/multilabel_binarizer.pkl", "rb") as handle:
    multilabel_binarizer = pickle.load(handle)
    
with open(f"{path}/pos_weight.pkl", "rb") as handle:
    pos_weight = pickle.load(handle)

#### Define CNN Hyperparameters

In [3]:
cnn_hyperparameters = CnnHyperParameters(nb_filters=100,
                                           kernels=[3,4,5],
                                           dropout=0.5,
                                           lrates=LearningRates(init_phase=5e-3,
                                                                finetuning_phase=1e-3),
                                           epochs=Epochs(init_phase=4,
                                                         finetuning_phase=5)
                                          )

In [4]:
# simple_datasets.training = simple_datasets.training.head(1000)
# simple_datasets.test = simple_datasets.test.head(200)
# simple_datasets.validation = simple_datasets.validation.head(200)

#### Init torch components for text processing

In [5]:
fit_categorical_data = FitCategoricalData()
one_hot_encoder = fit_categorical_data.execute(processed_data=simple_datasets)

In [6]:
init_fields = InitFields()
torch_fields = init_fields.execute(one_hot_encoder=one_hot_encoder)

2022-09-12 16:07:32,334 :: text_processing/text_processing.py/execute :: INFO :: Categorical field created
2022-09-12 16:07:33,677 :: text_processing/text_processing.py/execute :: INFO :: Text field created
2022-09-12 16:07:33,678 :: text_processing/text_processing.py/execute :: INFO :: Label field created


#### Create datasets mapping features and labels

In [7]:
labels = list(simple_datasets.training.iloc[:,5:].columns)
cat_cols = ["brand_name", "merchant_name"]
create_datasets = CreateDatasets()
torch_dataset = create_datasets.execute(processed_data=simple_datasets,
                                        torch_fields=torch_fields,
                                        cnn_hparams=cnn_hyperparameters,
                                        txt_col="product_name",
                                        cat_cols=cat_cols,
                                        lbl_cols=labels)

2022-09-12 16:07:34,030 :: __init__/__init__.py/__init__ :: INFO :: fields: [('text', <torchtext.data.field.Field object at 0x7f2182851a00>), ('category', <torchtext.data.field.Field object at 0x7f2218aed4c0>), ('label', <torchtext.data.field.LabelField object at 0x7f221936d0d0>)]
2022-09-12 16:08:44,954 :: __init__/__init__.py/__init__ :: INFO :: fields: [('text', <torchtext.data.field.Field object at 0x7f2182851a00>), ('category', <torchtext.data.field.Field object at 0x7f2218aed4c0>), ('label', <torchtext.data.field.LabelField object at 0x7f221936d0d0>)]
2022-09-12 16:08:50,971 :: __init__/__init__.py/__init__ :: INFO :: fields: [('text', <torchtext.data.field.Field object at 0x7f2182851a00>), ('category', <torchtext.data.field.Field object at 0x7f2218aed4c0>), ('label', <torchtext.data.field.LabelField object at 0x7f221936d0d0>)]
2022-09-12 16:08:57,419 :: text_processing/text_processing.py/execute :: INFO :: Torchtext datasets created from dataframes


#### Build text vocabulary with fasttext embeddings

In [8]:
vector_cache=".vector_cache"
vocab_size=100000
embedding_name="fasttext.simple.300d"
build_text_vocabulary = BuildTextVocabulary()
torch_fields = build_text_vocabulary.execute(torch_datasets=torch_dataset,
        torch_fields=torch_fields,
        vocab_size=vocab_size,
        embedding_name=embedding_name,
        vectors_cache=vector_cache)

2022-09-12 16:08:59,636 :: text_processing/text_processing.py/execute :: INFO :: Text vocabulary created. Corpus_dim: 65871


#### Create pytorch iterators

In [9]:
batch_size=128 
device="cpu"
build_iterators = BuildIterators()
torch_iterators = build_iterators.execute(torch_datasets=torch_dataset,
                                             batch_size=batch_size,
                                             device=device)

2022-09-12 16:08:59,645 :: text_processing/text_processing.py/execute :: INFO :: Torchtext iterators created from datasets


#### Create Learner to train CNN

In [10]:
import torch
gamma=2.0
pos_weight = torch.Tensor(pos_weight).to(device)
create_learner = CreateLearner()
learner = create_learner.execute(cnn_hparams=cnn_hyperparameters,
                                    embedding_name=embedding_name,
                                    torch_fields=torch_fields,
                                    processed_data=simple_datasets,
                                    batch_size=batch_size,
                                    label_number=len(labels),
                                    one_hot_encoder=one_hot_encoder,
                                    pos_weight=pos_weight,
                                    gamma=gamma)

2022-09-12 11:52:24,802 :: train/train.py/execute :: INFO :: CNN learner compiled for the text classification task


#### Training

Train high levels layers (embedding layer freezed)

In [11]:
train_high_levels = TrainHighLevels()
learner = train_high_levels.execute(cnn_learner=learner,
        device=device,
        torch_iterators=torch_iterators,
        cnn_hparams=cnn_hyperparameters)

2022-09-12 11:52:24,811 :: train/train.py/execute :: INFO :: Starting training of CNN and classification layers


Embedding(65871, 300)


Training Epoch [1/4]:  20%|██████████████████                                                                         | 202/1018 [02:01<08:29,  1.60it/s, loss=0.627]2022-09-12 11:54:35,457 :: learner/learner.py/evaluate :: INFO :: Validation:
2022-09-12 11:54:35,458 :: learner/learner.py/evaluate :: INFO :: 	hamming_loss = 0.205
2022-09-12 11:54:35,459 :: learner/learner.py/evaluate :: INFO :: 	f1_score = 0.264
2022-09-12 11:54:35,462 :: learner/learner.py/evaluate :: INFO :: 	loss = 61.686
2022-09-12 11:54:35,463 :: learner/learner.py/evaluate :: INFO :: ****************************************
Training Epoch [1/4]:  40%|████████████████████████████████████▌                                                       | 405/1018 [03:30<03:49,  2.67it/s, loss=1.01]2022-09-12 11:56:04,635 :: learner/learner.py/evaluate :: INFO :: Validation:
2022-09-12 11:56:04,636 :: learner/learner.py/evaluate :: INFO :: 	hamming_loss = 0.156
2022-09-12 11:56:04,637 :: learner/learner.py/evaluate :: INFO :: 

Train all model layers

In [12]:
finetune_all = FinetuneAll()
learner = finetune_all.execute(cnn_learner=learner,
                                device=device,
                                torch_iterators=torch_iterators,
                                cnn_hparams=cnn_hyperparameters,
                                processed_data=simple_datasets,
                                batch_size=batch_size,
                                pos_weight=pos_weight,
                                gamma=gamma)

2022-09-12 12:25:52,281 :: train/train.py/execute :: INFO :: Unfreezing the whole CNN network
2022-09-12 12:25:52,283 :: train/train.py/execute :: INFO :: Starting finetuning of the whole CNN network
Training Epoch [1/5]:  20%|██████████████████                                                                         | 202/1018 [02:58<12:00,  1.13it/s, loss=0.453]2022-09-12 12:28:59,634 :: learner/learner.py/evaluate :: INFO :: Validation:
2022-09-12 12:28:59,635 :: learner/learner.py/evaluate :: INFO :: 	hamming_loss = 0.098
2022-09-12 12:28:59,636 :: learner/learner.py/evaluate :: INFO :: 	f1_score = 0.504
2022-09-12 12:28:59,636 :: learner/learner.py/evaluate :: INFO :: 	loss = 0.335
2022-09-12 12:28:59,638 :: learner/learner.py/evaluate :: INFO :: ****************************************
Training Epoch [1/5]:  40%|████████████████████████████████████▌                                                       | 405/1018 [05:21<06:26,  1.59it/s, loss=0.32]2022-09-12 12:31:23,224 :: learne

#### Evaluation

In [13]:
evaluate_classifier = EvaluateClassifier()
perfs = evaluate_classifier.execute(cnn_learner=learner,
                                        device=device,
                                        torch_iterators=torch_iterators,
                                        multilabel_binarizer=multilabel_binarizer)

2022-09-12 13:28:35,722 :: train/train.py/execute :: INFO :: Test results
2022-09-12 13:28:35,723 :: train/train.py/execute :: INFO :: {'animalerie': {'precision': 0.7296340023612751, 'recall': 0.939209726443769, 'f1-score': 0.8212624584717607, 'support': 658}, 'auto et moto': {'precision': 0.4765751211631664, 'recall': 0.8753709198813057, 'f1-score': 0.6171548117154811, 'support': 337}, 'bagages et sacs': {'precision': 0.538860103626943, 'recall': 0.832, 'f1-score': 0.6540880503144654, 'support': 250}, 'beaute et parfum': {'precision': 0.8170391061452514, 'recall': 0.9420289855072463, 'f1-score': 0.8750934928945401, 'support': 621}, 'bebe et puericulture': {'precision': 0.4917043740573152, 'recall': 0.8810810810810811, 'f1-score': 0.6311713455953534, 'support': 370}, 'bijoux': {'precision': 0.729381443298969, 'recall': 0.9370860927152318, 'f1-score': 0.8202898550724637, 'support': 302}, 'bricolage': {'precision': 0.5633802816901409, 'recall': 0.8559201141226819, 'f1-score': 0.67950169

In [14]:
import pandas as pd
pd.DataFrame(perfs).T

Unnamed: 0,precision,recall,f1-score,support
animalerie,0.729634,0.93921,0.821262,658.0
auto et moto,0.476575,0.875371,0.617155,337.0
bagages et sacs,0.53886,0.832,0.654088,250.0
beaute et parfum,0.817039,0.942029,0.875093,621.0
bebe et puericulture,0.491704,0.881081,0.631171,370.0
bijoux,0.729381,0.937086,0.82029,302.0
bricolage,0.56338,0.85592,0.679502,701.0
cd et vinyles,0.138462,0.857143,0.238411,21.0
chaussures et accessoires,0.748044,0.929961,0.829141,514.0
"commerce, industrie et science",0.141026,0.392857,0.207547,28.0


In [15]:
learner.save(f"{path}/cnn_model.pt")

In [16]:
with open(f"{path}/cnn_learner.pkl", "wb") as handle:
    pickle.dump(learner, handle, protocol=pickle.HIGHEST_PROTOCOL)

### Re Calibrate focal loss gamma parameter

#### Create Learner to train CNN

In [10]:
import torch
gamma = 1.0
pos_weight = torch.Tensor(pos_weight).to(device)
create_learner = CreateLearner()
learner = create_learner.execute(cnn_hparams=cnn_hyperparameters,
                                    embedding_name=embedding_name,
                                    torch_fields=torch_fields,
                                    processed_data=simple_datasets,
                                    batch_size=batch_size,
                                    label_number=len(labels),
                                    one_hot_encoder=one_hot_encoder,
                                    pos_weight=pos_weight,
                                    gamma=gamma)

2022-09-12 16:08:59,969 :: train/train.py/execute :: INFO :: CNN learner compiled for the text classification task


#### Training

Train high levels layers (embedding layer freezed)

In [11]:
train_high_levels = TrainHighLevels()
learner = train_high_levels.execute(cnn_learner=learner,
        device=device,
        torch_iterators=torch_iterators,
        cnn_hparams=cnn_hyperparameters)

2022-09-12 16:08:59,978 :: train/train.py/execute :: INFO :: Starting training of CNN and classification layers


Embedding(65871, 300)


Training Epoch [1/4]:  20%|██████████████████▍                                                                          | 202/1018 [02:00<08:04,  1.69it/s, loss=0.8]2022-09-12 16:11:10,300 :: learner/learner.py/evaluate :: INFO :: Validation:
2022-09-12 16:11:10,301 :: learner/learner.py/evaluate :: INFO :: 	hamming_loss = 0.164
2022-09-12 16:11:10,302 :: learner/learner.py/evaluate :: INFO :: 	f1_score = 0.270
2022-09-12 16:11:10,304 :: learner/learner.py/evaluate :: INFO :: 	loss = 72.228
2022-09-12 16:11:10,305 :: learner/learner.py/evaluate :: INFO :: ****************************************
Training Epoch [1/4]:  40%|████████████████████████████████████▌                                                       | 405/1018 [03:30<03:53,  2.63it/s, loss=1.32]2022-09-12 16:12:39,833 :: learner/learner.py/evaluate :: INFO :: Validation:
2022-09-12 16:12:39,839 :: learner/learner.py/evaluate :: INFO :: 	hamming_loss = 0.131
2022-09-12 16:12:39,840 :: learner/learner.py/evaluate :: INFO :: 

Train all model layers

In [12]:
finetune_all = FinetuneAll()
learner = finetune_all.execute(cnn_learner=learner,
                                device=device,
                                torch_iterators=torch_iterators,
                                cnn_hparams=cnn_hyperparameters,
                                processed_data=simple_datasets,
                                batch_size=batch_size,
                                pos_weight=pos_weight,
                                gamma=gamma)

2022-09-12 16:42:27,382 :: train/train.py/execute :: INFO :: Unfreezing the whole CNN network
2022-09-12 16:42:27,386 :: train/train.py/execute :: INFO :: Starting finetuning of the whole CNN network
Training Epoch [1/5]:  20%|██████████████████                                                                         | 202/1018 [02:58<11:25,  1.19it/s, loss=0.536]2022-09-12 16:45:35,522 :: learner/learner.py/evaluate :: INFO :: Validation:
2022-09-12 16:45:35,522 :: learner/learner.py/evaluate :: INFO :: 	hamming_loss = 0.085
2022-09-12 16:45:35,523 :: learner/learner.py/evaluate :: INFO :: 	f1_score = 0.544
2022-09-12 16:45:35,525 :: learner/learner.py/evaluate :: INFO :: 	loss = 1.113
2022-09-12 16:45:35,526 :: learner/learner.py/evaluate :: INFO :: ****************************************
Training Epoch [1/5]:  40%|████████████████████████████████████▏                                                      | 405/1018 [05:24<07:31,  1.36it/s, loss=0.431]2022-09-12 16:48:01,187 :: learne

Epoch    25: reducing learning rate of group 0 to 5.0000e-04.


Training Epoch [5/5]:  40%|████████████████████████████████████▏                                                      | 405/1018 [05:08<06:34,  1.55it/s, loss=0.206]2022-09-12 17:38:07,685 :: learner/learner.py/evaluate :: INFO :: Validation:
2022-09-12 17:38:07,686 :: learner/learner.py/evaluate :: INFO :: 	hamming_loss = 0.026
2022-09-12 17:38:07,688 :: learner/learner.py/evaluate :: INFO :: 	f1_score = 0.754
2022-09-12 17:38:07,689 :: learner/learner.py/evaluate :: INFO :: 	loss = 0.350
2022-09-12 17:38:07,690 :: learner/learner.py/evaluate :: INFO :: ****************************************
Training Epoch [5/5]:  60%|██████████████████████████████████████████████████████▎                                    | 608/1018 [07:26<04:15,  1.61it/s, loss=0.162]2022-09-12 17:40:25,840 :: learner/learner.py/evaluate :: INFO :: Validation:
2022-09-12 17:40:25,841 :: learner/learner.py/evaluate :: INFO :: 	hamming_loss = 0.027
2022-09-12 17:40:25,842 :: learner/learner.py/evaluate :: INFO :: 	

#### Evaluation

In [13]:
evaluate_classifier = EvaluateClassifier()
perfs = evaluate_classifier.execute(cnn_learner=learner,
                                        device=device,
                                        torch_iterators=torch_iterators,
                                        multilabel_binarizer=multilabel_binarizer)

2022-09-12 17:45:26,302 :: train/train.py/execute :: INFO :: Test results
2022-09-12 17:45:26,303 :: train/train.py/execute :: INFO :: {'animalerie': {'precision': 0.8443526170798898, 'recall': 0.9316109422492401, 'f1-score': 0.8858381502890174, 'support': 658}, 'auto et moto': {'precision': 0.5376146788990825, 'recall': 0.8694362017804155, 'f1-score': 0.6643990929705215, 'support': 337}, 'bagages et sacs': {'precision': 0.5205811138014528, 'recall': 0.86, 'f1-score': 0.6485671191553545, 'support': 250}, 'beaute et parfum': {'precision': 0.8200836820083682, 'recall': 0.9468599033816425, 'f1-score': 0.8789237668161436, 'support': 621}, 'bebe et puericulture': {'precision': 0.5988483685220729, 'recall': 0.8432432432432433, 'f1-score': 0.7003367003367004, 'support': 370}, 'bijoux': {'precision': 0.7634408602150538, 'recall': 0.9403973509933775, 'f1-score': 0.8427299703264095, 'support': 302}, 'bricolage': {'precision': 0.5719844357976653, 'recall': 0.8388017118402282, 'f1-score': 0.680161

In [14]:
import pandas as pd
pd.DataFrame(perfs).T

Unnamed: 0,precision,recall,f1-score,support
animalerie,0.844353,0.931611,0.885838,658.0
auto et moto,0.537615,0.869436,0.664399,337.0
bagages et sacs,0.520581,0.86,0.648567,250.0
beaute et parfum,0.820084,0.94686,0.878924,621.0
bebe et puericulture,0.598848,0.843243,0.700337,370.0
bijoux,0.763441,0.940397,0.84273,302.0
bricolage,0.571984,0.838802,0.680162,701.0
cd et vinyles,0.147826,0.809524,0.25,21.0
chaussures et accessoires,0.770096,0.931907,0.84331,514.0
"commerce, industrie et science",0.160494,0.464286,0.238532,28.0


### Test compute time 

In [27]:
import timeit
timeit.timeit("learner.predict(torch_iterators.test, 'cpu')", number=10, globals = globals())

115.75525344512425