### Imports

In [None]:
from __future__ import annotations

import numpy as np
import random
import re
import statistics
import typing as t
from pathlib import Path

import nltk
import pandas as pd
import torch
import torch.nn.functional as torch_f
import typing_extensions as t_ext
import wandb
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from sklearn.feature_extraction.text import TfidfVectorizer
from torch.optim import Optimizer, AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm.notebook import tqdm
from transformers.models.auto.modeling_auto import AutoModel
from transformers.models.auto.tokenization_auto import AutoTokenizer
from wandb.wandb_run import Run as WAndBRun

In [None]:
tqdm.pandas()
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('stopwords')

In [None]:
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

seed_everything(42)

### Datasets

In [None]:
class _TokenizedText(t_ext.TypedDict):
    input_ids: torch.Tensor
    attention_mask: torch.Tensor


def _preprocess_tokenizer_output(output: t.Dict[str, t.Any]) -> _TokenizedText:
    return {
        'input_ids': torch.tensor(output['input_ids']),
        'attention_mask': torch.tensor(output['attention_mask']),
    }


class Tokenizer:

    def tokenize(self, x: str) -> t.List[str]:
        return x.split(' ')

    def invert_tokenize(self, x: t.List[str]) -> str:
        return ' '.join(x)


class RandomlyReduceTokenLenTo:

    def __init__(self, token_len: int, tokenizer: t.Optional[Tokenizer] = None):
        self._token_len = token_len
        self._tokenizer = tokenizer if tokenizer is not None else Tokenizer()

    def __call__(self, text: str) -> str:
        token_list = self._tokenizer.tokenize(text)
        if len(token_list) <= self._token_len:
            return text
        idx_set = set(random.choices(list(range(len(token_list))), k=self._token_len))
        return self._tokenizer.invert_tokenize([token for idx, token in enumerate(token_list) if idx in idx_set])


class RandomSubsetPerEpochSampler(Sampler[int]):

    @staticmethod
    def _build_index(data_source: t.Sized) -> t.List[int]:
        index = list(range(len(data_source)))
        random.shuffle(index)
        return index

    def __init__(self, data_source: t.Sized, samples_per_epoch: int):
        super().__init__(data_source)
        self._data_source = data_source
        self._samples_per_epoch = samples_per_epoch
        self._index: t.List[int] = self._build_index(data_source)
        self._real_epoch = 0
    
    def _sample_one(self) -> int:
        if not self._index:
            self._index = self._build_index(self._data_source)
            self._real_epoch += 1
        return self._index.pop()

    def __iter__(self) -> t.Iterator[int]:
        return iter([self._sample_one() for _ in range(self._samples_per_epoch)])

    def __len__(self) -> int:
        return self._samples_per_epoch

    @property
    def real_epoch(self) -> int:
        return self._real_epoch

    @property
    def frac_left(self) -> float:
        return len(self._index) / len(self._data_source)

    @property
    def frac_consumed(self) -> float:
        return 1.0 - self.frac_left


