<a href="https://colab.research.google.com/github/Amiri-007/kaggle_jigsaw_RDS/blob/master/ToxicCommentClassification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()


In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

jigsaw_unintended_bias_in_toxicity_classification_path = kagglehub.competition_download('jigsaw-unintended-bias-in-toxicity-classification')
mahmoudelkarargy_dataaugumented_cleaned_path = kagglehub.dataset_download('mahmoudelkarargy/dataaugumented-cleaned')

print('Data source import complete.')


In [None]:
!pip install torchmetrics
!pip install loguru

In [None]:
import os
import time
import datetime
from typing import Any, Union, Dict, List
import uuid
import json

import pandas as pd
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchtext
import nltk
import sklearn
import transformers
import torchmetrics as tm
from torchmetrics import MetricCollection, Metric, Accuracy, Precision, Recall, AUROC, HammingDistance, F1Score, ROC, PrecisionRecallCurve


from loguru import logger
from tqdm.auto import tqdm
tqdm.pandas()

import warnings
warnings.filterwarnings("ignore")

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import re
from bs4 import BeautifulSoup
import unicodedata
import seaborn as sns
import zipfile
import os

In [None]:
dataset_path = "/kaggle/input/jigsaw-toxic-comment-classification-challenge"

# Specify the extraction directory
extract_dir = "/kaggle/working/unzipped_files"

# Create the extraction directory if it doesn't exist
os.makedirs(extract_dir, exist_ok=True)

# Iterate through the dataset directory to find and unzip all zip files
for file_name in os.listdir(dataset_path):
    if file_name.endswith(".zip"):
        zip_file_path = os.path.join(dataset_path, file_name)
        print(f"Unzipping {zip_file_path}...")
        with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
            zip_ref.extractall(extract_dir)

print("All files unzipped successfully!")


In [None]:
# dataset_path = "/kaggle/input/jigsaw-unintended-bias-in-toxicity-classification"
dataset_path = "/kaggle/working/unzipped_files"

CUSTOME_NAME = "roberta-fl-augumented"

# Dataset
DATA_DIR_PATH = dataset_path
TRAIN_DATASET_PATH = os.path.join(DATA_DIR_PATH, "train.csv")
# TRAIN_DATASET_PATH = "/kaggle/working/augumented_by_25k_small_data.csv"
TEST_DATASET_PATH = os.path.join(DATA_DIR_PATH, "test.csv")

#LABEL_LIST = ['severe_toxicity', 'obscene', 'threat', 'insult', 'identity_attack', 'sexual_explicit']
LABEL_LIST = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']

In [None]:
# Session
SESSION_DIR_PATH = os.path.abspath("./session")
SESSION_DATETIME = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S-%f")
SESSION_NAME = f"{CUSTOME_NAME}_{SESSION_DATETIME}"
CURRENT_SESSION_DIR_PATH = os.path.join(SESSION_DIR_PATH, SESSION_NAME)
# Créer le dossier de la session
os.makedirs(CURRENT_SESSION_DIR_PATH, exist_ok=True)

# Architecture de fichier dans `CURRENT_SESSION_DIR_PATH`
LOG_FILE_NAME = f"{SESSION_NAME}.loguru.log"
MODEL_FILE_NAME = f"{SESSION_NAME}.model"
TEST_FILE_NAME = f"{SESSION_NAME}.test.csv"
VALIDATION_DATASET_NAME = f"{SESSION_NAME}.jigsaw2019-validation.csv"
VALIDATION_FILE_NAME = f"{SESSION_NAME}.validation.csv"
METRIC_FILE_NAME = f"{SESSION_NAME}.metric.json"
LOG_FILE_PATH = os.path.join(CURRENT_SESSION_DIR_PATH, LOG_FILE_NAME)
MODEL_FILE_PATH = os.path.join(CURRENT_SESSION_DIR_PATH, MODEL_FILE_NAME)
TEST_FILE_PATH = os.path.join(CURRENT_SESSION_DIR_PATH, TEST_FILE_NAME)
VALIDATION_DATASET_FILE_PATH = os.path.join(CURRENT_SESSION_DIR_PATH, VALIDATION_DATASET_NAME)
VALIDATION_FILE_PATH = os.path.join(CURRENT_SESSION_DIR_PATH, VALIDATION_FILE_NAME)
METRIC_FILE_PATH = os.path.join(CURRENT_SESSION_DIR_PATH, METRIC_FILE_NAME)

# CUDA
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
logger.add(LOG_FILE_PATH, level="TRACE")
logger.info(f"{SESSION_NAME=}")
logger.info(f"{TRAIN_DATASET_PATH=}")
logger.info(f"{TEST_DATASET_PATH=}")
logger.info(f"{CURRENT_SESSION_DIR_PATH=}")
logger.info(f"{LABEL_LIST=}")

In [None]:
logger.info(f"Checking consistency...")


if not os.path.exists(TRAIN_DATASET_PATH):
    logger.critical(f"Train dataset does not exist !")
    raise RuntimeError("Train dataset does not exist !")
if not os.path.exists(TEST_DATASET_PATH):
    logger.critical(f"Test dataset does not exist !")
    raise RuntimeError("Test dataset does not exist !")
logger.success("Datasets are reachable")


GPU_IS_AVAILABLE = torch.cuda.is_available()
GPU_COUNT = torch.cuda.device_count()
logger.info(f"{GPU_IS_AVAILABLE=}")
logger.info(f"{GPU_COUNT=}")
if not GPU_IS_AVAILABLE:
    logger.critical("GPU and CUDA are not available !")
    raise RuntimeError("GPU and CUDA are not available !")
logger.success("GPU and CUDA are available")
logger.info(f"{device=}")
for gpu_id in range(GPU_COUNT):
    gpu_name = torch.cuda.get_device_name(0)
    logger.info(f"GPU {gpu_id} : {gpu_name}")

In [None]:
all_train_df = pd.read_csv(TRAIN_DATASET_PATH, index_col=0)
logger.success("Dataset loaded !")

In [None]:
all_train_df.head()

In [None]:
test_df = pd.read_csv(TEST_DATASET_PATH)

## Data Pre-Processing

### Data Cleaning Steps

In this section, we performed several data cleaning steps to prepare our dataset for further analysis and modeling. The steps included:

1. **Removing Samples with Labels Equal to -1:**
   - We removed all rows where the `target` column had a value of -1.

2. **Text Cleaning:**
   - **Remove HTML Tags:** We used BeautifulSoup to strip HTML tags from the text.
   - **Remove URLs:** We used regular expressions to remove URLs from the text.
   - **Remove Diacritics:** We used the `unicodedata` library to remove diacritics from the text.
   - **Transform to Lowercase:** We converted all text to lowercase.
   - **Remove Extra Whitespaces:** We used regular expressions to remove extra whitespaces from the text.

