In [1]:
import warnings
warnings.filterwarnings('ignore')

import pickle
from product_classification.data_processing.text_processing import FitCategoricalData, InitFields, CreateDatasets, BuildTextVocabulary, BuildIterators
from product_classification.learner.train import CreateLearner, TrainHighLevels, FinetuneAll, EvaluateClassifier
from product_classification.models import CnnHyperParameters, Epochs, LearningRates

In [2]:
path = "../data"
with open(f"{path}/datasets.pkl", "rb") as handle:
    simple_datasets = pickle.load(handle)
    
with open(f"{path}/multilabel_binarizer.pkl", "rb") as handle:
    multilabel_binarizer = pickle.load(handle)
    
with open(f"{path}/pos_weight.pkl", "rb") as handle:
    pos_weight = pickle.load(handle)

#### Define CNN Hyperparameters

In [3]:
cnn_hyperparameters = CnnHyperParameters(nb_filters=100,
                                           kernels=[3,4,5],
                                           dropout=0.5,
                                           lrates=LearningRates(init_phase=5e-3,
                                                                finetuning_phase=1e-3),
                                           epochs=Epochs(init_phase=4,
                                                         finetuning_phase=5)
                                          )

In [4]:
# simple_datasets.training = simple_datasets.training.head(1000)
# simple_datasets.test = simple_datasets.test.head(200)
# simple_datasets.validation = simple_datasets.validation.head(200)

#### Init torch components for text processing

In [5]:
fit_categorical_data = FitCategoricalData()
one_hot_encoder = fit_categorical_data.execute(processed_data=simple_datasets)

In [6]:
init_fields = InitFields()
torch_fields = init_fields.execute(one_hot_encoder=one_hot_encoder)

2022-09-13 08:46:17,138 :: text_processing/text_processing.py/execute :: INFO :: Categorical field created
2022-09-13 08:46:18,492 :: text_processing/text_processing.py/execute :: INFO :: Text field created
2022-09-13 08:46:18,493 :: text_processing/text_processing.py/execute :: INFO :: Label field created


#### Create datasets mapping features and labels

In [7]:
labels = list(simple_datasets.training.iloc[:,5:].columns)
cat_cols = ["brand_name", "merchant_name"]
create_datasets = CreateDatasets()
torch_dataset = create_datasets.execute(processed_data=simple_datasets,
                                        torch_fields=torch_fields,
                                        cnn_hparams=cnn_hyperparameters,
                                        txt_col="product_name",
                                        cat_cols=cat_cols,
                                        lbl_cols=labels)

2022-09-13 08:46:19,974 :: __init__/__init__.py/__init__ :: INFO :: fields: [('text', <torchtext.data.field.Field object at 0x7fb05c076eb0>), ('category', <torchtext.data.field.Field object at 0x7fafc9e54e50>), ('label', <torchtext.data.field.LabelField object at 0x7fb05c077730>)]
2022-09-13 08:47:32,005 :: __init__/__init__.py/__init__ :: INFO :: fields: [('text', <torchtext.data.field.Field object at 0x7fb05c076eb0>), ('category', <torchtext.data.field.Field object at 0x7fafc9e54e50>), ('label', <torchtext.data.field.LabelField object at 0x7fb05c077730>)]
2022-09-13 08:47:38,221 :: __init__/__init__.py/__init__ :: INFO :: fields: [('text', <torchtext.data.field.Field object at 0x7fb05c076eb0>), ('category', <torchtext.data.field.Field object at 0x7fafc9e54e50>), ('label', <torchtext.data.field.LabelField object at 0x7fb05c077730>)]
2022-09-13 08:47:44,754 :: text_processing/text_processing.py/execute :: INFO :: Torchtext datasets created from dataframes


#### Build text vocabulary with fasttext embeddings

In [8]:
vector_cache=".vector_cache"
vocab_size=100000
embedding_name="fasttext.simple.300d"
build_text_vocabulary = BuildTextVocabulary()
torch_fields = build_text_vocabulary.execute(torch_datasets=torch_dataset,
        torch_fields=torch_fields,
        vocab_size=vocab_size,
        embedding_name=embedding_name,
        vectors_cache=vector_cache)

2022-09-13 08:47:47,010 :: text_processing/text_processing.py/execute :: INFO :: Text vocabulary created. Corpus_dim: 65871


#### Create pytorch iterators

In [9]:
batch_size=128 
device="cpu"
build_iterators = BuildIterators()
torch_iterators = build_iterators.execute(torch_datasets=torch_dataset,
                                             batch_size=batch_size,
                                             device=device)

2022-09-13 08:47:47,019 :: text_processing/text_processing.py/execute :: INFO :: Torchtext iterators created from datasets


