## <center>Test Task from
<center> <img src='https://newdermis.ru/wa-data/public/shop/products/14/webp/data/public/site/Logo_kazan.webp' style='width:400px;'>

**by Liliya Kazykhanova**

## Product categorization: Text SingleTask

# <p style="text-align:center;font-size:100%;">1. Install and Import</p>

In [1]:
import random
import numpy as np
import pandas as pd
import re
from pathlib import Path
from tqdm.auto import tqdm #progress bar

from sklearn.model_selection import StratifiedShuffleSplit

# PyTorch libs
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchmetrics import F1Score, Precision, Recall
# from torch.optim.lr_scheduler import ChainedScheduler, LinearLR, ExponentialLR

import gc

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import get_scheduler
from transformers import DataCollatorWithPadding

import warnings
warnings.filterwarnings("ignore")
import logging 
logging.disable(logging.WARNING)



In [2]:
# Set random seeds
def set_seed(seed: int = 24):
    """
    Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if
    installed).

    Args:
        seed (:obj:`int`): The seed to set.
    """
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    
set_seed()

# <p style="text-align:center;font-size:100%;">2. Load Data </p>

In [3]:
# Reading train and test file
DATA_DIR = '/kaggle/input/kazan-express-test/'
df_train = pd.read_parquet(DATA_DIR+'/train.parquet', engine='pyarrow')
df_test = pd.read_parquet(DATA_DIR+'/test.parquet', engine='pyarrow')

print('Train data:\n')
display(df_train.head())
print('\nTest data:\n')
display(df_test.head())

Train data:



Unnamed: 0,product_id,category_id,sale,shop_id,shop_title,rating,text_fields,category_name
0,325286,12171,False,9031,Aksik,5.0,"{""title"": ""Зарядный кабель Borofone BX1 Lightn...",Все категории->Электроника->Смартфоны и телефо...
1,888134,14233,False,18305,Sela,5.0,"{""title"": ""Трусы Sela"", ""description"": ""Трусы-...",Все категории->Одежда->Женская одежда->Белье и...
3,1267173,13429,False,16357,ЮНЛАНДИЯ канцтовары,5.0,"{""title"": ""Гуашь \""ЮНЫЙ ВОЛШЕБНИК\"", 12 цветов...",Все категории->Хобби и творчество->Рисование->...
4,1416943,2789,False,34666,вася-nicotine,4.0,"{""title"": ""Колба для кальяна Крафт (разные цве...",Все категории->Хобби и творчество->Товары для ...
5,1058275,12834,False,26389,Lim Market,4.6,"{""title"": ""Пижама женская, однотонная с шортам...",Все категории->Одежда->Женская одежда->Домашня...



Test data:



Unnamed: 0,product_id,sale,shop_id,shop_title,rating,text_fields
1,1997646,False,22758,Sky_Electronics,5.0,"{""title"": ""Светодиодная лента Smart led Strip ..."
2,927375,False,17729,Di-Di Market,4.405941,"{""title"": ""Стекло ПЛЕНКА керамик матовое Honor..."
3,1921513,False,54327,VisionStore,4.0,"{""title"": ""Проводные наушники с микрофоном jac..."
4,1668662,False,15000,FORNAILS,5.0,"{""title"": ""Декоративная табличка \""Правила кух..."
5,1467778,False,39600,МОЯ КУХНЯ,5.0,"{""title"": ""Подставка под ложку керамическая, п..."


# <p style="text-align:center;font-size:100%;">3. Data Preprocessing: Extract text fields</p>

In [4]:
def extract_text_from_html(x: str):
    extracted_text = re.sub('<[^<>]+>', ' ', x).strip()
    text_with_no_extra_spaces = re.sub(' +', ' ', extracted_text)
    
    return text_with_no_extra_spaces

In [5]:
df_train['text_fields_dict'] = df_train['text_fields'].apply(lambda x: eval(x))
df_train['title'] = df_train['text_fields_dict'].apply(lambda x: x['title'])
df_train['description'] = df_train['text_fields_dict'].apply(lambda x: extract_text_from_html(x['description']))

df_train.head()