class TrainDataset(Dataset):

    def __init__(
            self,
            df: pd.DataFrame,
            tokenizer: AutoTokenizer,
            vectorizer: TfidfVectorizer,
            max_len: int,
            augmentation_list: t.Optional[t.List[t.Callable[[str], str]]] = None):
        super().__init__()
        self._df = df
        self._tokenizer = tokenizer
        self._vectorizer = vectorizer
        self._max_len = max_len
        self._augmentation_list = augmentation_list if augmentation_list is not None else []

    def __len__(self) -> int:
        return len(self._df)

    def _apply_augmentations(self, text: str) -> str:
        for augmentation in self._augmentation_list:
            text = augmentation(text)
        return text

    def __getitem__(self, idx: int) -> t.Tuple[_TokenizedText, _TokenizedText, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        record = self._df.iloc[idx]
        more_comment_text = self._apply_augmentations(str(record['more_toxic']))
        less_comment_text = self._apply_augmentations(str(record['less_toxic']))
        tokenized_text_more = _preprocess_tokenizer_output(self._tokenizer(
            more_comment_text,
            add_special_tokens=True,
            truncation=True,
            padding='max_length',
            max_length=self._max_len,
            return_attention_mask=True))  # type: ignore
        tokenized_text_less = _preprocess_tokenizer_output(self._tokenizer(
            less_comment_text,
            add_special_tokens=True,
            truncation=True,
            padding='max_length',
            max_length=self._max_len,
            return_attention_mask=True))  # type: ignore
        more_labels = torch.tensor([int(l) for l in record['more_labels'].split(' ')], dtype=torch.float32)
        less_labels = torch.tensor([int(l) for l in record['less_labels'].split(' ')], dtype=torch.float32)

        more_vector_tensor = torch.tensor(self._vectorizer.transform([record['more_toxic_cleaned']]).toarray()[0], dtype=torch.float32)
        less_vector_tensor = torch.tensor(self._vectorizer.transform([record['less_toxic_cleaned']]).toarray()[0], dtype=torch.float32)

        # print(f'more_vector_tensor.shape = {more_vector_tensor.shape}')
        # print(f'less_vector_tensor.shape = {less_vector_tensor.shape}')

        return tokenized_text_more, tokenized_text_less, more_labels, less_labels, more_vector_tensor, less_vector_tensor


def _split_str_to_chunk_list(s: str, chunk_size: int) -> t.List[str]:
    chunk_list = []
    chunk = []
    for token in s.split(' '):
        chunk.append(token)
        if len(chunk) >= chunk_size:
            chunk_list.append(' '.join(chunk))
            chunk.clear()
    if chunk:
        chunk_list.append(' '.join(chunk))
    return chunk_list


def valid_collate_fn(
        sample_list: t.List[t.Tuple[int, _TokenizedText, _TokenizedText, torch.Tensor, torch.Tensor]]
        ) -> t.Tuple[t.List[int], _TokenizedText, _TokenizedText, t.List[slice], t.List[slice], torch.Tensor, torch.Tensor]:
    curr_pos_more, curr_pos_less = 0, 0

    idx_list: t.List[int] = []
    more_input_ids_list, less_input_ids_list = [], []
    more_attention_mask_list, less_attention_mask_list = [], []
    more_slice_list: t.List[slice] = []
    less_slice_list: t.List[slice] = []
    more_vector_tensor_list: t.List[torch.Tensor] = []
    less_vector_tensor_list: t.List[torch.Tensor] = []
    
    for sample in sample_list:
        idx_list.append(sample[0])
        more_input_ids, more_attention_mask = sample[1]['input_ids'], sample[1]['attention_mask']
        less_input_ids, less_attention_mask = sample[2]['input_ids'], sample[2]['attention_mask']
        more_input_ids_list.append(more_input_ids)
        less_input_ids_list.append(less_input_ids)
        more_attention_mask_list.append(more_attention_mask)
        less_attention_mask_list.append(less_attention_mask)
        more_slice_list.append(slice(curr_pos_more, curr_pos_more + more_input_ids.shape[0]))
        curr_pos_more += more_input_ids.shape[0]
        less_slice_list.append(slice(curr_pos_less, curr_pos_less + less_input_ids.shape[0]))
        curr_pos_less += less_input_ids.shape[0]
        more_vector_tensor_list.append(sample[3])
        less_vector_tensor_list.append(sample[4])

    more_tokenized_collated: _TokenizedText = {
        'input_ids': torch.cat(more_input_ids_list, dim=0),
        'attention_mask': torch.cat(more_attention_mask_list, dim=0),
    }
    less_tokenized_collated: _TokenizedText = {
        'input_ids': torch.cat(less_input_ids_list, dim=0),
        'attention_mask': torch.cat(less_attention_mask_list, dim=0),
    }

    return (
        idx_list,
        more_tokenized_collated,
        less_tokenized_collated,
        more_slice_list,
        less_slice_list,
        torch.cat(more_vector_tensor_list, dim=0),
        torch.cat(less_vector_tensor_list, dim=0),
    )


class ValidDataset(Dataset):

    def __init__(
            self,
            df: pd.DataFrame,
            tokenizer: AutoTokenizer,
            vectorizer: TfidfVectorizer,
            max_len: int) -> None:
        super().__init__()
        self._df = df
        self._tokenizer = tokenizer
        self._vectorizer = vectorizer
        self._max_len = max_len
        self._error: np.ndarray = np.zeros(len(df))

    def __len__(self) -> int:
        return len(self._df)

    def track_error(self, idx_list: t.List[int], error: np.ndarray):
        self._error[idx_list] = error

    def get_df_with_error(self) -> pd.DataFrame:
        df = self._df.copy()
        df['error'] = self._error
        return df

    def __getitem__(self, idx: int) -> t.Tuple[int, _TokenizedText, _TokenizedText, torch.Tensor, torch.Tensor]:
        record = self._df.iloc[idx]
        text_more = str(record['more_toxic'])
        text_less = str(record['less_toxic'])

        more_input_ids_list, less_input_ids_list = [], []
        more_attention_mask_list, less_attention_mask_list = [], []
        more_vector_tensor_list, less_vector_tensor_list = [], []
        for chunk in _split_str_to_chunk_list(text_more, chunk_size=self._max_len):
            tokenized_chunk = _preprocess_tokenizer_output(self._tokenizer(
                chunk,
                add_special_tokens=True,
                truncation=True,
                padding='max_length',
                max_length=self._max_len,
                return_attention_mask=True))  # type: ignore
            more_input_ids_list.append(tokenized_chunk['input_ids'])
            more_attention_mask_list.append(tokenized_chunk['attention_mask'])
            more_vector_tensor_list.append(torch.tensor(
                self._vectorizer.transform([record['more_toxic_cleaned']]).toarray()[0],
                dtype=torch.float32))
        for chunk in _split_str_to_chunk_list(text_less, chunk_size=self._max_len):
            tokenized_chunk = _preprocess_tokenizer_output(self._tokenizer(
                chunk,
                add_special_tokens=True,
                truncation=True,
                padding='max_length',
                max_length=self._max_len,
                return_attention_mask=True))  # type: ignore
            less_input_ids_list.append(tokenized_chunk['input_ids'])
            less_attention_mask_list.append(tokenized_chunk['attention_mask'])
            less_vector_tensor_list.append(torch.tensor(
                self._vectorizer.transform([record['less_toxic_cleaned']]).toarray()[0],
                dtype=torch.float32))

        tokenized_more: _TokenizedText = {
            'input_ids': torch.stack(more_input_ids_list, dim=0),
            'attention_mask': torch.stack(more_attention_mask_list, dim=0),
        }
        tokenized_less: _TokenizedText = {
            'input_ids': torch.stack(less_input_ids_list, dim=0),
            'attention_mask': torch.stack(less_attention_mask_list, dim=0),
        }

        return idx, tokenized_more, tokenized_less, torch.stack(more_vector_tensor_list, dim=0), torch.stack(less_vector_tensor_list, dim=0)


### Text cleaners

In [None]:
tokenizer = t.cast(t.Callable[[str], t.List[str]], nltk.tokenize.word_tokenize)
stop_words = stopwords.words('english')


class _Cleaner:

    def __call__(self, text: str) -> str:
        return text


class URLCleaner(_Cleaner):
    _RE_URL_1 = re.compile('((www\.[^\s]+)|(https?://[^\s]+))')
    _RE_URL_2 = re.compile(r'#([^\s]+)')

    def __call__(self, text: str) -> str:
        text = self._RE_URL_1.sub('url', text)
        text = self._RE_URL_2.sub(r'\1', text)
        return text


class AbbrevCleaner(_Cleaner):
    
    def __call__(self, text: str) -> str:
        return text\
            .replace('what\'s', 'what is ')\
            .replace('\'ve', ' have ')\
            .replace('can\'t', 'cannot ')\
            .replace('n\'t', ' not ')\
            .replace('i\'m', 'i am ')\
            .replace('\'re', ' are ')\
            .replace('\'d', ' would ')\
            .replace('\'ll', ' will ')\
            .replace('\'scuse', ' excuse ')\
            .replace('\'s', ' ')


class UnicodeCleaner(_Cleaner):
    _RE_UNICODE_1 = re.compile(r'(\\u[0-9A-Fa-f]+)')
    _RE_UNICODE_2 = re.compile(r'[^\x00-\x7f]')

    def __call__(self, text: str) -> str:
        text = self._RE_UNICODE_1.sub(r' ', text)
        text = self._RE_UNICODE_2.sub(r' ', text)
        return text


class RepeatPatternCleaner(_Cleaner):
    _RE_REPEAT_1 = re.compile(r'([a-zA-Z])\1{2,}\b')
    _RE_REPEAT_2 = re.compile(r'([a-zA-Z])\1\1{2,}\B')
    _RE_REPEAT_3 = re.compile(r'[ ]{2,}')

    def __call__(self, text: str) -> str:
        text = self._RE_REPEAT_1.sub(r'\1\1', text)
        text = self._RE_REPEAT_2.sub(r'\1\1\1', text)
        text = self._RE_REPEAT_3.sub(' ', text)
        return text    


class AtUserCleaner(_Cleaner):
    _RE_AT_USER = re.compile('@[^\s]+')

    def __call__(self, text: str) -> str:
        text = self._RE_AT_USER.sub('atUser', text)
        return text


class MultiToxicWordsCleaner(_Cleaner):
    _RE_FUCK_1 = re.compile(r'(fuckfuck)')
    _RE_FUCK_2 = re.compile(r'(f+)( *)([u|*]+)( *)([c|*]+)( *)(k)+')
    _RE_HAHA = re.compile(r'(haha)')
    _RE_SHIT = re.compile(r'(s+ *h+ *i+ *t+)')
    _RE_ASS = re.compile(r'([a|@][$|s][s|$])')
    _RE_FUK = re.compile(r'(\bfuk\b)')

    def __call__(self, text: str) -> str:
        text = self._RE_FUCK_1.sub('fuck fuck ', text)
        text = self._RE_FUCK_2.sub('fuck', text)
        text = self._RE_HAHA.sub('ha ha ', text)
        text = self._RE_SHIT.sub('shit', text)
        text = self._RE_ASS.sub('ass', text)
        text = self._RE_FUK.sub('fuck', text)
        return text


class NumbersCleaner(_Cleaner):
    _RE_NUMBERS = re.compile(r"(^|\W)\d+")

    def __call__(self, text: str) -> str:
        return self._RE_NUMBERS.sub(' ', text)


class MultiPuncCleaner(_Cleaner):
    _RE_1 = re.compile(r'([!?\'])\1+')
    _RE_2 = re.compile(r'([!?\'])')
    _RE_3 = re.compile(r'([*_:])\1+')

    def __call__(self, text: str) -> str:
        text = self._RE_1.sub(r' \1\1 ', text)
        text = self._RE_2.sub(r' \1 ', text)
        text = self._RE_3.sub(r'\1', text)
        return text


class Lemmatizer(t_ext.Protocol):

    def lemmatize(self, word: str, pos: str = "n") -> str:
        ...


class ReplaceTokenCleaner:

    def __init__(self, token_set: t.Set[str], replace_with: str):
        self._token_set = token_set
        self._replace_with = replace_with

    def __call__(self, text: str) -> str:
        for token in self._token_set:
            text = text.replace(token, self._replace_with)
        return text


class RemoveStopWordsCleaner:

    def __init__(self, tokenizer: t.Callable[[str], t.List[str]], stop_words: t.Optional[t.List[str]] = None):
        self._tokenizer = tokenizer
        self._stop_words = stop_words if stop_words is not None else stopwords.words('english')

    def __call__(self, text: str) -> str:
        return ' '.join([token for token in self._tokenizer(text) if token not in self._stop_words])


class LemmatizeCleaner:

    def __init__(self, tokenizer: t.Callable[[str], t.List[str]], lemmatizer: Lemmatizer):
        self._tokenizer = tokenizer
        self._lemmatizer = lemmatizer

    def __call__(self, text: str) -> str:
        return ' '.join([self._lemmatizer.lemmatize(token) for token in self._tokenizer(text)])


class TextCleanerList:

    def __init__(self, cleaner_list: t.List[t.Callable[[str], str]]):
        self._cleaner_list = cleaner_list

    def __call__(self, text: str) -> str:
        for cleaner in self._cleaner_list:
            text = cleaner(text)
        return text


text_cleaner = TextCleanerList([
    lambda text: text.lower(),
    URLCleaner(),
    UnicodeCleaner(),
    NumbersCleaner(),
    AbbrevCleaner(),
    MultiToxicWordsCleaner(),
    MultiPuncCleaner(),
    RepeatPatternCleaner(),
    ReplaceTokenCleaner(
        token_set=set('"%&\'()+,-./:;<=>@[\\]^_`{|}~'),
        replace_with=' '),
    LemmatizeCleaner(
        tokenizer=tokenizer,
        lemmatizer=WordNetLemmatizer()),
    RemoveStopWordsCleaner(tokenizer),
])

### Model

In [None]:
class _WeightedAverageLinearRegressor(torch.nn.Linear):

    def __init__(self, in_features: int, device: t.Optional[str] = None, dtype: t.Optional[str] = None):
        super().__init__(in_features=in_features, out_features=1, bias=False, device=device, dtype=dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch_f.linear(x, torch_f.softmax(self.weight, dim=1), self.bias)


class _TransformerClassifier(torch.nn.Module):

    def __init__(self, checkpoint: str, output_logits: int, num_classes: int):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(checkpoint, return_dict=False)
        self.classifier = torch.nn.Sequential(
            # torch.nn.LayerNorm(output_logits),
            torch.nn.Linear(output_logits, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, num_classes))

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        _, pooled_output = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask)
        label_preds = self.classifier(pooled_output)
        return label_preds


class _LinearClassifier(torch.nn.Module):

    def __init__(self, num_features: int, num_classes: int):
        super().__init__()
        self.inner = torch.nn.Linear(in_features=num_features, out_features=num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.inner(x)


class _ClassifierDispatcher(torch.nn.Module):

    def __init__(self, num_classes: int):
        super().__init__()
        self.inner = torch.nn.Linear(in_features=num_classes * 2, out_features=1)

    def forward(self, transformer_pred_tensor: torch.Tensor, linear_pred_tensor: torch.Tensor) -> torch.Tensor:
        return torch_f.sigmoid(self.inner(torch.concat([transformer_pred_tensor, linear_pred_tensor], dim=1)))


class Model(torch.nn.Module):

    def __init__(
            self,
            transformer_checkpoint: str,
            transformer_output_logits: int,
            linear_num_features: int,
            num_classes: int):
        super().__init__()
        self.transformer_classifier = _TransformerClassifier(
            checkpoint=transformer_checkpoint,
            output_logits=transformer_output_logits,
            num_classes=num_classes)
        self.linear_classifier = _LinearClassifier(
            num_features=linear_num_features,
            num_classes=num_classes)
        self.classifier_dispatcher = _ClassifierDispatcher(num_classes=num_classes)
        self.regressor = _WeightedAverageLinearRegressor(in_features=num_classes)

    def forward_scores(self, label_preds: torch.Tensor) -> torch.Tensor:
        return self.regressor(label_preds)

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, tfidf_vector_repr: torch.Tensor) -> t.Tuple[torch.Tensor, torch.Tensor]:
        transformer_pred_tensor = self.transformer_classifier(input_ids, attention_mask)
        linear_pred_tensor = self.linear_classifier(tfidf_vector_repr)
        theta = self.classifier_dispatcher(transformer_pred_tensor, linear_pred_tensor)
        pred_tensor = transformer_pred_tensor * theta + linear_pred_tensor * (1.0 - theta)
        scores = self.forward_scores(torch.sigmoid(pred_tensor))
        return pred_tensor, scores


### Metrics

In [None]:
class _Metric:

    def compute(self) -> float:
        raise NotImplementedError()

    def reset(self):
        raise NotImplementedError()

    def compute_and_reset(self) -> float:
        value = self.compute()
        self.reset()
        return value


class Accuracy(_Metric):

    def __init__(self):
        self._num_correct = 0
        self._num_total = 0

    def update(self, is_correct_tensor: torch.Tensor):
        num_correct = (is_correct_tensor == 1).int().sum().item()
        num_total = num_correct + (is_correct_tensor != 1).int().sum().item()
        self._num_correct += num_correct
        self._num_total += num_total

    def compute(self) -> float:
        assert self._num_total > 0
        return self._num_correct / self._num_total

    def reset(self):
        self._num_correct = 0
        self._num_total = 0


class _FloatListMetric(_Metric):

    def __init__(self):
        self._value_list: t.List[float] = []

    def update(self, value_tensor: torch.Tensor):
        self._value_list.extend(value_tensor.flatten().tolist())

    def reset(self):
        self._value_list.clear()


class FloatListMean(_FloatListMetric):

    def compute(self) -> float:
        return statistics.mean(self._value_list)


class FloatListStd(_FloatListMetric):

    def compute(self) -> float:
        return statistics.stdev(self._value_list)

### Loggers

In [None]:
_C = t.TypeVar('_C')


class ContextManagerList(t.Generic[_C]):

    def __init__(self, cm_list: t.List[t.ContextManager[_C]]):
        self._cm_list = cm_list

    def __enter__(self) -> t.List[_C]:
        return [cm.__enter__() for cm in self._cm_list]

    def __exit__(self, *args, **kwargs):
        for cm in self._cm_list:
            cm.__exit__(*args, **kwargs)


class Logger:

    def __enter__(self) -> Logger:
        return self

    def __exit__(self, *args, **kwargs):
        pass

    def log_params(self, params: t.Dict[str, t.Any]):
        raise NotImplementedError()

    def log_metrics(self, step: int, metrics: t.Dict[str, float]):
        raise NotImplementedError()


class StdOutLogger(Logger):

    def log_params(self, params: t.Dict[str, t.Any]):
        print('Using params:')
        for param in sorted(params.keys()):
            print(f'\t{param} = {params[param]}')

    def log_metrics(self, step: int, metrics: t.Dict[str, float]):
        print(f'Step {step} metrics:')
        for m in sorted(metrics.keys()):
            print(f'\t{m} = {metrics[m]:.8f}')


class TensorBoardLogger(Logger):

    def __init__(self, log_dir: str, metric_whitelist: t.Optional[t.Set[str]] = None) -> None:
        self._metric_whitelist = metric_whitelist
        self._writer = SummaryWriter(log_dir=log_dir)

    def log_params(self, params: t.Dict[str, t.Any]):
        pass  # TODO: handle hyperparams properly.

    def log_metrics(self, step: int, metrics: t.Dict[str, float]):
        for metric_key, metric_val in metrics.items():
            if self._metric_whitelist is None or metric_key in self._metric_whitelist:
                self._writer.add_scalar(tag=metric_key, scalar_value=metric_val, global_step=step)


class WAndBLogger(Logger):

    def __init__(self, user_name: str, api_key: str, project: str, run_id: str):
        wandb.login(key=api_key)
        self._user_name = user_name
        self._project = project
        self._run_id = run_id
        self._run: t.Optional[WAndBRun] = None
    
    @property
    def run(self) -> WAndBRun:
        assert self._run is not None
        return self._run

    def __enter__(self) -> WAndBLogger:
        self._run = wandb.init(project=self._project, entity=self._user_name, run_id=self._run_id)
        return self
    
    def log_params(self, params: t.Dict[str, t.Any]):
        self.run.config.update(params)

    def log_metrics(self, step: int, metrics: t.Dict[str, float]):
        self.run.log(step=step, data=metrics)


### Schedulers

In [None]:
class ForwardScheduledFloat:

    def __init__(self, start: float, end: float, step: float):
        self._end = end
        self._step = step
        self._val = start

    @property
    def value(self) -> float:
        return self._val

    def step(self):
        if self._val < self._end:
            self._val += self._step

### Loss

In [None]:
def margin_ranking_loss(
        more_scores: torch.Tensor,
        less_scores: torch.Tensor,
        diff: torch.Tensor,
        margin: float,
        device: str,) -> torch.Tensor:
    return torch.maximum(torch.tensor(0.0, device=device), less_scores - more_scores + margin * diff).mean()

### Iteration functions

In [None]:
def do_train_iteration(
        data_loader: DataLoader,
        model: Model,
        device: str,
        optimizer: Optimizer,
        scheduler: CosineAnnealingWarmRestarts,
        ranking_loss_part: float,
        train_margin_list: t.Optional[t.List[float]] = None,
        train_decision_margin: float = 0.3,
        accumulate_gradient_steps: int = 1,
        num_steps: t.Optional[int] = None) -> t.Dict[str, float]:
    sampler: RandomSubsetPerEpochSampler = t.cast(RandomSubsetPerEpochSampler, data_loader.sampler)
    loss_metric = FloatListMean()
    train_score_mean = FloatListMean()
    train_score_std = FloatListStd()
    if train_margin_list is None:
        train_margin_list = [train_decision_margin]
    assert train_decision_margin in train_margin_list
    train_accuracy_dict = {m: Accuracy() for m in train_margin_list}

    model.train()
    data_iter = tqdm(data_loader, desc='Training', total=num_steps if num_steps is not None else len(data_loader))
    for step, (tokenized_text_more, tokenized_text_less, labels_more, labels_less, vector_tensor_more, vector_tensor_less) in enumerate(data_iter):
        (
            input_ids_more,
            attention_mask_more,
            input_ids_less,
            attention_mask_less,
            labels_more,
            labels_less,
            vector_tensor_more,
            vector_tensor_less,
        ) = (
            tokenized_text_more['input_ids'].to(device),
            tokenized_text_more['attention_mask'].to(device),
            tokenized_text_less['input_ids'].to(device),
            tokenized_text_less['attention_mask'].to(device),
            labels_more.to(device),
            labels_less.to(device),
            vector_tensor_more.to(device),
            vector_tensor_less.to(device),
        )
        preds_more, score_more = model(input_ids_more, attention_mask_more, vector_tensor_more)
        preds_less, score_less = model(input_ids_less, attention_mask_less, vector_tensor_less)
        score_more_approx = model.forward_scores(labels_more)
        score_less_approx = model.forward_scores(labels_less)

        # loss = torch_f.margin_ranking_loss(
        #     score_more, score_less, torch.ones(score_more.shape[0], device=device), margin=train_decision_margin)
        cls_more_loss = torch_f.multilabel_soft_margin_loss(preds_more, labels_more)
        cls_less_loss = torch_f.multilabel_soft_margin_loss(preds_less, labels_less)
        ranking_loss = margin_ranking_loss(score_more, score_less, margin=train_decision_margin, diff=score_more_approx - score_less_approx, device=device)
        loss = ranking_loss_part * ranking_loss + (1.0 - ranking_loss_part) * (cls_more_loss + cls_less_loss) / 2
        loss.backward()

        if (step + 1) % accumulate_gradient_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()

        with torch.no_grad():
            score_more_cpu = score_more.cpu()
            train_score_mean.update(score_more_cpu)
            train_score_std.update(score_more_cpu)
            score_less_cpu = score_less.cpu()
            train_score_mean.update(score_less_cpu)
            train_score_std.update(score_less_cpu)

            loss_metric.update(loss.cpu())
            for m, a in train_accuracy_dict.items():
                a.update(score_more - score_less > m)
        epoch_str = f'epoch: {sampler.real_epoch} [{sampler.frac_consumed:.4f}]'
        accuracy_str = ', '.join([f'acc_{m}: {train_accuracy_dict[m].compute():.4f}' for m in sorted(train_accuracy_dict.keys())])
        data_iter.set_description(
            f'{epoch_str} th: {ranking_loss_part:.2f} loss: {loss_metric.compute():.4f}, {accuracy_str} '
            f'score_mean: {train_score_mean.compute():.6f}, score_std: {train_score_std.compute():.6f}')

        if num_steps is not None and step >= num_steps - 1:
            break

    train_metrics_to_track = {f'train_accuracy_{m}': a.compute() for m, a in train_accuracy_dict.items()}
    loss_val = loss_metric.compute_and_reset()
    return {
        'train_loss': loss_val,
        'train_score_mean': train_score_mean.compute(),
        'train_score_std': train_score_std.compute(),
        **train_metrics_to_track,
    }


@torch.no_grad()
def do_valid_iteration(
        data_loader: DataLoader,
        model: Model,
        device: str,
        margin_list: t.Optional[t.List[float]] = None,
        decision_margin: float = 0.0) -> t.Tuple[float, t.Dict[str, float]]:
    if margin_list is None:
        margin_list = [decision_margin]
    assert decision_margin in margin_list
    accuracy_dict = {margin: Accuracy() for margin in margin_list}
    valid_score_mean = FloatListMean()
    valid_score_std = FloatListStd()
    model.eval()
    it = tqdm(data_loader, desc='Validation')
    for idx_list, tokenized_text_more, tokenized_text_less, slice_list_more, slice_list_less, vector_tensor_more, vector_tensor_less in it:
        _, score_more = model(
            tokenized_text_more['input_ids'].to(device),
            tokenized_text_more['attention_mask'].to(device),
            vector_tensor_more.to(device))
        _, score_less = model(
            tokenized_text_less['input_ids'].to(device),
            tokenized_text_less['attention_mask'].to(device),
            vector_tensor_less.to(device))
        score_more = torch.cat([torch.max(score_more[s], dim=0, keepdim=True)[0] for s in slice_list_more], dim=0)
        score_less = torch.cat([torch.max(score_less[s], dim=0, keepdim=True)[0] for s in slice_list_less], dim=0)
        score_more_cpu = score_more.cpu()
        valid_score_mean.update(score_more_cpu)
        valid_score_std.update(score_more_cpu)
        score_less_cpu = score_less.cpu()
        valid_score_mean.update(score_less_cpu)
        valid_score_std.update(score_less_cpu)
        for margin, accuracy_metric in accuracy_dict.items():
            accuracy_metric.update(((score_more - score_less) > margin).cpu())
        data_loader.dataset.track_error(
            idx_list, torch.maximum(torch.zeros_like(score_less_cpu), score_less_cpu - score_more_cpu).squeeze(1))
        accuracy_str = ', '.join([f'acc_{m}: {accuracy_dict[m].compute():.4f}' for m in sorted(accuracy_dict.keys())])
        it.set_description(f'Validation. {accuracy_str}')
    score_dict_to_track = {
        'valid_score_mean': valid_score_mean.compute(),
        'valid_score_std': valid_score_std.compute(),
    }
    accuracy_dict_to_track =  {f'valid_accuracy_{m}': a.compute() for m, a in accuracy_dict.items()}
    return accuracy_dict[decision_margin].compute(), {**score_dict_to_track, **accuracy_dict_to_track}

### Main function

In [None]:
def main(
        train_df: pd.DataFrame,
        valid_df: pd.DataFrame,
        num_classes: int,
        from_checkpoint: str,
        to_checkpoint: str,
        error_artifact_dir_path: Path,
        logger_list: t.List[Logger],
        num_epochs: int,
        batch_size: int,
        max_len: int,
        num_workers: int,
        device: str,
        output_logits: int,
        lr: float,
        t_0: int,
        eta_min: float,
        train_margin_list: t.List[float],
        train_decision_margin: float,
        valid_margin_list: t.List[float],
        valid_decision_margin: float,
        accumulate_gradient_steps: int = 1,
        validate_every_n_steps: t.Optional[int] = None):    
    for logger in logger_list:
        logger.log_params({
            'from_checkpoint': from_checkpoint,
            'lr': lr,
            'batch_size': batch_size,
            'output_logits': output_logits,
            'max_len': max_len,
            'num_epochs': num_epochs,
            'optimizer': 'adam_w',
            'scheduler': 'cosine_annealing_warm_restarts',
            'accumulate_gradient_steps': accumulate_gradient_steps,
        })
    
    tokenizer = AutoTokenizer.from_pretrained(from_checkpoint)
    vectorizer = TfidfVectorizer(min_df=3, max_df=0.5, max_features=65536, analyzer='char_wb', ngram_range=(3, 5))\
        .fit(train_df['more_toxic_cleaned'].tolist() + train_df['less_toxic_cleaned'].tolist())
    train_dataset = TrainDataset(
        train_df,
        tokenizer=tokenizer,
        vectorizer=vectorizer,
        max_len=max_len,
        augmentation_list=[
            # RandomlyReduceTokenLenTo(token_len=max_len),
        ])
    valid_dataset = ValidDataset(
        valid_df,
        tokenizer=tokenizer,
        vectorizer=vectorizer,
        max_len=max_len)
    train_data_loader = DataLoader(
        dataset=train_dataset,
        sampler=RandomSubsetPerEpochSampler(
            data_source=train_dataset,
            samples_per_epoch=validate_every_n_steps if validate_every_n_steps is not None else len(train_dataset)),
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True)
    valid_data_loader = DataLoader(
        dataset=valid_dataset,
        batch_size=batch_size * 2,
        num_workers=num_workers,
        shuffle=False,
        pin_memory=True,
        collate_fn=valid_collate_fn)  # type: ignore
    model = Model(
        transformer_checkpoint=from_checkpoint,
        transformer_output_logits=output_logits,
        num_classes=num_classes,
        linear_num_features=len(vectorizer.vocabulary_)).to(device)
    optimizer = AdamW(model.parameters(), lr=lr)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=t_0, eta_min=eta_min)

    best_accuracy = 0.0
    for epoch in range(num_epochs):
        train_metrics_to_track = do_train_iteration(
            data_loader=train_data_loader,
            model=model,
            device=device,
            optimizer=optimizer,
            scheduler=scheduler,
            train_margin_list=train_margin_list,
            train_decision_margin=train_decision_margin,
            accumulate_gradient_steps=accumulate_gradient_steps,
            ranking_loss_part=0.5)
        accuracy, valid_metrics_to_track = do_valid_iteration(
            data_loader=valid_data_loader,
            model=model,
            device=device,
            margin_list=valid_margin_list,
            decision_margin=valid_decision_margin)
        for logger in logger_list:
            logger.log_metrics(step=epoch, metrics={
                **train_metrics_to_track,
                **valid_metrics_to_track,
            })
        if accuracy > best_accuracy:
            print(f'Best accuracy improved from {best_accuracy} to {accuracy}. Saving the model.')
            torch.save(model.state_dict(), to_checkpoint)
            best_accuracy = accuracy
        valid_dataset.get_df_with_error().to_csv(str(error_artifact_dir_path / f'{epoch}.csv'), index=False)