3. **Remove NA or Empty Comments:**
   - We removed rows where the `comment_text` column was NA or empty after the cleaning steps.

The cleaning operations that are always applied include:

- Replacing newline characters with spaces
- Removing any non-alphanumeric characters (except spaces)
- Removing any numbers
- Removing any extra spaces
- Removing any non-ASCII characters

In [None]:
print(all_train_df.iloc[28]['comment_text'])
print("Toxicity Level: ",all_train_df.iloc[28]['toxic'])


In [None]:
print(all_train_df.iloc[7]['comment_text'])
print("Toxicity Level: ",all_train_df.iloc[4]['toxic'])

In [None]:
# Remove rows with NaN values in 'comment_text' column
all_train_df = all_train_df.dropna(subset=['comment_text'])

In [None]:
import re

def decontracted(phrase):
    # specific
    phrase = re.sub(r"won't", "will not", phrase)
    phrase = re.sub(r"can\'t", "can not", phrase)

    # general
    phrase = re.sub(r"n\'t", " not", phrase)
    phrase = re.sub(r"\'re", " are", phrase)
    phrase = re.sub(r"\'s", " is", phrase)
    phrase = re.sub(r"\'d", " would", phrase)
    phrase = re.sub(r"\'ll", " will", phrase)
    phrase = re.sub(r"\'t", " not", phrase)
    phrase = re.sub(r"\'ve", " have", phrase)
    phrase = re.sub(r"\'m", " am", phrase)
    return phrase

