In [None]:
from __future__ import annotations

import os
import re
import typing as t
from pathlib import Path

import pandas as pd
import torch
import torch.nn.functional as torch_f
import typing_extensions as t_ext
import wandb
from bs4 import BeautifulSoup
from torch.optim import Optimizer, AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
from transformers.models.auto.modeling_auto import AutoModel
from transformers.models.auto.tokenization_auto import AutoTokenizer

In [None]:
class TextCleaner:
    _RE_WEBSITE_LINK = re.compile(r'https?://\S+|www\.\S+')
    _RE_EMOJI = re.compile('['
        u'\U0001F600-\U0001F64F'  # emoticons
        u'\U0001F300-\U0001F5FF'  # symbols & pictographs
        u'\U0001F680-\U0001F6FF'  # transport & map symbols
        u'\U0001F1E0-\U0001F1FF'  # flags (iOS)
        u'\U00002702-\U000027B0'
        u'\U000024C2-\U0001F251'
        ']+', flags=re.UNICODE)
    _RE_SPECIAL_CHARACTERS = re.compile(r'[^a-zA-Z\d]')
    _RE_EXTRA_SPACES = re.compile(r' +')

    def __init__(self):
        pass

    def clean(self, text: str) -> str:
        """
        Cleans text into a basic form for NLP. Operations include the following:-
        1. Remove special charecters like &, #, etc
        2. Removes extra spaces
        3. Removes embedded URL links
        4. Removes HTML tags
        5. Removes emojis
        
        text - Text piece to be cleaned.
        """
        text = self._RE_WEBSITE_LINK.sub(r'', text)
        
        soup = BeautifulSoup(text, 'lxml')  # Removes HTML tags
        only_text = soup.get_text()
        text = only_text

        text = self._RE_EMOJI.sub(r'', text)
        
        text = self._RE_SPECIAL_CHARACTERS.sub(" ", text)  # Remove special Charecters
        text = self._RE_EXTRA_SPACES.sub(' ', text)  # Remove Extra Spaces
        text = text.strip()  # Remove spaces at the beginning and at the end of string

        return text

In [None]:
class _TokenizedText(t_ext.TypedDict):
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
    token_type_ids: torch.Tensor


def _preprocess_tokenizer_output(output: t.Dict[str, t.Any]) -> _TokenizedText:
    return {
        'input_ids': torch.tensor(output['input_ids']),
        'attention_mask': torch.tensor(output['attention_mask']),
        'token_type_ids': torch.tensor(output['token_type_ids']),
    }


class TrainDataset(Dataset):

    def __init__(self, df: pd.DataFrame, tokenizer: AutoTokenizer, max_len: int):
        super().__init__()
        self._df = df
        self._tokenizer = tokenizer
        self._max_len = max_len
        self._cleaner = TextCleaner()

    def __len__(self) -> int:
        return len(self._df)

    def __getitem__(self, idx: int) -> t.Tuple[_TokenizedText, torch.Tensor]:
        record = self._df.iloc[idx]
        tokenized_text = _preprocess_tokenizer_output(self._tokenizer(
            self._cleaner.clean(str(record['text'])),
            add_special_tokens=True,
            truncation=True,
            padding='max_length',
            max_length=self._max_len,
            return_attention_mask=True,
            return_token_type_ids=True))  # type: ignore
        return tokenized_text, torch.tensor(float(t.cast(t.SupportsFloat, record['average'])))


class ValidDataset(Dataset):

    def __init__(self, df: pd.DataFrame, tokenizer: AutoTokenizer, max_len: int) -> None:
        super().__init__()
        self._df = df
        self._tokenizer = tokenizer
        self._max_len = max_len
        self._cleaner = TextCleaner()

    def __len__(self) -> int:
        return len(self._df)

    def __getitem__(self, idx: int) -> t.Tuple[_TokenizedText, _TokenizedText]:
        record = self._df.iloc[idx]
        tokenized_text_more = _preprocess_tokenizer_output(self._tokenizer(
            self._cleaner.clean(str(record['more_toxic'])),
            add_special_tokens=True,
            truncation=True,
            padding='max_length',
            max_length=self._max_len,
            return_attention_mask=True,
            return_token_type_ids=True))  # type: ignore
        tokenized_text_less = _preprocess_tokenizer_output(self._tokenizer(
            self._cleaner.clean(str(record['less_toxic'])),
            add_special_tokens=True,
            truncation=True,
            padding='max_length',
            max_length=self._max_len,
            return_attention_mask=True,
            return_token_type_ids=True))  # type: ignore
        return tokenized_text_more, tokenized_text_less