### Parameter definitions

In [None]:
# Parameters

IS_KAGGLE = False

ROOT_DIR_PATH = Path('/kaggle') if IS_KAGGLE else Path('/home/jovyan/jigsaw-toxic')
DATA_DIR_PATH = ROOT_DIR_PATH / ('input' if IS_KAGGLE else 'data/datasets')
DATASET_DIR_PATH = DATA_DIR_PATH / 'ccc-2017-multilabel'
TRAIN_CSV_PATH = DATASET_DIR_PATH / 'train_no_leak_pair_harder_3.csv'
VALID_CSV_PATH = DATASET_DIR_PATH / 'valid_pair.csv'
TENSORBOARD_DIR_PATH = ROOT_DIR_PATH / 'working/tensorboard' if IS_KAGGLE else Path('/home/jovyan/tensorboard')
ARTIFACT_DIR_PATH = ROOT_DIR_PATH / 'working/artifacts' if IS_KAGGLE else ROOT_DIR_PATH / 'artifacts'

MODELS_DIR_PATH = ROOT_DIR_PATH / ('working/models' if IS_KAGGLE else 'models')

In [None]:
# Create potentially missing directories
!mkdir -p $MODELS_DIR_PATH
!mkdir -p $ARTIFACT_DIR_PATH

In [None]:
# Read dataframes