# we are removing the words from the stop words list: 'no', 'nor', 'not'
stopwords= ['i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', "you're", "you've",\
            "you'll", "you'd", 'your', 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', \
            'she', "she's", 'her', 'hers', 'herself', 'it', "it's", 'its', 'itself', 'they', 'them', 'their',\
            'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', "that'll", 'these', 'those', \
            'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'having', 'do', 'does', \
            'did', 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', \
            'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 'during', 'before', 'after',\
            'above', 'below', 'to', 'from', 'up', 'down', 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further',\
            'then', 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 'both', 'each', 'few', 'more',\
            'most', 'other', 'some', 'such', 'only', 'own', 'same', 'so', 'than', 'too', 'very', \
            's', 't', 'can', 'will', 'just', 'don', "don't", 'should', "should've", 'now', 'd', 'll', 'm', 'o', 're', \
            've', 'y', 'ain', 'aren', "aren't", 'couldn', "couldn't", 'didn', "didn't", 'doesn', "doesn't", 'hadn',\
            "hadn't", 'hasn', "hasn't", 'haven', "haven't", 'isn', "isn't", 'ma', 'mightn', "mightn't", 'mustn',\
            "mustn't", 'needn', "needn't", 'shan', "shan't", 'shouldn', "shouldn't", 'wasn', "wasn't", 'weren', "weren't", \
            'won', "won't", 'wouldn', "wouldn't"]



In [None]:
# Combining all the above statemennts
preprocessed_comments = []
# tqdm is for printing the status bar
for sentence in tqdm(all_train_df['comment_text'].values):
    sent = decontracted(sentence)
    sent = sent.replace('\r', ' ')
    sent = sent.replace('\"', ' ')
    sent = sent.replace('\n', ' ')
    sent = re.sub('[^A-Za-z0-9]+', ' ', sent)
    # https://gist.github.com/sebleier/554280
    sent = ' '.join(e for e in sent.split() )
    preprocessed_comments.append(sent.lower().strip())

In [None]:
all_train_df['comment_text'] = preprocessed_comments
print(all_train_df.iloc[1]['comment_text'])

In [None]:
# Combining all the above statemennts
preprocessed_comments_test = []
# tqdm is for printing the status bar
for sentence in tqdm(test_df['comment_text'].values):
    sent = decontracted(sentence)
    sent = sent.replace('\r', ' ')
    sent = sent.replace('\"', ' ')
    sent = sent.replace('\n', ' ')
    sent = re.sub('[^A-Za-z0-9]+', ' ', sent)
    # https://gist.github.com/sebleier/554280
    sent = ' '.join(e for e in sent.split())
    preprocessed_comments_test.append(sent.lower().strip())

In [None]:
test_df['comment_text'] = preprocessed_comments_test

In [None]:
# Step 1: Remove samples with labels equal to -1
all_train_df = all_train_df[all_train_df['toxic'] != -1]
# Step 2: Remove rows with NA or empty comments
all_train_df = all_train_df[all_train_df['comment_text'].notna() & (all_train_df['comment_text'].str.strip() != '')]


In [None]:
# Function to clean text
def clean_text(text):
    cleaned_text = text
    # Remove HTML tags
    cleaned_text = BeautifulSoup(cleaned_text, "html.parser").get_text()
    # Remove URLs
    cleaned_text = re.sub(r'http\S+|www\S+|https\S+', '', cleaned_text, flags=re.MULTILINE)
    # Remove diacritics
    cleaned_text = ''.join(c for c in unicodedata.normalize('NFD', cleaned_text) if unicodedata.category(c) != 'Mn')
    # Transform to lowercase
    cleaned_text = cleaned_text.lower()
    # Remove extra whitespaces
    cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip()
    return cleaned_text

# Apply text cleaning to the 'comment_text' column
all_train_df['comment_text'] = all_train_df['comment_text'].apply(clean_text)


In [None]:
all_train_df.head()

In [None]:
# Save the cleaned dataset to a new CSV file if needed
# all_train_df.to_csv('/kaggle/working/train_data_without_null.csv', index=True)

## Data analysis

In [None]:
# Plot the distribution of toxicity
plt.figure(figsize=(10, 6))
sns.histplot(all_train_df['toxic'], bins=50, kde=True)
plt.title('Distribution of Toxicity Scores')
plt.xlabel('Toxicity Score')
plt.ylabel('Frequency')
plt.show()


In [None]:
# Update the toxic_label column with boolean values
all_train_df['toxic_label'] = (
    (all_train_df['toxic'] == 1) |
    (all_train_df['severe_toxic'] == 1) |
    (all_train_df['obscene'] == 1) |
    (all_train_df['threat'] == 1) |
    (all_train_df['insult'] == 1) |
    (all_train_df['identity_hate'] == 1)
).astype(int)

# Verify the changes
print(all_train_df[['toxic', 'toxic_label']].head())

In [None]:
# Calculate the number and percentage of toxic and non-toxic comments
toxic_counts = all_train_df['toxic_label'].value_counts()
toxic_percentage = all_train_df['toxic_label'].value_counts(normalize=True) * 100

# Plot the pie chart
plt.figure(figsize=(8, 8))
plt.pie(toxic_counts, labels=['Non-Toxic', 'Toxic'], autopct=lambda p: f'{p:.1f}% ({int(p * sum(toxic_counts) / 100)})', startangle=140, colors=['skyblue', 'salmon'])
plt.title('Percentage and Number of Toxic and Non-Toxic Comments')
plt.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.
plt.show()

# Print the number and percentage values
print('Number of non-toxic comments:', toxic_counts[0])
print('Number of toxic comments:', toxic_counts[1])
print('Percentage of non-toxic comments: {:.2f}%'.format(toxic_percentage[0]))
print('Percentage of toxic comments: {:.2f}%'.format(toxic_percentage[1]))
print('Total Number: ', all_train_df.shape[0])

In [None]:
# Define the toxicity columns
toxicity_columns = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']

# Check for distribution of different toxicity types
toxicity_distributions = all_train_df[toxicity_columns].sum()

# Plot the distribution of different toxicity types
plt.figure(figsize=(12, 8))
toxicity_distributions.plot(kind='bar')
plt.title('Distribution of Different Toxicity Types')
plt.xlabel('Toxicity Type')
plt.ylabel('Count')
plt.xticks(rotation=45)
plt.show()

# Print the number of samples for each toxicity type
for toxicity_type in toxicity_columns:
    count = (all_train_df[toxicity_type] > 0).sum()
    print(f'Number of samples for {toxicity_type}: {count}')

# Find the number of comments that belong to only one class
exclusive_counts = {}

for toxicity_type in toxicity_columns:
    mask = (all_train_df[toxicity_type] == 1) & (all_train_df[toxicity_columns].sum(axis=1) == 1)
    count = mask.sum()
    exclusive_counts[toxicity_type] = count

# Print the number of exclusive samples for each toxicity type
for toxicity_type, count in exclusive_counts.items():
    print(f'Number of comments exclusively for {toxicity_type}: {count}')

In [None]:
print(all_train_df.shape[0])

In [None]:
# all_train_df.to_csv('/kaggle/working/small_data.csv', index=True)


### Data Augmentation

Now, let's perform the data augmentations step by step:

#### 1. Unique Words Augmentation

In [None]:
!pip install nlpaug
!python3 -m nltk.downloader wordnet
!unzip /usr/share/nltk_data/corpora/wordnet.zip -d /usr/share/nltk_data/corpora

In [None]:
import nlpaug.augmenter.word as naw
import nlpaug.flow as naf
import nlpaug.augmenter.char as nac
import nlpaug.augmenter.sentence as nas

from nlpaug.util import Action
import random

In [None]:
# Synonym Augmentation using WordNet
synonym_aug = naw.SynonymAug(aug_src='wordnet')

# Contextual Word Embedding Augmentation using BERT for Insertion
insert_aug = naw.ContextualWordEmbsAug(model_path='bert-base-uncased', action="insert")

# Random Swap Augmentation
swap_aug = naw.RandomWordAug(action="swap")

# Random Deletion Augmentation
delete_aug = naw.RandomWordAug(action="delete")


In [None]:
def unique_words_augmentation(comment):
    words = comment.split()
    unique_words = list(set(words))
    return ' '.join(unique_words)

def random_mask(comment, mask_ratio=0.2):
    words = comment.split()
    num_words_to_mask = int(len(words) * mask_ratio)
    mask_indices = random.sample(range(len(words)), num_words_to_mask)
    masked_words = [word for idx, word in enumerate(words) if idx not in mask_indices]
    return ' '.join(masked_words)


In [None]:
# all_train_df = pd.read_csv("/kaggle/working/train_data_cleaned.csv", index_col=0)
# logger.success("Dataset loaded !")

In [None]:
all_train_df.head()

In [None]:

# Augment a subset of the data
augmented_data = []
# Update the toxic_label column with boolean values
all_train_df['toxic_label'] = (
    (all_train_df['toxic'] == 1) |
    (all_train_df['severe_toxic'] == 1) |
    (all_train_df['obscene'] == 1) |
    (all_train_df['threat'] == 1) |
    (all_train_df['insult'] == 1) |
    (all_train_df['identity_hate'] == 1)
).astype(int)

# Define the classes and their respective iteration sizes
classes = ['threat', 'identity_hate', 'severe_toxic']
iteration_sizes = [2000, 2000, 2000]  # Adjust these sizes based on your needs
for idx, class_name in enumerate(classes):
    toxic_comments = all_train_df[(all_train_df['toxic_label'] == 1) & (all_train_df[class_name] == 1)]

    for i in range(min(iteration_sizes[idx], len(toxic_comments))):  # Adjust based on the iteration size
        print(f"Augmenting {class_name}, Iteration: {i}")
        comment = toxic_comments['comment_text'].iloc[i]
        original_row = toxic_comments.iloc[i].to_dict()
        # Unique Words Augmentation
        unique_comment = unique_words_augmentation(comment)
        new_row = original_row.copy()
        new_row['comment_text'] = unique_comment
        augmented_data.append(new_row)

        # Random Mask
        masked_comment = random_mask(comment)
        new_row = original_row.copy()
        new_row['comment_text'] = masked_comment
        augmented_data.append(new_row)

        augmented_comment = synonym_aug.augment(comment)[0]
        new_row = original_row.copy()
        new_row['comment_text'] = augmented_comment
        augmented_data.append(new_row)

        augmented_comment = insert_aug.augment(comment)[0]
        new_row = original_row.copy()
        new_row['comment_text'] = augmented_comment
        augmented_data.append(new_row)

        augmented_comment = swap_aug.augment(comment)[0]
        new_row = original_row.copy()
        new_row['comment_text'] = augmented_comment
        augmented_data.append(new_row)

        augmented_comment = delete_aug.augment(comment)[0]
        new_row = original_row.copy()
        new_row['comment_text'] = augmented_comment
        augmented_data.append(new_row)

# Convert the augmented data to a DataFrame
augmented_df = pd.DataFrame(augmented_data)

In [None]:
# # Augment a subset of the data
# augmented_data = []

# # Filter for toxic comments (target >= threshold)
# toxic_comments = all_train_df[all_train_df['target'] >= threshold]

# for i in range(5000):  # Adjust this number based on your computational power
#     print(f"In Iteration: {i}")
#     comment = toxic_comments['comment_text'].iloc[i]
# #     print(f"Original comment: {comment}\n")

#     # Get the original row as a dictionary
#     original_row = toxic_comments.iloc[i].to_dict()

#     # Unique Words Augmentation
#     unique_comment = unique_words_augmentation(comment)
#     new_row = original_row.copy()
#     new_row['comment_text'] = unique_comment
#     augmented_data.append(new_row)
# #     print("Unique Words Augmentation: " + unique_comment)

#     # Random Mask
#     masked_comment = random_mask(comment)
#     new_row = original_row.copy()
#     new_row['comment_text'] = masked_comment
#     augmented_data.append(new_row)
# #     print("Random Mask Augmentation: " + masked_comment)

#     augmented_comment = synonym_aug.augment(comment)[0]
#     new_row = original_row.copy()
#     new_row['comment_text'] = augmented_comment
#     augmented_data.append(new_row)
# #     print("Synonym Replacement Augmentation: " + str(augmented_comment))

#     augmented_comment = insert_aug.augment(comment)[0]
#     new_row = original_row.copy()
#     new_row['comment_text'] = augmented_comment
#     augmented_data.append(new_row)
# #     print("Random Insert Augmentation: " + str(augmented_comment))

#     augmented_comment = swap_aug.augment(comment)[0]
#     new_row = original_row.copy()
#     new_row['comment_text'] = augmented_comment
#     augmented_data.append(new_row)
# #     print("Random Swap Augmentation: " + str(augmented_comment))

#     augmented_comment = delete_aug.augment(comment)[0]
#     new_row = original_row.copy()
#     new_row['comment_text'] = augmented_comment
#     augmented_data.append(new_row)
# #     print("Random Delete Augmentation: " + str(augmented_comment))
# #     print("\n\n")

# augmented_df = pd.DataFrame(augmented_data)

In [None]:
print(augmented_df.shape[0])

In [None]:
augmented_df.head()

In [None]:
# Combine with the original data
print(all_train_df.shape[0])
print(augmented_df.shape[0])

all_train_df = pd.concat([all_train_df, augmented_df], ignore_index=False)
print(all_train_df.shape[0])


In [None]:
all_train_df.head()

In [None]:
# all_train_df.iloc[159576]

In [None]:
all_train_df.to_csv('/kaggle/working/train_data_augumented_by_21k.csv', index=True)
# all_train_df = pd.read_csv("/kaggle/input/dataaugumented-cleaned/train_data_augumented_by_21k.csv")
# logger.success("Dataset loaded !")

In [None]:
# Check for distribution of different toxicity types
toxicity_columns = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
toxicity_distributions = all_train_df[toxicity_columns].sum()

# Plot the distribution of different toxicity types
plt.figure(figsize=(12, 8))
toxicity_distributions.plot(kind='bar')
plt.title('Distribution of Different Toxicity Types')
plt.xlabel('Toxicity Type')
plt.ylabel('Count')
plt.xticks(rotation=45)
plt.show()

# Print the number of samples for each toxicity type
for toxicity_type in toxicity_columns:
    count = (all_train_df[toxicity_type] > 0).sum()
    print(f'Number of samples for {toxicity_type}: {count}')

In [None]:
# Calculate the number and percentage of toxic and non-toxic comments
toxic_counts = all_train_df['toxic_label'].value_counts()
toxic_percentage = all_train_df['toxic_label'].value_counts(normalize=True) * 100

# Plot the pie chart
plt.figure(figsize=(8, 8))
plt.pie(toxic_counts, labels=['Non-Toxic', 'Toxic'], autopct=lambda p: f'{p:.1f}% ({int(p * sum(toxic_counts) / 100)})', startangle=140, colors=['skyblue', 'salmon'])
plt.title('Percentage and Number of Toxic and Non-Toxic Comments')
plt.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.
plt.show()

# Print the number and percentage values
print('Number of non-toxic comments:', toxic_counts[0])
print('Number of toxic comments:', toxic_counts[1])
print('Percentage of non-toxic comments: {:.2f}%'.format(toxic_percentage[0]))
print('Percentage of toxic comments: {:.2f}%'.format(toxic_percentage[1]))
print('Total Number: ', all_train_df.shape[0])

In [None]:
all_train_df.head()

# Dataset


In [None]:
train_df = all_train_df[~all_train_df.isna()]
# removing sample with labels equal to -1
train_df = train_df.loc[train_df['toxic'] >= 0]
train_df.reset_index(inplace=True)

In [None]:
# Sample 10,000 rows for validation
validation_df = train_df.sample(n=10_000, random_state=42)

# Remove the validation samples from the training dataset
train_df = train_df.drop(validation_df.index)

train_df[LABEL_LIST] = (train_df[LABEL_LIST]>=0.5).astype(int)
validation_df[LABEL_LIST] = (validation_df[LABEL_LIST]>=0.5).astype(int)

In [None]:
train_df.head()

In [None]:
train_df.shape[0]

In [None]:
validation_df.shape[0]

In [None]:
class JigsawDataset(Dataset):
    def __init__(self, data_df, tokenizer):
        self.data = data_df
        self.tokenizer = tokenizer
        self.labels_present = all(label in data_df.columns for label in LABEL_LIST)
        if not self.labels_present:
            # Add columns for labels if not present
            for label in LABEL_LIST:
                self.data[label] = 0

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        comment = self.data.iloc[index]["comment_text"]
        label = torch.tensor(self.data.iloc[index][LABEL_LIST].tolist(), dtype=torch.float)

        token_list, attention_mask = self.text_to_token_and_mask(comment)

        return dict(index=index, ids=token_list, mask=attention_mask, labels=label)

    def text_to_token_and_mask(self, input_text):
        tokenization_dict = tokenizer.encode_plus(input_text,
                                add_special_tokens=True,
                                max_length=128,
                                padding='max_length',
                                truncation=True,
                                return_attention_mask=True,
                                return_tensors='pt')
        token_list = tokenization_dict["input_ids"].flatten()
        attention_mask = tokenization_dict["attention_mask"].flatten()
        return (token_list, attention_mask)

In [None]:
def set_lr(optim, lr):
    '''
    Set the learning rate in the optimizer
    '''
    for g in optim.param_groups:
        g['lr'] = lr
    return optim

In [None]:
# Transformer class and functions for models and predictions

class TransformerClassifierStack(nn.Module):
    def __init__(self, tr_model, nb_labels, dropout_prob=0.4, freeze=False):
        super().__init__()
        self.tr_model = tr_model

        # Stack features of 4 last encoders
        self.hidden_dim = tr_model.config.hidden_size * 4

        # hidden linear for the classification
        self.dropout = nn.Dropout(dropout_prob)
        self.hl = nn.Linear(self.hidden_dim, self.hidden_dim)

        # Last Linear for the classification
        self.last_l = nn.Linear(self.hidden_dim, nb_labels)

        # freeze all the parameters if necessary
        for param in self.tr_model.parameters():
            param.requires_grad = not freeze

        # init learning params of last layers
        torch.nn.init.xavier_uniform_(self.hl.weight)
        torch.nn.init.xavier_uniform_(self.last_l.weight)

    def forward(self, ids, mask):
        # ids = [batch_size, padded_seq_len]
        # mask = [batch_size, padded_seq_len]
        # mask: avoid to make self attention on padded data
        tr_output = self.tr_model(input_ids=ids,
                                  attention_mask=mask,
                                  output_hidden_states=True)

        # Get all the hidden states
        hidden_states = tr_output['hidden_states']

        # hs_* = [batch_size, padded_seq_len, 768]
        hs_1 = hidden_states[-1][:, 0, :]
        hs_2 = hidden_states[-2][:, 0, :]
        hs_3 = hidden_states[-3][:, 0, :]
        hs_4 = hidden_states[-4][:, 0, :]

        # features_vec = [batch_size, 768 * 4]
        features_vec = torch.cat([hs_1, hs_2, hs_3, hs_4], dim=-1)

        x = self.dropout(features_vec)
        x = self.hl(x)

        # x = [batch_size, 768 * 4]
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.last_l(x)

        # x = [batch_size, 1]
        return x

def load_roberta_model(nb_labels):
    '''
    Load RoBERTa model without any checkpoint
    RoBERTa for finetuning
    '''
    logger.info(f"transformers.RobertaTokenizer : roberta-base")
    logger.info(f"transformers.AutoModel : roberta-base")
    tokenizer = transformers.RobertaTokenizer.from_pretrained('roberta-base')
    tr_model = transformers.AutoModel.from_pretrained('roberta-base')
    model = TransformerClassifierStack(tr_model, nb_labels, freeze=True)
    return model, tokenizer


def load_roberta_pretrained(path, nb_labels, lr=2e-5):
    '''
    Load RoBERTa from checkout point (already trained on Hate Speech tasks)
    '''
    tokenizer = transformers.RobertaTokenizer.from_pretrained('roberta-base')
    tr_model = transformers.AutoModel.from_pretrained('roberta-base')
    model = TransformerClassifierStack(tr_model, nb_labels)

    loaded = torch.load(path)
    model.load_state_dict(loaded['state_dict'])

    optimizer = transformers.AdamW(model.parameters(), lr=lr)
    optimizer.load_state_dict(loaded['optimizer_dict'])
    optimizer = set_lr(optimizer, lr)

    return model, tokenizer, optimizer

def preds_fn(batch, model, device):
    '''
    Get the predictions for one batch according to the model
    '''
    b_input = batch['ids'].to(device)
    b_mask = batch['mask'].to(device)

    return model(b_input, b_mask)

In [None]:
# Load the model
model, tokenizer = load_roberta_model(nb_labels=len(LABEL_LIST))
logger.success("Model loaded !")

In [None]:
BATCH_SIZE = 32
LR=1e-4
PIN_MEMORY = True
NUM_WORKERS = 0
PREFETCH_FACTOR = "None"
NUM_EPOCHS = 3
logger.info(f"{BATCH_SIZE=}")
logger.info(f"{LR=}")
logger.info(f"{PIN_MEMORY=}")
logger.info(f"{NUM_WORKERS=}")
logger.info(f"{PREFETCH_FACTOR=}")
logger.info(f"{NUM_EPOCHS=}")

### **Losses**

In [None]:
class FocalLoss(nn.Module):
    def __init__(self,
                 gamma: float = 2,
                 reduction: str = "mean",
                 pos_weight: torch.Tensor = None):
        super(FocalLoss, self).__init__()
        self.gamma= gamma
        self.reduction = reduction
        self.pos_weight = pos_weight

    def forward(self, inputs: torch.Tensor,
                targets: torch.Tensor):
        p = torch.sigmoid(inputs)
        ce_loss = F.binary_cross_entropy_with_logits(
            inputs, targets, reduction="none", pos_weight=self.pos_weight
        )
        p_t =  p * targets + (1 - p) * (1 - targets)
        loss = ce_loss * ((1 - p_t) ** self.gamma)

        if self.reduction == "mean":
            loss = loss.mean()
        elif self.reduction == "sum":
            loss = loss.sum()

        return loss

In [None]:

class ResampleLoss(nn.Module):

    def __init__(self,
                 use_sigmoid=True, partial=False,
                 loss_weight=1.0, reduction='mean',
                 reweight_func=None,  # None, 'inv', 'sqrt_inv', 'rebalance', 'CB'
                 weight_norm=None, # None, 'by_instance', 'by_batch'
                 focal=dict(
                     focal=True,
                     alpha=0.5,
                     gamma=2,
                 ),
                 map_param=dict(
                     alpha=10.0,
                     beta=0.2,
                     gamma=0.1
                 ),
                 CB_loss=dict(
                     CB_beta=0.9,
                     CB_mode='average_w'  # 'by_class', 'average_n', 'average_w', 'min_n'
                 ),
                 logit_reg=dict(
                     neg_scale=5.0,
                     init_bias=0.1
                 ),
                 class_freq=None,
                 train_num=None):
        super(ResampleLoss, self).__init__()

        assert (use_sigmoid is True) or (partial is False)
        self.use_sigmoid = use_sigmoid
        self.partial = partial
        self.loss_weight = loss_weight
        self.reduction = reduction
        if self.use_sigmoid:
            if self.partial:
                raise RuntimeError("Not defined here")
                self.cls_criterion = partial_cross_entropy
            else:
                self.cls_criterion = binary_cross_entropy
        else:
            raise RuntimeError("Not defined here")
            self.cls_criterion = cross_entropy

        # reweighting function
        self.reweight_func = reweight_func

        # normalization (optional)
        self.weight_norm = weight_norm

        # focal loss params
        self.focal = focal['focal']
        self.gamma = focal['gamma']
        self.alpha = focal['alpha'] # change to alpha

        # mapping function params
        self.map_alpha = map_param['alpha']
        self.map_beta = map_param['beta']
        self.map_gamma = map_param['gamma']

        # CB loss params (optional)
        self.CB_beta = CB_loss['CB_beta']
        self.CB_mode = CB_loss['CB_mode']

        self.class_freq = torch.from_numpy(np.asarray(class_freq)).float().cuda()
        self.num_classes = self.class_freq.shape[0]
        self.train_num = train_num # only used to be divided by class_freq
        # regularization params
        self.logit_reg = logit_reg
        self.neg_scale = logit_reg[
            'neg_scale'] if 'neg_scale' in logit_reg else 1.0
        init_bias = logit_reg['init_bias'] if 'init_bias' in logit_reg else 0.0

        self.init_bias = - torch.log(
            self.train_num / self.class_freq - 1) * init_bias

        self.freq_inv = torch.ones(self.class_freq.shape).cuda() / self.class_freq
        self.propotion_inv = self.train_num / self.class_freq

    def forward(self,
                cls_score,
                label,
                weight=None,
                avg_factor=None,
                reduction_override=None,
                **kwargs):

        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)

        weight = self.reweight_functions(label)

        cls_score, weight = self.logit_reg_functions(label.float(), cls_score, weight)

        if self.focal:
            logpt = self.cls_criterion(
                cls_score.clone(), label, weight=None, reduction='none',
                avg_factor=avg_factor)
            # pt is sigmoid(logit) for pos or sigmoid(-logit) for neg
            pt = torch.exp(-logpt)
            wtloss = self.cls_criterion(
                cls_score, label.float(), weight=weight, reduction='none')
            alpha_t = torch.where(label==1, self.alpha, 1-self.alpha)
            loss = alpha_t * ((1 - pt) ** self.gamma) * wtloss # balance_param should be a tensor
            loss = reduce_loss(loss, reduction)             # add reduction
        else:
            loss = self.cls_criterion(cls_score, label.float(), weight,
                                      reduction=reduction)

        loss = self.loss_weight * loss
        return loss

    def reweight_functions(self, label):
        if self.reweight_func is None:
            return None
        elif self.reweight_func in ['inv', 'sqrt_inv']:
            weight = self.RW_weight(label.float())
        elif self.reweight_func in 'rebalance':
            weight = self.rebalance_weight(label.float())
        elif self.reweight_func in 'CB':
            weight = self.CB_weight(label.float())
        else:
            return None

        if self.weight_norm is not None:
            if 'by_instance' in self.weight_norm:
                max_by_instance, _ = torch.max(weight, dim=-1, keepdim=True)
                weight = weight / max_by_instance
            elif 'by_batch' in self.weight_norm:
                weight = weight / torch.max(weight)

        return weight

    def logit_reg_functions(self, labels, logits, weight=None):
        if not self.logit_reg:
            return logits, weight
        if 'init_bias' in self.logit_reg:
            logits += self.init_bias
        if 'neg_scale' in self.logit_reg:
            logits = logits * (1 - labels) * self.neg_scale  + logits * labels
            if weight is not None:
                weight = weight / self.neg_scale * (1 - labels) + weight * labels
        return logits, weight

    def rebalance_weight(self, gt_labels):
        repeat_rate = torch.sum( gt_labels.float() * self.freq_inv, dim=1, keepdim=True)
        pos_weight = self.freq_inv.clone().detach().unsqueeze(0) / repeat_rate
        # pos and neg are equally treated
        weight = torch.sigmoid(self.map_beta * (pos_weight - self.map_gamma)) + self.map_alpha
        return weight

    def CB_weight(self, gt_labels):
        if  'by_class' in self.CB_mode:
            weight = torch.tensor((1 - self.CB_beta)).cuda() / \
                     (1 - torch.pow(self.CB_beta, self.class_freq)).cuda()
        elif 'average_n' in self.CB_mode:
            avg_n = torch.sum(gt_labels * self.class_freq, dim=1, keepdim=True) / \
                    torch.sum(gt_labels, dim=1, keepdim=True)
            weight = torch.tensor((1 - self.CB_beta)).cuda() / \
                     (1 - torch.pow(self.CB_beta, avg_n)).cuda()
        elif 'average_w' in self.CB_mode:
            weight_ = torch.tensor((1 - self.CB_beta)).cuda() / \
                      (1 - torch.pow(self.CB_beta, self.class_freq)).cuda()
            weight = torch.sum(gt_labels * weight_, dim=1, keepdim=True) / \
                     torch.sum(gt_labels, dim=1, keepdim=True)
        elif 'min_n' in self.CB_mode:
            min_n, _ = torch.min(gt_labels * self.class_freq +
                                 (1 - gt_labels) * 100000, dim=1, keepdim=True)
            weight = torch.tensor((1 - self.CB_beta)).cuda() / \
                     (1 - torch.pow(self.CB_beta, min_n)).cuda()
        else:
            raise NameError
        return weight

    def RW_weight(self, gt_labels, by_class=True):
        if 'sqrt' in self.reweight_func:
            weight = torch.sqrt(self.propotion_inv)
        else:
            weight = self.propotion_inv
        if not by_class:
            sum_ = torch.sum(weight * gt_labels, dim=1, keepdim=True)
            weight = sum_ / torch.sum(gt_labels, dim=1, keepdim=True)
        return weight