Unnamed: 0,product_id,category_id,sale,shop_id,shop_title,rating,text_fields,category_name,text_fields_dict,title,description
0,325286,12171,False,9031,Aksik,5.0,"{""title"": ""Зарядный кабель Borofone BX1 Lightn...",Все категории->Электроника->Смартфоны и телефо...,{'title': 'Зарядный кабель Borofone BX1 Lightn...,Зарядный кабель Borofone BX1 Lightning для айф...,Зарядный кабель Borofone BX1 подходит для заря...
1,888134,14233,False,18305,Sela,5.0,"{""title"": ""Трусы Sela"", ""description"": ""Трусы-...",Все категории->Одежда->Женская одежда->Белье и...,"{'title': 'Трусы Sela', 'description': 'Трусы-...",Трусы Sela,Трусы-слипы из эластичного бесшовного трикотаж...
3,1267173,13429,False,16357,ЮНЛАНДИЯ канцтовары,5.0,"{""title"": ""Гуашь \""ЮНЫЙ ВОЛШЕБНИК\"", 12 цветов...",Все категории->Хобби и творчество->Рисование->...,"{'title': 'Гуашь ""ЮНЫЙ ВОЛШЕБНИК"", 12 цветов п...","Гуашь ""ЮНЫЙ ВОЛШЕБНИК"", 12 цветов по 35 мл, БО...",Гуашь высшего качества ЮНЛАНДИЯ поможет создат...
4,1416943,2789,False,34666,вася-nicotine,4.0,"{""title"": ""Колба для кальяна Крафт (разные цве...",Все категории->Хобби и творчество->Товары для ...,{'title': 'Колба для кальяна Крафт (разные цве...,Колба для кальяна Крафт (разные цвета),Универсальная колба для кальяна Craft подходит...
5,1058275,12834,False,26389,Lim Market,4.6,"{""title"": ""Пижама женская, однотонная с шортам...",Все категории->Одежда->Женская одежда->Домашня...,"{'title': 'Пижама женская, однотонная с шортам...","Пижама женская, однотонная с шортами",Лёгкая ткань! Комфортная посадка! Идеальная дл...


In [6]:
# Join text features into single string
df_train['text_feature'] = df_train[['shop_title', 'title', 'description']].agg('. '.join, axis=1)
df_train.head(3)

Unnamed: 0,product_id,category_id,sale,shop_id,shop_title,rating,text_fields,category_name,text_fields_dict,title,description,text_feature
0,325286,12171,False,9031,Aksik,5.0,"{""title"": ""Зарядный кабель Borofone BX1 Lightn...",Все категории->Электроника->Смартфоны и телефо...,{'title': 'Зарядный кабель Borofone BX1 Lightn...,Зарядный кабель Borofone BX1 Lightning для айф...,Зарядный кабель Borofone BX1 подходит для заря...,Aksik. Зарядный кабель Borofone BX1 Lightning ...
1,888134,14233,False,18305,Sela,5.0,"{""title"": ""Трусы Sela"", ""description"": ""Трусы-...",Все категории->Одежда->Женская одежда->Белье и...,"{'title': 'Трусы Sela', 'description': 'Трусы-...",Трусы Sela,Трусы-слипы из эластичного бесшовного трикотаж...,Sela. Трусы Sela. Трусы-слипы из эластичного б...
3,1267173,13429,False,16357,ЮНЛАНДИЯ канцтовары,5.0,"{""title"": ""Гуашь \""ЮНЫЙ ВОЛШЕБНИК\"", 12 цветов...",Все категории->Хобби и творчество->Рисование->...,"{'title': 'Гуашь ""ЮНЫЙ ВОЛШЕБНИК"", 12 цветов п...","Гуашь ""ЮНЫЙ ВОЛШЕБНИК"", 12 цветов по 35 мл, БО...",Гуашь высшего качества ЮНЛАНДИЯ поможет создат...,"ЮНЛАНДИЯ канцтовары. Гуашь ""ЮНЫЙ ВОЛШЕБНИК"", 1..."


## <p style="text-align:center;font-size:100%;"> Drop small categories</p>

In [7]:
# Drop categories with number of samples less than 2
valid_categories = df_train['category_id'].value_counts()[df_train['category_id'].value_counts() >= 2].index.values
print(f"Number of categories with number of samples less than 2: {(df_train.category_id.nunique() - len(valid_categories))}\n")

df_train = df_train[lambda x: x['category_id'].isin(valid_categories)]
df_train = df_train.reset_index()

df_train.shape

Number of categories with number of samples less than 2: 4



(91116, 13)

In [8]:
df_train.head()

Unnamed: 0,index,product_id,category_id,sale,shop_id,shop_title,rating,text_fields,category_name,text_fields_dict,title,description,text_feature
0,0,325286,12171,False,9031,Aksik,5.0,"{""title"": ""Зарядный кабель Borofone BX1 Lightn...",Все категории->Электроника->Смартфоны и телефо...,{'title': 'Зарядный кабель Borofone BX1 Lightn...,Зарядный кабель Borofone BX1 Lightning для айф...,Зарядный кабель Borofone BX1 подходит для заря...,Aksik. Зарядный кабель Borofone BX1 Lightning ...
1,1,888134,14233,False,18305,Sela,5.0,"{""title"": ""Трусы Sela"", ""description"": ""Трусы-...",Все категории->Одежда->Женская одежда->Белье и...,"{'title': 'Трусы Sela', 'description': 'Трусы-...",Трусы Sela,Трусы-слипы из эластичного бесшовного трикотаж...,Sela. Трусы Sela. Трусы-слипы из эластичного б...
2,3,1267173,13429,False,16357,ЮНЛАНДИЯ канцтовары,5.0,"{""title"": ""Гуашь \""ЮНЫЙ ВОЛШЕБНИК\"", 12 цветов...",Все категории->Хобби и творчество->Рисование->...,"{'title': 'Гуашь ""ЮНЫЙ ВОЛШЕБНИК"", 12 цветов п...","Гуашь ""ЮНЫЙ ВОЛШЕБНИК"", 12 цветов по 35 мл, БО...",Гуашь высшего качества ЮНЛАНДИЯ поможет создат...,"ЮНЛАНДИЯ канцтовары. Гуашь ""ЮНЫЙ ВОЛШЕБНИК"", 1..."
3,4,1416943,2789,False,34666,вася-nicotine,4.0,"{""title"": ""Колба для кальяна Крафт (разные цве...",Все категории->Хобби и творчество->Товары для ...,{'title': 'Колба для кальяна Крафт (разные цве...,Колба для кальяна Крафт (разные цвета),Универсальная колба для кальяна Craft подходит...,вася-nicotine. Колба для кальяна Крафт (разные...
4,5,1058275,12834,False,26389,Lim Market,4.6,"{""title"": ""Пижама женская, однотонная с шортам...",Все категории->Одежда->Женская одежда->Домашня...,"{'title': 'Пижама женская, однотонная с шортам...","Пижама женская, однотонная с шортами",Лёгкая ткань! Комфортная посадка! Идеальная дл...,"Lim Market. Пижама женская, однотонная с шорта..."


In [9]:
print(f'Max value of category_id: {df_train.category_id.min()}')
print(f'Min value of category_id: {df_train.category_id.max()}')

Max value of category_id: 2599
Min value of category_id: 15076


## <center> Ordered X-th category (main)

In [10]:
# Needed for correct work of loss function: get category number from 0 to max
category_X_map = dict([(x[1], x[0]) for x in enumerate(df_train['category_id'].unique())])
df_train['category_X_id'] = df_train['category_id'].apply(lambda x: category_X_map[x])

df_train.shape

(91116, 14)

In [11]:
print(f'Max value of category_X_id: {df_train.category_X_id.min()}')
print(f'Min value of category_X_id: {df_train.category_X_id.max()}')

Max value of category_X_id: 0
Min value of category_X_id: 869


## <center> Keep relevant columns

In [12]:
df_train_final = df_train[[
    'product_id',
    'category_X_id',
    'category_name',
    'text_feature',
]]

df_train_final.head(3)

Unnamed: 0,product_id,category_X_id,category_name,text_feature
0,325286,0,Все категории->Электроника->Смартфоны и телефо...,Aksik. Зарядный кабель Borofone BX1 Lightning ...
1,888134,1,Все категории->Одежда->Женская одежда->Белье и...,Sela. Трусы Sela. Трусы-слипы из эластичного б...
2,1267173,2,Все категории->Хобби и творчество->Рисование->...,"ЮНЛАНДИЯ канцтовары. Гуашь ""ЮНЫЙ ВОЛШЕБНИК"", 1..."


***

In [13]:
NUM_CLASSES_CAT_X = len(df_train_final['category_X_id'].unique())
print(f'Number of classes to train: {NUM_CLASSES_CAT_X}')

Number of classes to train: 870


***

# <p style="text-align:center;font-size:100%;">4. Steps before model building: Split on train and valid set</p>


In [14]:
splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=0)

X = df_train_final.values
y = df_train_final['category_X_id'].values

for i, (train_index, valid_index) in enumerate(splitter.split(X, y)):
    print(f"  Train: index = {train_index}")
    print(f"  Valid: index = {valid_index}")

  Train: index = [72371 85938 81148 ... 48760 77109 67262]
  Valid: index = [ 8298 54709 65490 ... 29941 17461 86102]


In [15]:
train_data = df_train_final.iloc[train_index]
valid_data = df_train_final.iloc[valid_index]

In [16]:
# Save train and valid datasets to pkl format for futher work
train_data.to_pickle('./train_data.pkl')
valid_data.to_pickle('./valid_data.pkl')

## <center> Create dataset

In [17]:
"""
Create datasets: exctract datasets from pkl format
"""
df_dataset = load_dataset("pandas", data_files={"train": "./train_data.pkl", "valid": "./valid_data.pkl"})
df_dataset

Downloading and preparing dataset pandas/default to /root/.cache/huggingface/datasets/pandas/default-873f3e231b30e5f8/0.0.0/6197c1e855b639d75a767140856841a562b7a71d129104973fe1962594877ade...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Dataset pandas downloaded and prepared to /root/.cache/huggingface/datasets/pandas/default-873f3e231b30e5f8/0.0.0/6197c1e855b639d75a767140856841a562b7a71d129104973fe1962594877ade. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['product_id', 'category_X_id', 'category_name', 'text_feature', '__index_level_0__'],
        num_rows: 72892
    })
    valid: Dataset({
        features: ['product_id', 'category_X_id', 'category_name', 'text_feature', '__index_level_0__'],
        num_rows: 18224
    })
})

## <center> Process dataset (tokenize)

In [18]:
"""
Multilingual MiniLM uses the same tokenizer as XLM-R
But the Transformer architecture of model is the same as BERT
"""
checkpoint = "microsoft/Multilingual-MiniLM-L12-H384"

model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")

Downloading config.json:   0%|          | 0.00/430 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/471M [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/615 [00:00<?, ?B/s]

Downloading (…)tencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/9.10M [00:00<?, ?B/s]

In [19]:
# Let's look at the model structure
model.bert

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(250037, 384, padding_idx=0)
    (position_embeddings): Embedding(512, 384)
    (token_type_embeddings): Embedding(2, 384)
    (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=384, out_features=384, bias=True)
            (key): Linear(in_features=384, out_features=384, bias=True)
            (value): Linear(in_features=384, out_features=384, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=384, out_features=384, bias=True)
            (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
 

In [20]:
def tokenize_function(example):
    """
    Function to get tokens from text
    """
    return tokenizer(example["text_feature"], truncation=True)

In [21]:
tokenized_datasets = df_dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["__index_level_0__", "text_feature", "category_name"])
tokenized_datasets = tokenized_datasets.rename_column("category_X_id", "labels") # the model expects the argument to be named labels
# tokenized_datasets.set_format("torch") or tokenized_datasets = tokenized_datasets.with_format("torch")
tokenized_datasets["train"].column_names

  0%|          | 0/73 [00:00<?, ?ba/s]

  0%|          | 0/19 [00:00<?, ?ba/s]

['product_id', 'labels', 'input_ids', 'attention_mask']

In [22]:
# Check tokenizer results
sample_id = 2
print('Original text: ')
print(df_dataset["train"][sample_id]["text_feature"])
print()
print('Tokens: ')
# This tokenizer is a subword tokenizer: it splits the words until it obtains tokens that can be represented by its vocabulary
print(tokenizer.tokenize(df_dataset["train"][sample_id]["text_feature"]))
print()
print('Tokenizer decoded text: ')
print(tokenizer.decode(tokenized_datasets["train"][sample_id]["input_ids"]))
print()
print('Tokenizer inputs-ids (how model gets data): ')
print(tokenized_datasets["train"][sample_id]["input_ids"])

Original text: 
Чехлович. Samsung Galaxy A80 чехол силиконовый под кожу "полигоны". Чехол изготовлен из силикона, защищающего телефон от ударов, а также от воздействия пыли, грязи и воды. Его мягкая поверхность позволяет надёжно удерживать устройство в руках.&nbsp;Чехол не утолщает телефон, поэтому его удобно носить в кармане или сумке. По любым вопросам по товару пишите нам, нажав на кнопку «Спросить продавца» под карточкой товара.&nbsp;Если не нашли чехол или стекло на вашу модель, тоже пишите, поможем найти.

Tokens: 
['▁Чех', 'л', 'ович', '.', '▁Samsung', '▁Galaxy', '▁A', '80', '▁че', 'хол', '▁сили', 'кон', 'овый', '▁под', '▁ко', 'жу', '▁"', 'пол', 'иг', 'оны', '".', '▁Чех', 'ол', '▁изготовлен', '▁из', '▁сили', 'ко', 'на', ',', '▁защища', 'ющего', '▁телефон', '▁от', '▁удар', 'ов', ',', '▁а', '▁также', '▁от', '▁воздействия', '▁пыл', 'и', ',', '▁гряз', 'и', '▁и', '▁воды', '.', '▁Его', '▁', 'мягк', 'ая', '▁поверх', 'ность', '▁позволяет', '▁на', 'дё', 'жно', '▁у', 'держ', 'ивать', '▁ус

## <center> Create dataloader

In [23]:
"""
Define a collate function that will apply the correct amount of padding to the items of the dataset we want to batch together
"""
batch_size = 8
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) # It puts together all the samples in a batch

train_dataloader = DataLoader(
    tokenized_datasets["train"], shuffle=True, batch_size=batch_size, collate_fn=data_collator
)

valid_dataloader = DataLoader(
    tokenized_datasets["valid"], batch_size=batch_size, collate_fn=data_collator
)

for batch in train_dataloader:
    break
{k: v.shape for k, v in batch.items()}

{'product_id': torch.Size([8]),
 'labels': torch.Size([8]),
 'input_ids': torch.Size([8, 485]),
 'attention_mask': torch.Size([8, 485])}

In [24]:
for step, batch in enumerate(train_dataloader):
    print(batch["input_ids"].shape)
    if step > 5:
        break

torch.Size([8, 512])
torch.Size([8, 208])
torch.Size([8, 512])
torch.Size([8, 198])
torch.Size([8, 400])
torch.Size([8, 391])
torch.Size([8, 512])


# <p style="text-align:center;font-size:100%;">5. Model</p>

**Training process:**
- Transfer-learning BERT on our classes
- We do domain adaptation by fine-tuning BERT's embedding layer

In [25]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device used: {device}")

Device used: cuda


In [26]:
class ProductCategorizationModel(nn.Module):
    """
    @param    bert: Model object
    @param    freeze_bert (bool): Set `False` to fine-tune the BERT model
    """
    def __init__(self, checkpoint, num_classes):
        """
        Architecture:
        - Transformer
        - Final FC linear layer to get output for multiclass classification
        """
        super(ProductCategorizationModel, self).__init__()
        
        # Load pretrained multi-language BERT model
        self.pretrained_model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
        # Keep only transformer part, do not include classification head since we
        # are going to add out own classification head.
        self.pretrained_model = self.pretrained_model.bert
        
        self.fc1 = nn.Linear(in_features=384, out_features=192)
        self.fc2 = nn.Linear(in_features=192, out_features=num_classes)
        
        self.relu = nn.ReLU()
        # Dropout: this regularization technique helps prevent overfitting by reducing the reliance on specific neurons
        self.dropout = nn.Dropout(0.1)
        
        
    def forward(self, input_ids, attention_mask):
        # Attention mask is a binary tensor indicating the position of the padded indices so that the model does not attend to them. 
        bert_output = self.pretrained_model(input_ids=input_ids, attention_mask=attention_mask)
        text_embedding = bert_output[0][:, 0]
             
        x = self.dropout(text_embedding)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x
    

    def freeze_bert(self):
        """
        Freezes the parameters of BERT so only the weights 
        of the custom classifier are modified.
        """
        for name, param in self.pretrained_model.named_parameters():
            if "classifier" not in name: # classifier layer
                param.requires_grad = False
        
#         for param in self.pretrained_model.parameters():
#             param.requires_grad = False
                
        
    def unfreeze_emdedding(self):
        """
        Unfreeze Embedding layer of BERT for domain adaptation.
        """
        for name, param in model.named_parameters():
            if "embeddings" in name:
                param.requires_grad = True
#         for param in self.pretrained_model.embeddings.parameters():
#             param.requires_grad = True

In [27]:
model = ProductCategorizationModel(checkpoint=checkpoint, num_classes=NUM_CLASSES_CAT_X).to(device)

In [28]:
# for name, param in model.named_parameters():
#     print(name, param.requires_grad)

In [29]:
model_num_parameters = sum(p.numel() for p in model.parameters()) / 1000000
print(f"'>>> Number of parameters: {round(model_num_parameters)}M'")

'>>> Number of parameters: 118M'


***

In [30]:
# Define functions
def save_checkpoint(state, filename="/kaggle/working/checkpoint.pt"):
    """
    Save checkpoint in order to get the best model parameters
    """
    print("=> Saving checkpoint")
    torch.save(state, filename)
    

def load_checkpoint(checkpoint, model, optimizer):
    """
    Load checkpoint with the best model parameters
    """
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])


# Define Classes
class EarlyStopping():
    """
    regularization techniques to combat the overfitting issue
    """
    def __init__(self, max_epochs):
        self.max_epochs = max_epochs
        self.current_epochs = 0
        self.best_loss = np.inf

    def __call__(self, current_loss):
        if current_loss < self.best_loss:
            self.best_loss = current_loss
            self.current_epochs = 0
        else:
            self.current_epochs += 1
            if self.current_epochs >= self.max_epochs:
                print("Early Stopping!")
                return True

            
class SaveBestModel():
    def __init__(self):
        self.best_valid_loss = np.inf
        self.best_model = None
        
    def __call__(self, current_valid_loss, model, optimizer):
        if current_valid_loss < self.best_valid_loss:
            self.best_model = model
            self.best_valid_loss = current_valid_loss
            print(f"Best validation loss: {self.best_valid_loss}")
            checkpoint = {'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
            save_checkpoint(checkpoint, filename='best_model.pt')
        return self.best_model

In [36]:
def model_train(train_loader, model, optimizer, criterion, scheduler, num_epochs, device):
    # Train Network
    train_loss_sum = 0
    running_train_outputs, running_train_labels = None, None
    train_progress_bar = tqdm(range(len(train_dataloader)))
    
    model.train()
    model.freeze_bert()
    model.unfreeze_emdedding()
    
    for bdx, batch in enumerate(train_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()} # input and target values
        
        # forward -> get model prediciton
        outputs = model(input_ids=batch["input_ids"], attention_mask=batch['attention_mask'])
        loss = criterion(outputs, batch["labels"]) # Compute loss
        
        """
        * zero_grad clears old gradients from the last step (otherwise you’d just accumulate the gradients from all loss.backward() calls).
        * loss.backward() computes the derivative of the loss w.r.t. the parameters (or anything requiring gradients) using backpropagation.
        * opt.step() causes the optimizer to take a step based on the gradients of the parameters.
        """
        # backward
        optimizer.zero_grad()
        loss.backward() # calculate the gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Clip the norm of the gradients to 1.0. This is to help prevent the "exploding gradients" problem
        
        # gradient descent or adam step
        optimizer.step() # The optimizer dictates the "update rule"--how the parameters are modified based on their gradients, the learning rate, etc.
        lr_scheduler.step()
        
        # Aggregate metrics
        train_loss_sum += loss.item()
        if bdx == 0:
            running_train_outputs = outputs.detach().cpu()
            running_train_labels = batch["labels"].detach().cpu()
        else:
            running_train_outputs = torch.cat([running_train_outputs, outputs.detach().cpu()], axis=0)
            running_train_labels = torch.cat([running_train_labels, batch["labels"].detach().cpu()], axis=0)
            
        train_progress_bar.update(1)
        
    # Compute metrics
    train_loss_sum /= len(train_dataloader)
    train_f1 = f1(running_train_outputs, running_train_labels).item()
    train_precision = precision(running_train_outputs, running_train_labels).item()
    train_recall = recall(running_train_outputs, running_train_labels).item()
    print(f'Epoch: {epoch+1}. Loss (train): {train_loss_sum:.4f}. F1 (train): {train_f1:.4f}. Precision (train): {train_precision:.4f}. Recall: (train): {train_recall:.4f}')
    
    # Clean garbage data
#     del running_train_outputs
#     del running_train_labels
    torch.cuda.empty_cache()

In [32]:
def model_valid(valid_loader, model, optimizer, criterion, num_epochs, device):
    valid_loss_sum = 0
    valid_progress_bar = tqdm(range(len(valid_dataloader)))
    running_valid_outputs, running_valid_labels = None, None
            
    # Eval model turns off some features like dropout and batch normalization.
    # Does not automatically disable gradient computation for all operations.
    model.eval()
    # Temporarly turn off gradient computation for given block of code.
    
    with torch.no_grad():
        for bdx, batch in enumerate(valid_dataloader):
            batch = {k: v.to(device) for k, v in batch.items()}
            # Get model prediciton.
            outputs = model(input_ids=batch["input_ids"], attention_mask=batch['attention_mask'])
            # Compute loss.
            val_loss = criterion(outputs, batch["labels"])
                    
            # Aggregate metrics.
            valid_loss_sum += val_loss.item()
            if bdx == 0:
                running_valid_outputs = outputs.detach().cpu()
                running_valid_labels = batch["labels"].detach().cpu()
            else:
                running_valid_outputs = torch.cat([running_valid_outputs, outputs.detach().cpu()], axis=0)
                running_valid_labels = torch.cat([running_valid_labels, batch["labels"].detach().cpu()], axis=0)
                    
            valid_progress_bar.update(1)

    # Compute metrics.
    valid_loss_sum /= len(valid_dataloader)
    valid_f1 = f1(running_valid_outputs, running_valid_labels).item()
    valid_precision = precision(running_valid_outputs, running_valid_labels).item()
    valid_recall = recall(running_valid_outputs, running_valid_labels).item()
    print(f'Epoch: {epoch+1}. Loss (valid): {valid_loss_sum:.4f}. F1 (valid): {valid_f1:.4f}. Precision (valid): {valid_precision:.4f}. Recall: (valid): {valid_recall:.4f}')
    
    # Clean garbage data
#     del running_valid_outputs
#     del running_valid_labels
    torch.cuda.empty_cache()
        
    return val_loss, valid_f1

***

# <p style="text-align:center;font-size:100%;">6. NN model training</p>

In [33]:
"""
Metrics
"""
# How to compute metric? Micro or Macro?
# =======================================================================================================
# Micro is a weighted version of metric, so small classes will be underweighted.
# Calculate metrics globally by counting the total number of times each class was correctly predicted and incorrectly predicted.

# Macro is class-imbalance agnostic, and cares only about correctly prediction a class
# Calculate metrics for each "class" independently, and find their unweighted mean.
# This does not take label imbalance into account.
# =======================================================================================================
# Choice: "micro", since we have a lot of classes with n_samples < 10, it will be extremely
# taugh and LONG for model to learn how to identify them, without proper handling of imbalanced classes.
# To avoid that issue, while we do not solve imbalanced classes issue, it's better to use "micro" method.
# -------------------------------------------------------------------------------------------------------
# !!! The right way to do it is to use MACRO! But since we DO NOT handle imbalanced classes, it is
# better to use MICRO average method for now.
# -------------------------------------------------------------------------------------------------------
# Read more: https://hackmd.io/@antolaga/H1rGPefeK
# Read more: https://www.educative.io/answers/what-is-the-difference-between-micro-and-macro-averaging

metric_average_method = 'micro'

f1 = F1Score(task="multiclass", average=metric_average_method, num_classes=NUM_CLASSES_CAT_X).to('cpu')
precision = Precision(task="multiclass", average=metric_average_method, num_classes=NUM_CLASSES_CAT_X).to('cpu')
recall = Recall(task="multiclass", average=metric_average_method, num_classes=NUM_CLASSES_CAT_X).to('cpu')

In [37]:
# Loss, optimizer, other parameters
num_epochs = 5
learning_rate = 1e-3

criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

num_training_steps = num_epochs * len(train_dataloader)

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

print(f'Number of training steps: {num_training_steps}\n')

early_stopping = EarlyStopping(max_epochs = 3)
save_best_model = SaveBestModel()

Number of training steps: 45560



In [38]:
## Train Model
for epoch in range(num_epochs): 
    model_train(train_dataloader, model, optimizer, criterion, lr_scheduler, num_epochs, device)
    if (epoch+1) % 2 == 0: pass
    else:
        val_loss, valid_f1 = model_valid(valid_dataloader, model, optimizer, criterion, num_epochs, device)
        best_model = save_best_model(val_loss, model, optimizer)
        if early_stopping(val_loss): break
        print(50 * '-')

  0%|          | 0/9112 [00:00<?, ?it/s]

Epoch: 1. Loss (train): 3.9752. F1 (train): 0.2993. Precision (train): 0.2993. Recall: (train): 0.2993


  0%|          | 0/2278 [00:00<?, ?it/s]

Epoch: 1. Loss (valid): 2.6699. F1 (valid): 0.4938. Precision (valid): 0.4938. Recall: (valid): 0.4938
Best validation loss: 2.7640159130096436
=> Saving checkpoint
--------------------------------------------------


  0%|          | 0/9112 [00:00<?, ?it/s]

Epoch: 2. Loss (train): 2.0969. F1 (train): 0.5830. Precision (train): 0.5830. Recall: (train): 0.5830


  0%|          | 0/9112 [00:00<?, ?it/s]

Epoch: 3. Loss (train): 1.3737. F1 (train): 0.7039. Precision (train): 0.7039. Recall: (train): 0.7039


  0%|          | 0/2278 [00:00<?, ?it/s]

Epoch: 3. Loss (valid): 1.5005. F1 (valid): 0.7212. Precision (valid): 0.7212. Recall: (valid): 0.7212
Best validation loss: 1.3898680210113525
=> Saving checkpoint
--------------------------------------------------


  0%|          | 0/9112 [00:00<?, ?it/s]

KeyboardInterrupt: 

***

# <p style="text-align:center;font-size:100%;">7. Evaluation</p>

In [39]:
"""
Load the best model checkpoint
"""
checkpoint = torch.load("/kaggle/working/best_model.pt")

# Model
best_model = model.to(device)

# Optimizer
optimizer = torch.optim.Adam(best_model.parameters(), lr=learning_rate)

# Load Checkpoint
load_checkpoint(checkpoint, best_model, optimizer)

=> Loading checkpoint


In [40]:
# f1_micro, f1_macro scores compute
f1_micro = F1Score(task="multiclass", average="micro", num_classes=NUM_CLASSES_CAT_X).to('cpu')
f1_macro = F1Score(task="multiclass", average="macro", num_classes=NUM_CLASSES_CAT_X).to('cpu')

In [41]:
def evaluate_model(df_data, model, valid_dataloader, device, tokenizer):
    model.eval()
    
    with torch.no_grad():
        
        progress_bar = tqdm(range(len(valid_dataloader)))
        running_outputs, running_labels = None, None
        text_description = []
        
        for bdx, batch in enumerate(valid_dataloader):
            
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(input_ids=batch["input_ids"], attention_mask=batch['attention_mask'])
            
            if bdx == 0:
                running_outputs = outputs.detach().cpu()
                running_labels = batch["labels"].detach().cpu()
            else:
                running_outputs = torch.cat([running_outputs, outputs.detach().cpu()], axis=0)
                running_labels = torch.cat([running_labels, batch["labels"].detach().cpu()], axis=0)
            
            text_description += [tokenizer.decode(batch["input_ids"][x, :]).replace('<pad>', '') for x in range(len(batch["input_ids"]))]
            progress_bar.update(1)
            
    true_labels = running_labels.numpy()
    predicted_labels = torch.argmax(running_outputs, axis=1).numpy()
    predicted_category_names, true_category_names = [], []
    
    for pred_label, true_label in zip(predicted_labels, true_labels):
        predicted_category_names.append(df_data[lambda x: x['category_X_id'] == pred_label]['category_name'].values[0])
        true_category_names.append(df_data[lambda x: x['category_X_id'] == true_label]['category_name'].values[0])
        
    return {
        "text_description": text_description,
        "predicted_category_id": predicted_labels,
        "running_outputs": running_outputs.numpy(),
        "true_category_id": true_labels,
        "predicted_category_names": predicted_category_names,
        "true_category_names": true_category_names,
    }

In [42]:
prediction_batch = evaluate_model(df_train_final, model, valid_dataloader, device, tokenizer)

  0%|          | 0/2278 [00:00<?, ?it/s]

In [43]:
f1_macro_metric = f1_macro(
    torch.tensor(prediction_batch["running_outputs"]), 
    torch.tensor(prediction_batch["true_category_id"]),
).item()

f1_micro_metric = f1_micro(
    torch.tensor(prediction_batch["running_outputs"]), 
    torch.tensor(prediction_batch["true_category_id"]),
).item()

print(f"F1(micro): {f1_micro_metric:.4f}. F1(macro): {f1_macro_metric:.4f}")

F1(micro): 0.7212. F1(macro): 0.3131


In [44]:
n_samples_to_check = 5

for i in range(n_samples_to_check):
    print(f'Sample: {i}')
    print(f'Text description: {prediction_batch["text_description"][i]}')
    print(f'True category: {prediction_batch["true_category_names"][i]}')
    print(f'Predicted category: {prediction_batch["predicted_category_names"][i]}')
    print()

Sample: 0
Text description: <s> Бантик. Мешочек новогодний "Снежинки" WF-607, 10*12см, цвет золотой. Мешочек новогодний "Снежинки" WF-607, 10*12см, цвет золотой</s>
True category: Все категории->Товары для дома->Товары для праздников->Подарочная упаковка->Мешочки подарочные
Predicted category: Все категории->Хобби и творчество->Рисование->Краски, пигменты

Sample: 1
Text description: <s> Mom&Me. Колготки теплые для беременных с высокой регулируемой талией. Комфортные и мягкие колготки для будущих мам с укрепленным мыском и пяткой. Благодаря специальной анатомической вставке на животе и регулируемой резинке колготки можно носить с на любом сроке беременности.&nbsp; Если видео не воспроизводится нажмите на синий текст ошибки и вы сможете посмотреть обзор в youtube.</s>
True category: Все категории->Одежда->Женская одежда->Одежда для беременных->Колготки
Predicted category: Все категории->Одежда->Женская одежда->Брюки и джинсы->Брюки

Sample: 2
Text description: <s> Texноmall. Карта памят