# Import packages

In [55]:
import json
from functools import partial
from pathlib import Path

import pandas as pd
import pytorch_lightning
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

from attribute_extraction.models.attribute_classification import MultiAttributeClassifier
from attribute_extraction.models.mapper import Mapper
from attribute_extraction.models.train_utils import (
    AttributeDataset,
    MultiAttributeBatchSampler,
    build_callbacks,
    collate_fun_generator,
)
from attribute_extraction.utils.data_balancing_utils import add_weight_for_data_balancing


# setting up variables and hyperparameters

In [56]:
attribute_code_col = "attribute_code"
attribute_lov_col = "lov_code"
context_col = "context"
weight_col = "weight"
mapped_lov_col = "mapped_lov_code"

local_path = "../outputs_train_workflow/"

model_name = "distilbert-base-multilingual-cased"
experiment_description = "Multi-task small GRU model with multilanguage distilledBert tokenizer"
num_epoch = 20
max_len = 512
batch_size = 128 
freeze_backbone = True
learning_rate = 1e-4
projection_dim = 256
dropout = 0.2
data_balance = True
upper_qn = 0.9
lower_qn = 0.1
lov_attribute_codes = ["02419", "01746", "00562", "15344", "99999"]
train_set_uri = "../data/train_formatted.csv"
val_set_uri = "../data/val_formatted.csv"
test_set_uri = "../data/train_formatted.csv"

# Save the hyper_parameters
hyper_parameters = {
    "model_name": model_name,
    "experiment_description": experiment_description,
    "num_epoch": num_epoch,
    "max_len": max_len,
    "batch_size": batch_size,
    "freeze_backbone": freeze_backbone,
    "learning_rate": learning_rate,
    "projection_dim": projection_dim,
    "dropout": dropout,
    "data_balance": data_balance,
    "upper_qn": upper_qn,
    "lower_qn": lower_qn,
    "lov_attribute_codes": lov_attribute_codes,
    "train_set_uri": train_set_uri,
    "val_set_uri": val_set_uri,
    "test_set_uri": test_set_uri,

}

with open(f"{local_path}hyper_parameters.json", "w") as f:
    json.dump(hyper_parameters, f, indent=4)

# Loading data

In [57]:
data_train = pd.read_csv(train_set_uri).reset_index(drop=True)
data_val = pd.read_csv(val_set_uri).reset_index(drop=True)
data_test = pd.read_csv(test_set_uri).reset_index(drop=True)

In [58]:
data_train['attribute_code'] = data_train['attribute_code'].apply(lambda x: str(x).zfill(5))
data_train['lov_code'] = data_train['lov_code'].apply(lambda x: str(x).zfill(5))

data_val['attribute_code'] = data_val['attribute_code'].apply(lambda x: str(x).zfill(5))
data_val['lov_code'] = data_val['lov_code'].apply(lambda x: str(x).zfill(5))

data_test['attribute_code'] = data_test['attribute_code'].apply(lambda x: str(x).zfill(5))
data_test['lov_code'] = data_test['lov_code'].apply(lambda x: str(x).zfill(5))

In [59]:
data_train.description_clean[0]

"Facile d'entretien tapis de coton \xa0 \xa0Ce coton naturel Flachflorteppich 100% obtient son look spécial par le motif de rayures élégantes dans des couleurs harmonieuses et douces pompons sur les coins. La conception à haut contraste fixe habilement le ton dans la chambre et donne à chaque dispositif a un facteur de confort. Créer une ambiance harmonieuse conçue combinée avec tampon de laine ou un coussin de laine qui respire le confort. \xa0Le matériau de laine dense, doux, grâce aux propriétés d'isolation thermique, même sur un carrelage froid de sol agréablement doux et le réchauffement. \xa0 \xa0Vous recevez ce tapis en coton comme une sorte unique. En raison du matériau naturel et la production de tapis tissés à la main pas la même chose. En forme, la couleur et motif, il peut y avoir des différences allant jusqu'à 30% de l'image et d'autres tapis de cette qualité."

In [60]:
data_train[context_col] = data_train.title + " " + data_train.description_clean
data_val[context_col] = data_val.title + " " + data_val.description_clean
data_test[context_col] = data_test.title + " " + data_test.description_clean

# Building Mapper

In [61]:
mapper = Mapper(attribute_code_col=attribute_code_col, attribute_value_col=attribute_lov_col)

mapper.fit(pd.concat([data_train, data_val, data_test]))

mapper.save(f"{local_path}mapper.json")

# Mapping Columns