In [None]:
class Model(torch.nn.Module):

    def __init__(self, checkpoint: str, output_logits: int, dropout: float):
        super(Model, self).__init__()
        self.bert = AutoModel.from_pretrained(checkpoint, return_dict=False)
        self.layer_norm = torch.nn.LayerNorm(output_logits)
        self.dropout = torch.nn.Dropout(dropout)
        self.dense = torch.nn.Sequential(
            torch.nn.Linear(output_logits, 128),
            torch.nn.SiLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(128, 1)
        )

    def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        _, pooled_output = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        pooled_output = self.layer_norm(pooled_output)
        pooled_output = self.dropout(pooled_output)
        preds = self.dense(pooled_output)
        return preds

In [None]:
class Metric:

    def compute(self) -> float:
        raise NotImplementedError()

    def reset(self):
        raise NotImplementedError()

    def compute_and_reset(self) -> float:
        value = self.compute()
        self.reset()
        return value


class Accuracy(Metric):

    def __init__(self):
        self._num_correct = 0
        self._num_total = 0

    def update(self, is_correct_tensor: torch.Tensor):
        num_correct = (is_correct_tensor == 1).int().sum().item()
        num_total = num_correct + (is_correct_tensor != 1).int().sum().item()
        self._num_correct += num_correct
        self._num_total += num_total

    def compute(self) -> float:
        assert self._num_total > 0
        return self._num_correct / self._num_total

    def reset(self):
        self._num_correct = 0
        self._num_total = 0


class Loss(Metric):

    def __init__(self):
        self._value_list: t.List[float] = []

    def update(self, loss_tensor: torch.Tensor):
        self._value_list.extend(loss_tensor.flatten().tolist())

    def compute(self) -> float:
        assert len(self._value_list) > 0
        return sum(self._value_list) / len(self._value_list)

    def reset(self):
        self._value_list.clear()


def do_train_iteration(
        train_data_loader: DataLoader,
        valid_data_loader: DataLoader,
        epoch: int,
        model: Model,
        device: str,
        optimizer: Optimizer,
        scheduler: t.Any,
        validate_every_n_steps: int,
        to_checkpoint: str,
        margin_list: t.Optional[t.List[float]] = None,
        decision_margin: float = 0.0):
    loss_metric = Loss()
    model.train()
    data_iter = tqdm(train_data_loader)
    best_accuracy = 0.0
    for i, (tokenized_text, y) in enumerate(data_iter, start=1):
        step = epoch * len(train_data_loader) + i
        input_ids, attention_mask, token_type_ids, y = (
            tokenized_text['input_ids'].to(device),
            tokenized_text['attention_mask'].to(device),
            tokenized_text['token_type_ids'].to(device),
            y.to(device))
        y_hat = model(input_ids, token_type_ids, attention_mask)
        loss = torch_f.mse_loss(y_hat.squeeze(1), y)
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        loss_metric.update(loss.cpu())
        if step % validate_every_n_steps == 0:
            loss_val = loss_metric.compute_and_reset()
            print(f'Step: {step}. Loss: {loss_val}')
            accuracy = do_valid_iteration(
                data_loader=valid_data_loader,
                model=model,
                step=step,
                device=device,
                margin_list=margin_list,
                decision_margin=decision_margin)
            wandb.log({'step': step, 'train_loss': loss_val, 'valid_accuracy': accuracy})
            if accuracy > best_accuracy:
                print(f'Best accuracy improved from {best_accuracy} to {accuracy}. Saving the model.')
                torch.save(model.state_dict(), to_checkpoint)
                best_accuracy = accuracy
            model.train()


def do_valid_iteration(
        data_loader: DataLoader,
        model: Model,
        step: int,
        device: str,
        margin_list: t.Optional[t.List[float]] = None,
        decision_margin: float = 0.0) -> float:
    if margin_list is None:
        margin_list = [0.0]
    assert decision_margin in margin_list
    accuracy_dict = {margin: Accuracy() for margin in margin_list}
    model.eval()
    with torch.no_grad():
        for tokenized_text_more, tokenized_text_less in tqdm(data_loader, desc=f'Step: {step}. Validation'):
            score_more = model(
                tokenized_text_more['input_ids'].to(device),
                tokenized_text_more['token_type_ids'].to(device),
                tokenized_text_more['attention_mask'].to(device),)
            score_less = model(
                tokenized_text_less['input_ids'].to(device),
                tokenized_text_less['token_type_ids'].to(device),
                tokenized_text_less['attention_mask'].to(device),)
            for margin, accuracy_metric in accuracy_dict.items():
                accuracy_metric.update(((score_more - score_less) >= margin).cpu())
    accuracy_str = ', '.join(f'{margin} = {metric.compute()}' for margin, metric in accuracy_dict.items())
    print(f'Step: {step}. Valid accuracy: {accuracy_str}')
    return accuracy_dict[decision_margin].compute()