#### Create Learner to train CNN

In [10]:
import torch
gamma=2.0
pos_weight = torch.Tensor(pos_weight).to(device)
create_learner = CreateLearner()
learner = create_learner.execute(cnn_hparams=cnn_hyperparameters,
                                    embedding_name=embedding_name,
                                    torch_fields=torch_fields,
                                    processed_data=simple_datasets,
                                    batch_size=batch_size,
                                    label_number=len(labels),
                                    one_hot_encoder=one_hot_encoder,
                                    pos_weight=pos_weight,
                                    gamma=gamma)

2022-09-12 11:52:24,802 :: train/train.py/execute :: INFO :: CNN learner compiled for the text classification task


#### Training

Train high levels layers (embedding layer freezed)

In [11]:
train_high_levels = TrainHighLevels()
learner = train_high_levels.execute(cnn_learner=learner,
        device=device,
        torch_iterators=torch_iterators,
        cnn_hparams=cnn_hyperparameters)

2022-09-12 11:52:24,811 :: train/train.py/execute :: INFO :: Starting training of CNN and classification layers


Embedding(65871, 300)


Training Epoch [1/4]:  20%|██████████████████                                                                         | 202/1018 [02:01<08:29,  1.60it/s, loss=0.627]2022-09-12 11:54:35,457 :: learner/learner.py/evaluate :: INFO :: Validation:
2022-09-12 11:54:35,458 :: learner/learner.py/evaluate :: INFO :: 	hamming_loss = 0.205
2022-09-12 11:54:35,459 :: learner/learner.py/evaluate :: INFO :: 	f1_score = 0.264
2022-09-12 11:54:35,462 :: learner/learner.py/evaluate :: INFO :: 	loss = 61.686
2022-09-12 11:54:35,463 :: learner/learner.py/evaluate :: INFO :: ****************************************
Training Epoch [1/4]:  40%|████████████████████████████████████▌                                                       | 405/1018 [03:30<03:49,  2.67it/s, loss=1.01]2022-09-12 11:56:04,635 :: learner/learner.py/evaluate :: INFO :: Validation:
2022-09-12 11:56:04,636 :: learner/learner.py/evaluate :: INFO :: 	hamming_loss = 0.156
2022-09-12 11:56:04,637 :: learner/learner.py/evaluate :: INFO :: 

Train all model layers

In [12]:
finetune_all = FinetuneAll()
learner = finetune_all.execute(cnn_learner=learner,
                                device=device,
                                torch_iterators=torch_iterators,
                                cnn_hparams=cnn_hyperparameters,
                                processed_data=simple_datasets,
                                batch_size=batch_size,
                                pos_weight=pos_weight,
                                gamma=gamma)

2022-09-12 12:25:52,281 :: train/train.py/execute :: INFO :: Unfreezing the whole CNN network
2022-09-12 12:25:52,283 :: train/train.py/execute :: INFO :: Starting finetuning of the whole CNN network
Training Epoch [1/5]:  20%|██████████████████                                                                         | 202/1018 [02:58<12:00,  1.13it/s, loss=0.453]2022-09-12 12:28:59,634 :: learner/learner.py/evaluate :: INFO :: Validation:
2022-09-12 12:28:59,635 :: learner/learner.py/evaluate :: INFO :: 	hamming_loss = 0.098
2022-09-12 12:28:59,636 :: learner/learner.py/evaluate :: INFO :: 	f1_score = 0.504
2022-09-12 12:28:59,636 :: learner/learner.py/evaluate :: INFO :: 	loss = 0.335
2022-09-12 12:28:59,638 :: learner/learner.py/evaluate :: INFO :: ****************************************
Training Epoch [1/5]:  40%|████████████████████████████████████▌                                                       | 405/1018 [05:21<06:26,  1.59it/s, loss=0.32]2022-09-12 12:31:23,224 :: learne

#### Evaluation

In [13]:
evaluate_classifier = EvaluateClassifier()
perfs = evaluate_classifier.execute(cnn_learner=learner,
                                        device=device,
                                        torch_iterators=torch_iterators,
                                        multilabel_binarizer=multilabel_binarizer)