def reduce_loss(loss, reduction):
    """Reduce loss as specified.
    Args:
        loss (Tensor): Elementwise loss tensor.
        reduction (str): Options are "none", "mean" and "sum".
    Return:
        Tensor: Reduced loss tensor.
    """
    reduction_enum = F._Reduction.get_enum(reduction)
    # none: 0, elementwise_mean:1, sum: 2
    if reduction_enum == 0:
        return loss
    elif reduction_enum == 1:
        return loss.mean()
    elif reduction_enum == 2:
        return loss.sum()


def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
    """Apply element-wise weight and reduce loss.
    Args:
        loss (Tensor): Element-wise loss.
        weight (Tensor): Element-wise weights.
        reduction (str): Same as built-in losses of PyTorch.
        avg_factor (float): Avarage factor when computing the mean of losses.
    Returns:
        Tensor: Processed loss values.
    """
    # if weight is specified, apply element-wise weight
    if weight is not None:
        loss = loss * weight

    # if avg_factor is not specified, just reduce the loss
    if avg_factor is None:
        loss = reduce_loss(loss, reduction)
    else:
        # if reduction is mean, then average the loss by avg_factor
        if reduction == 'mean':
            loss = loss.sum() / avg_factor
        # if reduction is 'none', then do nothing, otherwise raise an error
        elif reduction != 'none':
            raise ValueError('avg_factor can not be used with reduction="sum"')
    return loss


