# 🙊Toxic comments with Lightning⚡Flash

[Flash](https://lightning-flash.readthedocs.io/en/stable) makes complex AI recipes for over 15 tasks across 7 data domains accessible to all.

In a nutshell, Flash is the production grade research framework you always dreamed of but didn't have time to build.

In [None]:
# ! pip install -q lightning-flash[text]
# ! pip install -q 'https://github.com/PyTorchLightning/lightning-flash/archive/refs/heads/master.zip#egg=lightning-flash[text]'
! pip install -q 'https://github.com/PyTorchLightning/lightning-flash/archive/refs/heads/fix/serialize_tokenizer.zip#egg=lightning-flash[text]'
! pip install -q mplfinance
! pip install -q --upgrade pandas --force-reinstall
! pip list | grep -E "lightning|torch"

In [None]:
# ! pip download -q lightning-flash[text] --prefer-binary --dest frozen_packages
! pip wheel -q 'https://github.com/PyTorchLightning/lightning-flash/archive/refs/heads/fix/serialize_tokenizer.zip#egg=lightning-flash[text]' --wheel-dir frozen_packages
! rm frozen_packages/torch-*
! ls frozen_packages

In [None]:
! ls /kaggle/input/jigsaw-multilingual-toxic-comment-classification

## Data exolorations & preparation

Checking the input data and pairing with Crypto names

In [None]:
import pandas as pd

csv_train = "/kaggle/input/jigsaw-multilingual-toxic-comment-classification/jigsaw-toxic-comment-train.csv"
df_train = pd.read_csv(csv_train)
display(df_train.head())

_= df_train.plot.hist(bins=2, grid=True, sharex=True, logy=True)

In [None]:
csv_comemnts = "/kaggle/input/jigsaw-toxic-severity-rating/comments_to_score.csv"
df_comments = pd.read_csv(csv_comemnts)
display(df_comments.head())

### ToDo

Consider some label aggregation for this competition...

In [None]:
df_train["sum"] = df_train[["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]].sum(axis=1)
df_train["any"] = df_train["sum"].gt(0).astype(int)
_= df_train["any"].plot.hist(bins=2, grid=True)

# Training with Flash Lightning

See the classification docs: https://lightning-flash.readthedocs.io/en/stable/reference/text_classification.html

In [None]:
import torch

import flash
from flash.text import TextClassificationData, TextClassifier

### 1. Create the DataModule

In [None]:
datamodule = TextClassificationData.from_data_frame(
    input_field="comment_text",
    target_fields="any",  # "toxic",
    train_data_frame=df_train,
    val_data_frame=df_train,
    backbone="xlm-roberta-base",
    batch_size=64,
    num_workers=0,
)

### 2. Build the task

In [None]:
from torchmetrics import F1, Precision

model = TextClassifier(
    backbone=datamodule.backbone,
    num_classes=datamodule.num_classes,
    metrics=[Precision(), F1()],
)
model.model.save_pretrained("./used-HF-model")
! ls -l ./used-HF-model

### 3. Create the trainer and finetune the model

In [None]:
import torch
from pytorch_lightning.loggers import CSVLogger
# from pytorch_lightning.callbacks import StochasticWeightAveraging

# swa = StochasticWeightAveraging(swa_epoch_start=0.6)
logger = CSVLogger(save_dir='logs/')
trainer = flash.Trainer(
    max_epochs=10,
    logger=logger,
    gpus=torch.cuda.device_count(),
    # callbacks=[swa],
    accumulate_grad_batches=12,
    gradient_clip_val=0.1,
    precision=16,
    # enable_ort=True,  # if you have PT>=1.5
    auto_lr_find=True,
)

trainer.tune(model, datamodule=datamodule, lr_find_kwargs=dict(min_lr=1e-5, max_lr=0.1, num_training=65),)
print(f"Learning Rate: {model.learning_rate}")

trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# Save the model!
trainer.save_checkpoint("text_classification_model.pt")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
del metrics["step"]
metrics.set_index("epoch", inplace=True)
display(metrics.dropna(axis=1, how="all").head())
g = sns.relplot(data=metrics, kind="line")
plt.gcf().set_size_inches(15, 5)

In [None]:
# ! ls -l ~/.cache/huggingface/
# ! mkdir -p cache/huggingface
# ! rsync -ahv ~/.cache/huggingface/ cache/huggingface --exclude="*.lock"
# ! ls -l cache/huggingface

### 4. Classify new comments

In [None]:
import math
from flash.core.classification import Logits, Probabilities
from tqdm.auto import tqdm

model.output = Logits()
# predictions = model.predict(df_comments["text"])

predictions = []
for i in tqdm(range(math.ceil(len(df_comments) / datamodule.batch_size))):
    batch = df_comments["text"][i * datamodule.batch_size:(i + 1) * datamodule.batch_size]
    predictions += model.predict(batch)

print(f"inputs={len(df_comments)} ; preds={len(predictions)}")
print(predictions[0])

In [None]:
import numpy as np

predictions = np.array(predictions)[:, -1]
_= plt.hist(predictions, bins=25)

In [None]:
df_submit = pd.DataFrame(zip(df_comments["comment_id"], predictions), columns=("comment_id", "score"))
df_submit.set_index("comment_id", inplace=True)
df_submit.to_csv("submission.csv")

! head submission.csv