In [None]:
# Parameters

IS_KAGGLE = True

ROOT_DIR_PATH = Path('/kaggle') if IS_KAGGLE else Path('/home/jovyan/jigsaw-toxic')
DATA_DIR_PATH = ROOT_DIR_PATH / ('input' if IS_KAGGLE else 'data')
JIGSAW_TOXIC_20211212_DIR_PATH = DATA_DIR_PATH / 'jigsaw-toxic-20211212'

TRAIN_CSV_PATH = JIGSAW_TOXIC_20211212_DIR_PATH / 'offenseval_2020_train.csv'
VALID_CSV_PATH = JIGSAW_TOXIC_20211212_DIR_PATH / 'valid.csv'

MODELS_DIR_PATH = ROOT_DIR_PATH / ('working/models' if IS_KAGGLE else 'models')

In [None]:
!mkdir -p $MODELS_DIR_PATH

In [None]:
# Read dataframes

train_df = pd.read_csv(TRAIN_CSV_PATH)
valid_df = pd.read_csv(VALID_CSV_PATH)

In [None]:
train_df

In [None]:
valid_df

In [None]:
def main(
        train_df: pd.DataFrame,
        valid_df: pd.DataFrame,
        from_checkpoint: str,
        to_checkpoint: str,
        num_epochs: int,
        batch_size: int,
        max_len: int,
        num_workers: int,
        device: str,
        output_logits: int,
        dropout: float,
        lr: float,
        min_lr: float,
        t_0: int,
        t_mult: int,
        margin_list: t.List[float],
        decision_margin: float,
        validate_every_n_steps: int,
        is_kaggle: bool = True):
    if is_kaggle:
        from kaggle_secrets import UserSecretsClient
        user_secrets = UserSecretsClient()
        api_key = user_secrets.get_secret('wandb-api-token')
    else:
        api_key = os.environ['WANDB_API_KEY']
    wandb.login(key=api_key)
    wandb.init(project='kaggle-jigsaw-toxic', entity='andrei-papou')
    
    wandb.config = {
        'from_checkpoint': from_checkpoint,
        'lr': lr,
        'batch_size': batch_size,
        'dropout': dropout,
        'output_logits': output_logits,
        'max_len': max_len,
        'num_epochs': num_epochs,
        'validate_every_n_steps': validate_every_n_steps,
        'optimizer': 'adam_w',
        'scheduler': 'cosine_annealing_warm_restarts',
        't_0': t_0,
    }
    
    tokenizer = AutoTokenizer.from_pretrained(from_checkpoint)
    train_dataset = TrainDataset(train_df, tokenizer=tokenizer, max_len=max_len)
    valid_dataset = ValidDataset(valid_df, tokenizer=tokenizer, max_len=max_len)
    train_data_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=True,
        pin_memory=True)
    valid_data_loader = DataLoader(
        dataset=valid_dataset,
        batch_size=batch_size * 2,
        num_workers=num_workers,
        shuffle=False,
        pin_memory=True)
    model = Model(checkpoint=from_checkpoint, output_logits=output_logits, dropout=dropout)
    model = model.to(device)
    wandb.watch(model)
    optimizer = AdamW(model.parameters(), lr=lr)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=t_0, T_mult=t_mult, eta_min=min_lr)

    for epoch in range(num_epochs):
        do_train_iteration(
            train_data_loader=train_data_loader,
            valid_data_loader=valid_data_loader,
            model=model,
            epoch=epoch,
            device=device,
            optimizer=optimizer,
            scheduler=scheduler,
            to_checkpoint=to_checkpoint,
            margin_list=margin_list,
            decision_margin=decision_margin,
            validate_every_n_steps=validate_every_n_steps)

In [None]:
main(
    train_df=train_df,
    valid_df=valid_df,
    from_checkpoint='roberta-base',
    to_checkpoint=str(MODELS_DIR_PATH / 'offenseval-2020-regression.pt'),
    num_epochs=2,
    batch_size=48,
    max_len=256,
    num_workers=2,
    device='cuda',
    output_logits=768,
    dropout=0.2,
    lr=1e-4,
    min_lr=1e-6,
    t_0=400,
    t_mult=2,
    margin_list=[0.0, 0.05, 0.1],
    decision_margin=0.0,
    validate_every_n_steps=400,
    is_kaggle=IS_KAGGLE)