train_df = pd.read_csv(TRAIN_CSV_PATH)
valid_df = pd.read_csv(VALID_CSV_PATH)

In [None]:
train_df

In [None]:
train_df['more_toxic_cleaned'] = train_df['more_toxic'].progress_apply(lambda text: text_cleaner(text))
train_df['less_toxic_cleaned'] = train_df['less_toxic'].progress_apply(lambda text: text_cleaner(text))
valid_df['more_toxic_cleaned'] = valid_df['more_toxic'].progress_apply(lambda text: text_cleaner(text))
valid_df['less_toxic_cleaned'] = valid_df['less_toxic'].progress_apply(lambda text: text_cleaner(text))

### Entrypoint

In [None]:
# if IS_KAGGLE:
#     from kaggle_secrets import UserSecretsClient
#     user_secrets = UserSecretsClient()
#     wandb_api_key = user_secrets.get_secret('wandb-api-token')
# else:
#     with open(os.path.join(os.path.dirname(os.getcwd()), 'deploy/secrets/wandb_api_token.txt')) as f:
#         wandb_api_key = f.read()

model_name = f'ccc-2017-multilabel-linear-harder-cls-loss_0p5'
error_artifact_dir_path = ARTIFACT_DIR_PATH / model_name
error_artifact_dir_path.mkdir(exist_ok=True)
with ContextManagerList([
            # StdOutLogger(),
            TensorBoardLogger(
                log_dir=str(TENSORBOARD_DIR_PATH / f'jt-{model_name}'),
                metric_whitelist={
                    'train_loss',
                    'train_accuracy_0.0',
                    'train_accuracy_0.5',
                    'valid_accuracy_0.0',
                    'valid_accuracy_0.5',
                },
            ),
        ]) as logger_list:
    main(
        train_df=train_df,
        valid_df=valid_df,
        num_classes=6,
        from_checkpoint='roberta-base',
        to_checkpoint=str(MODELS_DIR_PATH / f'{model_name}.pt'),
        error_artifact_dir_path=error_artifact_dir_path,
        logger_list=logger_list,
        num_epochs=100,
        batch_size=4,
        max_len=256,
        num_workers=8,
        device='cuda',
        output_logits=768,
        lr=1e-4,
        t_0=100,
        eta_min=1e-6,
        train_margin_list=[0.0, 0.5, 1.0],
        train_decision_margin=1.0,
        valid_margin_list=[0.0, 0.5, 1.0],
        valid_decision_margin=0.0,
        accumulate_gradient_steps=16,
        validate_every_n_steps=4096,
    )