def binary_cross_entropy(pred,
                         label,
                         weight=None,
                         reduction='mean',
                         avg_factor=None):

    # weighted element-wise losses
    if weight is not None:
        weight = weight.float()

    loss = F.binary_cross_entropy_with_logits(
        pred, label.float(), weight, reduction='none')
    loss = weight_reduce_loss(loss, reduction=reduction, avg_factor=avg_factor)

    return loss

In [None]:
def get_label_weights_bce(df, classes=LABEL_LIST):
    weights = torch.empty((len(classes),))

    nb_samples = len(df)

    for idx, c in enumerate(classes):
        nb_zeros = len(df[df[c] == 0])
        nb_ones = nb_samples - nb_zeros
        weights[idx] = nb_zeros / nb_ones

    return weights

def get_label_inv_freq(df, classes=LABEL_LIST):
    weights = torch.empty((len(classes),))
    nb_samples = len(df)

    for idx, c in enumerate(classes):
        nb_zeros = len(df[df[c] == 0])
        weights[idx] = (nb_zeros / nb_samples)

    return weights

def get_nb_samples_lab(df, classes=LABEL_LIST):
    nb_ones_tot, nb_zeros_tot = [], []
    nb_tot = len(df)

    for c in classes:
        nb_zeros = len(df[df[c] == 0])
        nb_ones = nb_tot - nb_zeros

        nb_ones_tot.append(nb_ones)
        nb_zeros_tot.append(nb_zeros)

    return torch.tensor(nb_ones_tot), torch.tensor(nb_zeros_tot)