In [62]:
data_train = mapper.map_dataframe(data_train, mapped_col_name=mapped_lov_col)
data_val = mapper.map_dataframe(data_val, mapped_col_name=mapped_lov_col)
data_test = mapper.map_dataframe(data_test, mapped_col_name=mapped_lov_col)

# Data balancing

In [63]:
if data_balance:

    data_train = add_weight_for_data_balancing(
        df=data_train,
        label_col=mapped_lov_col,
        weight_col=weight_col,
        attribute_code_col=attribute_code_col,
        upper_qn=upper_qn,
        lower_qn=lower_qn,
    )

    data_val[weight_col] = 1
    data_test[weight_col] = 1
else:
    data_train[weight_col] = 1
    data_val[weight_col] = 1
    data_test[weight_col] = 1

# Model Initialisation

In [64]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = MultiAttributeClassifier(
    vocab_size=tokenizer.vocab_size,
    class_config=mapper.mappings,
    freeze_backbone=freeze_backbone,
    warmup_steps=data_train.shape[0] // batch_size // 100,
    estimated_stepping_batches=data_train.shape[0]
    * num_epoch
    // batch_size,
    num_cycles=num_epoch,
    projection_dim=projection_dim,
    dropout=dropout,
    learning_rate=learning_rate,
)

# Dataloaders Initialisation

In [65]:
train = AttributeDataset(
    data=data_train,
    context_col_name=context_col,
    label_col_name=mapped_lov_col,
    attribute_code_col_name=attribute_code_col,
)
validation = AttributeDataset(
    data=data_val,
    context_col_name=context_col,
    label_col_name=mapped_lov_col,
    attribute_code_col_name=attribute_code_col,
)
test = AttributeDataset(
    data=data_test,
    context_col_name=context_col,
    label_col_name=mapped_lov_col,
    attribute_code_col_name=attribute_code_col,
)

train_loader = DataLoader(
    dataset=train,
    batch_sampler=MultiAttributeBatchSampler(
        data=data_train,
        batch_size=batch_size,
        split_col=attribute_code_col,
        weight_col=weight_col,
    ),
    collate_fn=partial(
        collate_fun_generator,
        tokenizer=tokenizer,
        max_len=max_len,
    ),
    num_workers=4,
)

validation_loader = DataLoader(
    dataset=validation,
    batch_sampler=MultiAttributeBatchSampler(
        data=data_val,
        batch_size=batch_size,
        split_col=attribute_code_col,
        weight_col=weight_col,
    ),
    collate_fn=partial(
        collate_fun_generator,
        tokenizer=tokenizer,
        max_len=max_len,
    ),
    num_workers=4,
)

test_loader = DataLoader(
    dataset=test,
    batch_sampler=MultiAttributeBatchSampler(
        data=data_test,
        batch_size=batch_size,
        split_col=attribute_code_col,
        weight_col=weight_col,
    ),
    collate_fn=partial(
        collate_fun_generator,
        tokenizer=tokenizer,
        max_len=max_len,
    ),
    num_workers=4,
)

# Build Callbacks

In [66]:
callbacks, metric_logger = build_callbacks(output_path=Path(local_path), model_name="model")

# Model Training

In [67]:
trainer = pytorch_lightning.Trainer(
    gpus=1 if torch.cuda.is_available() else 0,
    logger=metric_logger,
    callbacks=callbacks,
    log_every_n_steps=1,
    max_epochs=num_epoch,
    precision=16 if torch.cuda.is_available() else 32,
)

trainer.fit(model, train_loader, validation_loader)

GPU available: False, used: False
TPU available: None, using: 0 TPU cores
Missing logger folder: ../outputs_train_workflow/model

  | Name                 | Type             | Params
----------------------------------------------------------
0 | encoder              | Embedding        | 30.6 M
1 | classification_heads | ModuleDict       | 2.9 M 
2 | ce_loss              | CrossEntropyLoss | 0     
----------------------------------------------------------
33.5 M    Trainable params
0         Non-trainable params
33.5 M    Total params
133.906   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

MisconfigurationException: ModelCheckpoint(monitor='validation_loss') not found in the returned metrics: ['train_loss', 'train_accuracy_02419', 'train_accuracy_99999', 'train_accuracy_00562', 'train_accuracy_15344', 'train_accuracy_01746']. HINT: Did you call self.log('validation_loss', tensor) in the LightningModule?

# Model Evaluation

In [None]:
test_metrics = trainer.test(model, test_loader)
test_metrics_sort = [dict(sorted(test_metrics[0].items()))]

with open(local_path / "metrics.json", "w") as f:
    json.dump(test_metrics, f, indent=4)

with open(local_path / "metrics_sorted.json", "w") as f:
    json.dump(test_metrics_sort, f, indent=4)