2022-09-12 13:28:35,722 :: train/train.py/execute :: INFO :: Test results
2022-09-12 13:28:35,723 :: train/train.py/execute :: INFO :: {'animalerie': {'precision': 0.7296340023612751, 'recall': 0.939209726443769, 'f1-score': 0.8212624584717607, 'support': 658}, 'auto et moto': {'precision': 0.4765751211631664, 'recall': 0.8753709198813057, 'f1-score': 0.6171548117154811, 'support': 337}, 'bagages et sacs': {'precision': 0.538860103626943, 'recall': 0.832, 'f1-score': 0.6540880503144654, 'support': 250}, 'beaute et parfum': {'precision': 0.8170391061452514, 'recall': 0.9420289855072463, 'f1-score': 0.8750934928945401, 'support': 621}, 'bebe et puericulture': {'precision': 0.4917043740573152, 'recall': 0.8810810810810811, 'f1-score': 0.6311713455953534, 'support': 370}, 'bijoux': {'precision': 0.729381443298969, 'recall': 0.9370860927152318, 'f1-score': 0.8202898550724637, 'support': 302}, 'bricolage': {'precision': 0.5633802816901409, 'recall': 0.8559201141226819, 'f1-score': 0.67950169

In [14]:
import pandas as pd
pd.DataFrame(perfs).T

Unnamed: 0,precision,recall,f1-score,support
animalerie,0.729634,0.93921,0.821262,658.0
auto et moto,0.476575,0.875371,0.617155,337.0
bagages et sacs,0.53886,0.832,0.654088,250.0
beaute et parfum,0.817039,0.942029,0.875093,621.0
bebe et puericulture,0.491704,0.881081,0.631171,370.0
bijoux,0.729381,0.937086,0.82029,302.0
bricolage,0.56338,0.85592,0.679502,701.0
cd et vinyles,0.138462,0.857143,0.238411,21.0
chaussures et accessoires,0.748044,0.929961,0.829141,514.0
"commerce, industrie et science",0.141026,0.392857,0.207547,28.0


In [15]:
learner.save(f"{path}/cnn_model.pt")

In [16]:
with open(f"{path}/cnn_learner.pkl", "wb") as handle:
    pickle.dump(learner, handle, protocol=pickle.HIGHEST_PROTOCOL)

### Train with weighted BCE loss

#### Create Learner to train CNN

In [10]:
import torch
gamma = 0.0
pos_weight = torch.Tensor(pos_weight).to(device)
create_learner = CreateLearner()
learner = create_learner.execute(cnn_hparams=cnn_hyperparameters,
                                    embedding_name=embedding_name,
                                    torch_fields=torch_fields,
                                    processed_data=simple_datasets,
                                    batch_size=batch_size,
                                    label_number=len(labels),
                                    one_hot_encoder=one_hot_encoder,
                                    pos_weight=pos_weight,
                                    gamma=gamma)

2022-09-13 08:47:47,397 :: train/train.py/execute :: INFO :: CNN learner compiled for the text classification task


#### Training

Train high levels layers (embedding layer freezed)

In [11]:
train_high_levels = TrainHighLevels()
learner = train_high_levels.execute(cnn_learner=learner,
        device=device,
        torch_iterators=torch_iterators,
        cnn_hparams=cnn_hyperparameters)

2022-09-13 08:47:47,409 :: train/train.py/execute :: INFO :: Starting training of CNN and classification layers


Embedding(65871, 300)


Training Epoch [1/4]:  20%|███████████████████▍                                                                              | 202/1018 [02:03<08:36,  1.58it/s, loss=1.05]2022-09-13 08:50:00,775 :: learner/learner.py/evaluate :: INFO :: Validation:
2022-09-13 08:50:00,776 :: learner/learner.py/evaluate :: INFO :: 	hamming_loss = 0.158
2022-09-13 08:50:00,777 :: learner/learner.py/evaluate :: INFO :: 	f1_score = 0.292
2022-09-13 08:50:00,778 :: learner/learner.py/evaluate :: INFO :: 	loss = 66.031
2022-09-13 08:50:00,779 :: learner/learner.py/evaluate :: INFO :: ****************************************
Training Epoch [1/4]:  40%|██████████████████████████████████████▉                                                           | 405/1018 [03:35<03:52,  2.63it/s, loss=1.83]2022-09-13 08:51:32,324 :: learner/learner.py/evaluate :: INFO :: Validation:
2022-09-13 08:51:32,325 :: learner/learner.py/evaluate :: INFO :: 	hamming_loss = 0.123
2022-09-13 08:51:32,326 :: learner/learner.py/evaluate

Train all model layers

In [12]:
finetune_all = FinetuneAll()
learner = finetune_all.execute(cnn_learner=learner,
                                device=device,
                                torch_iterators=torch_iterators,
                                cnn_hparams=cnn_hyperparameters,
                                processed_data=simple_datasets,
                                batch_size=batch_size,
                                pos_weight=pos_weight,
                                gamma=gamma)

2022-09-13 09:21:59,947 :: train/train.py/execute :: INFO :: Unfreezing the whole CNN network
2022-09-13 09:21:59,960 :: train/train.py/execute :: INFO :: Starting finetuning of the whole CNN network
Training Epoch [1/5]:  20%|███████████████████▏                                                                             | 202/1018 [02:58<11:45,  1.16it/s, loss=0.671]2022-09-13 09:25:08,201 :: learner/learner.py/evaluate :: INFO :: Validation:
2022-09-13 09:25:08,202 :: learner/learner.py/evaluate :: INFO :: 	hamming_loss = 0.073
2022-09-13 09:25:08,202 :: learner/learner.py/evaluate :: INFO :: 	f1_score = 0.568
2022-09-13 09:25:08,205 :: learner/learner.py/evaluate :: INFO :: 	loss = 0.487
2022-09-13 09:25:08,206 :: learner/learner.py/evaluate :: INFO :: ****************************************
Training Epoch [1/5]:  40%|██████████████████████████████████████▉                                                           | 405/1018 [05:24<06:49,  1.50it/s, loss=0.47]2022-09-13 09:27:33,6

Epoch    19: reducing learning rate of group 0 to 5.0000e-04.


Training Epoch [4/5]:  40%|██████████████████████████████████████▌                                                          | 405/1018 [05:14<06:38,  1.54it/s, loss=0.278]2022-09-13 10:05:44,923 :: learner/learner.py/evaluate :: INFO :: Validation:
2022-09-13 10:05:44,924 :: learner/learner.py/evaluate :: INFO :: 	hamming_loss = 0.026
2022-09-13 10:05:44,925 :: learner/learner.py/evaluate :: INFO :: 	f1_score = 0.752
2022-09-13 10:05:44,926 :: learner/learner.py/evaluate :: INFO :: 	loss = 0.467
2022-09-13 10:05:44,927 :: learner/learner.py/evaluate :: INFO :: ****************************************
Training Epoch [4/5]:  60%|█████████████████████████████████████████████████████████▉                                       | 608/1018 [07:35<04:22,  1.56it/s, loss=0.222]2022-09-13 10:08:06,492 :: learner/learner.py/evaluate :: INFO :: Validation:
2022-09-13 10:08:06,494 :: learner/learner.py/evaluate :: INFO :: 	hamming_loss = 0.024
2022-09-13 10:08:06,495 :: learner/learner.py/evaluate 

#### Evaluation

In [13]:
evaluate_classifier = EvaluateClassifier()
perfs = evaluate_classifier.execute(cnn_learner=learner,
                                        device=device,
                                        torch_iterators=torch_iterators,
                                        multilabel_binarizer=multilabel_binarizer)

2022-09-13 10:25:53,092 :: train/train.py/execute :: INFO :: Test results
2022-09-13 10:25:53,093 :: train/train.py/execute :: INFO :: {'animalerie': {'precision': 0.8261455525606469, 'recall': 0.9316109422492401, 'f1-score': 0.8757142857142858, 'support': 658}, 'auto et moto': {'precision': 0.5509433962264151, 'recall': 0.8664688427299704, 'f1-score': 0.6735870818915801, 'support': 337}, 'bagages et sacs': {'precision': 0.65814696485623, 'recall': 0.824, 'f1-score': 0.7317939609236234, 'support': 250}, 'beaute et parfum': {'precision': 0.9090909090909091, 'recall': 0.9339774557165862, 'f1-score': 0.9213661636219221, 'support': 621}, 'bebe et puericulture': {'precision': 0.6553911205073996, 'recall': 0.8378378378378378, 'f1-score': 0.7354685646500593, 'support': 370}, 'bijoux': {'precision': 0.8134110787172012, 'recall': 0.9238410596026491, 'f1-score': 0.8651162790697675, 'support': 302}, 'bricolage': {'precision': 0.6, 'recall': 0.8131241084165478, 'f1-score': 0.6904906117504542, 'sup

In [14]:
import pandas as pd
pd.DataFrame(perfs).T

Unnamed: 0,precision,recall,f1-score,support
animalerie,0.826146,0.931611,0.875714,658.0
auto et moto,0.550943,0.866469,0.673587,337.0
bagages et sacs,0.658147,0.824,0.731794,250.0
beaute et parfum,0.909091,0.933977,0.921366,621.0
bebe et puericulture,0.655391,0.837838,0.735469,370.0
bijoux,0.813411,0.923841,0.865116,302.0
bricolage,0.6,0.813124,0.690491,701.0
cd et vinyles,0.222222,0.761905,0.344086,21.0
chaussures et accessoires,0.837456,0.922179,0.877778,514.0
"commerce, industrie et science",0.22449,0.392857,0.285714,28.0


### Test compute time 

In [27]:
import timeit
timeit.timeit("learner.predict(torch_iterators.test, 'cpu')", number=10, globals = globals())

115.75525344512425