## Data and Data Loaders

In [None]:
train_dataset = JigsawDataset(train_df, tokenizer)
train_dataloader = DataLoader(train_dataset,
                             batch_size=BATCH_SIZE,
                             shuffle=True,
                             num_workers=NUM_WORKERS,
                             pin_memory=PIN_MEMORY)

validation_dataset = JigsawDataset(validation_df, tokenizer)
validation_dataloader = DataLoader(validation_dataset,
                             batch_size=BATCH_SIZE,
                             shuffle=True,
                             num_workers=NUM_WORKERS,
                             pin_memory=PIN_MEMORY)


criterion = FocalLoss()
logger.info(criterion)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
logger.info(optimizer)

model.to(device)
criterion.to(device)

In [None]:
class HammingLossWithoutThreshold(Metric):
    def __init__(self, num_classes=1, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.num_classes = num_classes

        self.add_state("total", default=torch.tensor(0, dtype=torch.float32), dist_reduce_fx="sum")
        self.add_state("nbr_sample", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        current_nbr_sample, current_nbr_category = preds.shape
        if current_nbr_category != self.num_classes:
          raise AttributeError("`num_classes` != `current_nbr_category` detected in `pred` parameter")

        current_loss_per_pred = torch.absolute(target - preds)
        current_hamming_loss = current_loss_per_pred.sum()

        self.total += current_hamming_loss.float()
        self.nbr_sample += current_nbr_sample

    def compute(self):
        return self.total/(self.num_classes*self.nbr_sample)

In [None]:
class RebalancedHammingLossWithoutThreshold(Metric):
    def __init__(self, num_classes=1, average="macro", dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.num_classes = num_classes

        # average = "macro" or None
        self.average = average

        # Nombre de positif 1 & negatif 0 par categorie
        self.add_state(
            "number_positive",
            default=torch.tensor([0 for _ in range(num_classes)]),
            dist_reduce_fx="sum",
        )
        self.add_state(
            "number_negative",
            default=torch.tensor([0 for _ in range(num_classes)]),
            dist_reduce_fx="sum",
        )

        self.add_state(
            "hamming_loss_positive",
            default=torch.tensor([0.0 for _ in range(num_classes)]),
            dist_reduce_fx="sum",
        )
        self.add_state(
            "hamming_loss_negative",
            default=torch.tensor([0.0 for _ in range(num_classes)]),
            dist_reduce_fx="sum",
        )

        self.add_state("nbr_sample", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        current_nbr_sample, current_nbr_category = preds.shape
        if current_nbr_category != self.num_classes:
            raise AttributeError(
                "`num_classes` != `current_nbr_category` detected in `pred` parameter"
            )

        # Nombre de positif 1 & negatif 0 par categorie
        current_number_positive = target.sum(axis=0)
        current_number_negative = current_nbr_sample - target.sum(axis=0)

        self.number_positive += current_number_positive.int()
        self.number_negative += current_number_negative.int()

        self.nbr_sample += current_nbr_sample

        for class_id in range(self.num_classes):
            positive_filter = target[:, class_id] == 1
            negative_filter = target[:, class_id] == 0

            target_vector = target[:, class_id]
            preds_vector = preds[:, class_id]

            # Filtered vector
            ## Target
            pos_filtered_target_vector = target_vector[positive_filter]
            neg_filtered_target_vector = target_vector[negative_filter]
            ## Preds
            pos_filtered_preds_vector = preds_vector[positive_filter]
            neg_filtered_preds_vector = preds_vector[negative_filter]

            # Hamming Loss without Threshold
            hamming_loss_on_positive = torch.absolute(
                pos_filtered_target_vector - pos_filtered_preds_vector
            )
            hamming_loss_on_negative = torch.absolute(
                neg_filtered_target_vector - neg_filtered_preds_vector
            )

            self.hamming_loss_positive[class_id] += hamming_loss_on_positive.sum()
            self.hamming_loss_negative[class_id] += hamming_loss_on_negative.sum()

    def compute(self):
        factor_pos = self.nbr_sample / (2 * self.number_positive)
        factor_neg = self.nbr_sample / (2 * self.number_negative)

        rebalanced_hamming_loss_per_class = torch.multiply(
            self.hamming_loss_positive, factor_pos
        ) + torch.multiply(self.hamming_loss_negative, factor_neg)
        if self.average == "macro":
            return rebalanced_hamming_loss_per_class.sum() / (
                self.nbr_sample * self.num_classes
            )
        return rebalanced_hamming_loss_per_class / (self.nbr_sample)

In [None]:
num_classes = len(LABEL_LIST)
print(num_classes)
train_metric_dict = dict()

# AUROC Macro
auroc_macro = AUROC(task="multilabel",num_labels=num_classes, average="macro")
# auroc_macro = AUROC("MULTICLASS",num_classes=num_classes, average="macro")
train_metric_dict["auroc_macro"] = auroc_macro

# # AUROC per class
# auroc_per_class = AUROC("MULTICLASS",num_classes=num_classes, average=None)
auroc_per_class = AUROC(task="multilabel",num_labels=num_classes, average=None)
train_metric_dict["auroc_per_class"] = auroc_per_class


# # F1 score global
# F1 score per class
f1 = F1Score(task="multilabel",num_labels=6, average="macro")
train_metric_dict["f1"] = f1

# F1 score per class
f1_per_calss = F1Score(task="multilabel",num_labels=6, average=None)
train_metric_dict["f1_per_calss"] = f1_per_calss

# Hamming Distance without Threshold
hamming_loss_woutt = HammingLossWithoutThreshold(num_classes=num_classes)
train_metric_dict["hamming_loss_without_threshold"] = hamming_loss_woutt

# Rebalanced Hamming Distance without Threshold macro
rebalanced_hamming_loss_woutt_macro = RebalancedHammingLossWithoutThreshold(
    num_classes=num_classes, average="macro"
)
train_metric_dict[
    "rebalanced_hamming_loss_without_threshold_macro"
] = rebalanced_hamming_loss_woutt_macro

# Rebalanced Hamming Distance without Threshold macro
rebalanced_hamming_loss_woutt_per_class = RebalancedHammingLossWithoutThreshold(
    num_classes=num_classes, average=None
)
train_metric_dict[
    "rebalanced_hamming_loss_without_threshold_per_class"
] = rebalanced_hamming_loss_woutt_per_class

In [None]:
train_metric = MetricCollection(train_metric_dict)
train_metric.to(device)

validation_metric = train_metric.clone()
validation_metric.to(device)

In [None]:
def serialize(object_to_serialize: Any, ensure_ascii: bool = True) -> str:
    """
    Serialize any object, i.e. convert an object to JSON
    Args:
        object_to_serialize (Any): The object to serialize
        ensure_ascii (bool, optional): If ensure_ascii is true (the default), the output is guaranteed to have all incoming non-ASCII characters escaped. If ensure_ascii is false, these characters will be output as-is. Defaults to True.
    Returns:
            str: string of serialized object (JSON)
    """

    def dumper(obj: Any) -> Union[str, Dict]:
        """
        Function called recursively by json.dumps to know how to serialize an object.
        For example, for datetime, we try to convert it to ISO format rather than
        retrieve the list of attributes defined in its object.
        Args:
            obj (Any): The object to serialize
        Returns:
            Union[str, Dict]: Serialized object
        """
        if isinstance(obj, torch.Tensor):
            return obj.cpu().numpy().tolist()
        elif hasattr(obj, "__dict__"):
            return obj.__dict__
        return str(obj)

    return json.dumps(object_to_serialize, default=dumper, ensure_ascii=ensure_ascii)

In [None]:
def export_metric(metric_collection, **kwargs):
    """
    Export MetricCollection to json file

    Args:
        metric_collection: MetricCollection
        **kwargs: field to add in json line
    """
    with open(METRIC_FILE_PATH, "a") as f:
        metric_collection_value = metric_collection.compute()
        metric_collection_value.update(kwargs)
        serialized_value = serialize(metric_collection_value)
        f.write(serialized_value)
        f.write("\n")
    logger.success("Metrics are exported !")

In [None]:
def train_epoch(epoch_id=None):
    model.train()
    logger.info(f"START EPOCH {epoch_id=}")

    progress = tqdm(train_dataloader, desc='training batch...', leave=False)
    for batch_id, batch in enumerate(progress):
        if batch_id % 1_000 == 0:
            valid_epoch(epoch_id=epoch, batch_id=batch_id)

        logger.trace(f"{batch_id=}")
        token_list_batch = batch["ids"].to(device)
        attention_mask_batch = batch["mask"].to(device)
        label_batch = batch["labels"].to(device)

        # Reset gradient
        optimizer.zero_grad()

        # Predict
        prediction_batch = model(token_list_batch, attention_mask_batch)
        transformed_prediction_batch = prediction_batch.squeeze()

        # Loss
        loss = criterion(transformed_prediction_batch.to(torch.float32), label_batch.to(torch.float32))

        # Metrics
        proba_prediction_batch = torch.sigmoid(transformed_prediction_batch)
        train_metrics_collection_dict = train_metric(proba_prediction_batch.to(torch.float32), label_batch.to(torch.int32))
        logger.trace(train_metrics_collection_dict)

        # Backprop
        loss.backward()
        # gradient clip
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        # Update progress bar description
        progress_description = "Train Loss : {loss:.4f} - Train AUROC : {acc:.4f}"
        auc_roc = float(train_metrics_collection_dict["auroc_macro"])
        progress_description = progress_description.format(loss=loss.item(), acc=auc_roc)
        progress.set_description(progress_description)

    logger.info(f"END EPOCH {epoch_id=}")

In [None]:
@torch.no_grad()
def valid_epoch(epoch_id=None, batch_id=None):
    model.eval()
    logger.info(f"START VALIDATION {epoch_id=}{batch_id=}")
    validation_metric.reset()

    loss_list = []
    prediction_list = torch.Tensor([])
    target_list = torch.Tensor([])


    progress = tqdm(validation_dataloader, desc="valid batch...", leave=False)
    for _, batch in enumerate(progress):

        token_list_batch = batch["ids"].to(device)
        attention_mask_batch = batch["mask"].to(device)
        label_batch = batch["labels"].to(device)

        # Predict
        prediction_batch = model(token_list_batch, attention_mask_batch)

        transformed_prediction_batch = prediction_batch.squeeze()

        # Loss
        loss = criterion(
            transformed_prediction_batch.to(torch.float32),
            label_batch.to(torch.float32),
        )

        loss_list.append(loss.item())

        proba_prediction_batch = torch.sigmoid(transformed_prediction_batch)
        prediction_list = torch.concat(
            [prediction_list, proba_prediction_batch.cpu()]
        )
        target_list = torch.concat([target_list, label_batch.cpu()])

        # Metrics
        validation_metric(proba_prediction_batch.to(torch.float32), label_batch.to(torch.int32))

    loss_mean = np.mean(loss_list)
    logger.trace(validation_metric.compute())
    logger.info(f"END VALIDATION {epoch_id=}{batch_id=}")
    export_metric(validation_metric, epoch_id=epoch_id, batch_id=batch_id, loss=loss_mean)@torch.no_grad()
def valid_epoch(epoch_id=None, batch_id=None):
    model.eval()
    logger.info(f"START VALIDATION {epoch_id=}{batch_id=}")
    validation_metric.reset()

    loss_list = []
    prediction_list = torch.Tensor([])
    target_list = torch.Tensor([])


    progress = tqdm(validation_dataloader, desc="valid batch...", leave=False)
    for _, batch in enumerate(progress):

        token_list_batch = batch["ids"].to(device)
        attention_mask_batch = batch["mask"].to(device)
        label_batch = batch["labels"].to(device)

        # Predict
        prediction_batch = model(token_list_batch, attention_mask_batch)

        transformed_prediction_batch = prediction_batch.squeeze()

        # Loss
        loss = criterion(
            transformed_prediction_batch.to(torch.float32),
            label_batch.to(torch.float32),
        )

        loss_list.append(loss.item())

        proba_prediction_batch = torch.sigmoid(transformed_prediction_batch)
        prediction_list = torch.concat(
            [prediction_list, proba_prediction_batch.cpu()]
        )
        target_list = torch.concat([target_list, label_batch.cpu()])

        # Metrics
        validation_metric(proba_prediction_batch.to(torch.float32), label_batch.to(torch.int32))

    loss_mean = np.mean(loss_list)
    logger.trace(validation_metric.compute())
    logger.info(f"END VALIDATION {epoch_id=}{batch_id=}")
    export_metric(validation_metric, epoch_id=epoch_id, batch_id=batch_id, loss=loss_mean)

In [None]:
torch.cuda.empty_cache()
progress =  tqdm(range(1,NUM_EPOCHS+1), desc='training epoch...', leave=True)
for epoch in progress:
    # Train
    train_epoch(epoch_id=epoch)

    # Validation
    valid_epoch(epoch_id=epoch)

    # Save
    torch.save(model, MODEL_FILE_PATH)

# Evaluation

In [None]:
# try:
#     del train_df
#     del validation_df
# except NameError:
#     logger.warning("Train DataFrame & Validation DataFrame already deleted")