### Imports

In [None]:
from __future__ import annotations

import numpy as np
import os
import random
import shutil
import statistics
import typing as t
from pathlib import Path

import pandas as pd
import torch
import torch.nn.functional as torch_f
import typing_extensions as t_ext
import wandb
from torch.optim import Optimizer, AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm.notebook import tqdm
from transformers.models.auto.configuration_auto import AutoConfig
from transformers.models.auto.modeling_auto import AutoModel
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.optimization import get_cosine_schedule_with_warmup
from wandb.wandb_run import Run as WAndBRun

In [None]:
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

seed_everything(42)

### Datasets

In [None]:
class RandomSubsetPerEpochSampler(Sampler[int]):

    @staticmethod
    def _build_index(data_source: t.Sized) -> t.List[int]:
        index = list(range(len(data_source)))
        random.shuffle(index)
        return index

    def __init__(self, data_source: t.Sized, samples_per_epoch: int):
        super().__init__(data_source)
        self._data_source = data_source
        self._samples_per_epoch = samples_per_epoch
        self._index: t.List[int] = self._build_index(data_source)
        self._real_epoch = 0
        self._step = -1
    
    def _sample_one(self) -> int:
        if not self._index:
            self._index = self._build_index(self._data_source)
            self._real_epoch += 1
        return self._index.pop()

    def __iter__(self) -> t.Iterator[int]:
        self._step += 1
        return iter([self._sample_one() for _ in range(self._samples_per_epoch)])

    def __len__(self) -> int:
        return self._samples_per_epoch

    @property
    def step(self) -> int:
        return self._step

    @property
    def real_epoch(self) -> int:
        return self._real_epoch

    @property
    def frac_left(self) -> float:
        return len(self._index) / len(self._data_source)

    @property
    def frac_consumed(self) -> float:
        return 1.0 - self.frac_left


class MetricBasedRandomSubsetPerEpochSampler(RandomSubsetPerEpochSampler):

    def __init__(
            self,
            data_source: t.Sized,
            samples_per_epoch: int,
            thresholded_samples_per_epoch_list: t.List[t.Tuple[float, int]],):
        super().__init__(data_source, samples_per_epoch)
        self._samples_per_epoch_init = samples_per_epoch
        self._thresholded_samples_per_epoch_list = thresholded_samples_per_epoch_list

    def adjust_samples_per_epoch(self, metric_val: float):
        self._samples_per_epoch = self._samples_per_epoch_init
        for threshold, samples_per_epoch in self._thresholded_samples_per_epoch_list:
            if metric_val >= threshold:
                self._samples_per_epoch = samples_per_epoch


class PairDataset(Dataset):

    def __init__(
            self,
            pair_df: pd.DataFrame,
            feature_df: pd.DataFrame) -> None:
        super().__init__()
        self._pair_df = pair_df
        self._feature_df = feature_df
        self._feature_col_list = [col for col in list(self._feature_df.columns) if col.startswith('score_')]
        print(self._feature_col_list)

    @property
    def num_features(self) -> int:
        return len(self._feature_col_list)

    def __len__(self) -> int:
        return len(self._pair_df)

    def __getitem__(self, idx: int) -> t.Tuple[torch.Tensor, torch.Tensor]:
        record = self._pair_df.iloc[idx]
        more_feature_record = self._feature_df[self._feature_df['comment_id'] == record['more_toxic_id']].iloc[0]
        less_feature_record = self._feature_df[self._feature_df['comment_id'] == record['less_toxic_id']].iloc[0]
        more_feature_tensor = torch.tensor([more_feature_record[col] for col in self._feature_col_list], dtype=torch.float32)
        less_feature_tensor = torch.tensor([less_feature_record[col] for col in self._feature_col_list], dtype=torch.float32)
        return more_feature_tensor, less_feature_tensor


### Model

In [None]:
class Model(torch.nn.Module):

    def __init__(self, num_features: int):
        super(Model, self).__init__()
        self.fc = torch.nn.Linear(num_features, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc(x)

### Metrics

In [None]:
class _Metric:

    def compute(self) -> float:
        raise NotImplementedError()

    def reset(self):
        raise NotImplementedError()

    def compute_and_reset(self) -> float:
        value = self.compute()
        self.reset()
        return value


class Accuracy(_Metric):

    def __init__(self):
        self._num_correct = 0
        self._num_total = 0

    def update(self, is_correct_tensor: torch.Tensor):
        num_correct = (is_correct_tensor == 1).int().sum().item()
        num_total = num_correct + (is_correct_tensor != 1).int().sum().item()
        self._num_correct += num_correct
        self._num_total += num_total

    def compute(self) -> float:
        assert self._num_total > 0
        return self._num_correct / self._num_total

    def reset(self):
        self._num_correct = 0
        self._num_total = 0


class _FloatListMetric(_Metric):

    def __init__(self):
        self._value_list: t.List[float] = []

    def update(self, value_tensor: torch.Tensor):
        self._value_list.extend(value_tensor.flatten().tolist())

    def reset(self):
        self._value_list.clear()


class FloatListMean(_FloatListMetric):

    def compute(self) -> float:
        return statistics.mean(self._value_list)


class FloatListStd(_FloatListMetric):

    def compute(self) -> float:
        return statistics.stdev(self._value_list)

### Loggers

In [None]:
_C = t.TypeVar('_C')


class ContextManagerList(t.Generic[_C]):

    def __init__(self, cm_list: t.List[t.ContextManager[_C]]):
        self._cm_list = cm_list

    def __enter__(self) -> t.List[_C]:
        return [cm.__enter__() for cm in self._cm_list]

    def __exit__(self, *args, **kwargs):
        for cm in self._cm_list:
            cm.__exit__(*args, **kwargs)


class Logger:

    def __enter__(self) -> Logger:
        return self

    def __exit__(self, *args, **kwargs):
        pass

    def log_params(self, params: t.Dict[str, t.Any]):
        raise NotImplementedError()

    def log_metrics(self, step: int, metrics: t.Dict[str, float]):
        raise NotImplementedError()


class StdOutLogger(Logger):

    def log_params(self, params: t.Dict[str, t.Any]):
        print('Using params:')
        for param in sorted(params.keys()):
            print(f'\t{param} = {params[param]}')

    def log_metrics(self, step: int, metrics: t.Dict[str, float]):
        print(f'Step {step} metrics:')
        for m in sorted(metrics.keys()):
            print(f'\t{m} = {metrics[m]:.8f}')


class TensorBoardLogger(Logger):

    def __init__(self, log_dir: str, metric_whitelist: t.Optional[t.Set[str]] = None) -> None:
        self._metric_whitelist = metric_whitelist
        self._writer = SummaryWriter(log_dir=log_dir)

    def log_params(self, params: t.Dict[str, t.Any]):
        pass  # TODO: handle hyperparams properly.

    def log_metrics(self, step: int, metrics: t.Dict[str, float]):
        for metric_key, metric_val in metrics.items():
            if self._metric_whitelist is None or metric_key in self._metric_whitelist:
                self._writer.add_scalar(tag=metric_key, scalar_value=metric_val, global_step=step)


class WAndBLogger(Logger):

    def __init__(self, user_name: str, api_key: str, project: str, run_id: str):
        wandb.login(key=api_key)
        self._user_name = user_name
        self._project = project
        self._run_id = run_id
        self._run: t.Optional[WAndBRun] = None
    
    @property
    def run(self) -> WAndBRun:
        assert self._run is not None
        return self._run

    def __enter__(self) -> WAndBLogger:
        self._run = wandb.init(project=self._project, entity=self._user_name, run_id=self._run_id)
        return self
    
    def log_params(self, params: t.Dict[str, t.Any]):
        self.run.config.update(params)

    def log_metrics(self, step: int, metrics: t.Dict[str, float]):
        self.run.log(step=step, data=metrics)


### Loss

In [None]:
def margin_ranking_loss(
        more_scores: torch.Tensor,
        less_scores: torch.Tensor,
        margin: torch.Tensor,
        device: str,) -> torch.Tensor:
    return torch.maximum(torch.tensor(0.0, device=device), less_scores - more_scores + margin).mean()

### Iteration functions

In [None]:
def do_train_iteration(
        data_loader: DataLoader,
        model: Model,
        device: str,
        optimizer: Optimizer,
        scheduler: CosineAnnealingWarmRestarts,
        train_margin_list: t.Optional[t.List[float]] = None,
        train_decision_margin: float = 0.3,
        accumulate_gradient_steps: int = 1) -> t.Tuple[float, t.Dict[str, float]]:
    sampler: RandomSubsetPerEpochSampler = t.cast(RandomSubsetPerEpochSampler, data_loader.sampler)
    loss_metric = FloatListMean()
    train_score_mean = FloatListMean()
    train_score_std = FloatListStd()
    if train_margin_list is None:
        train_margin_list = [train_decision_margin]
    assert train_decision_margin in train_margin_list
    train_accuracy_dict = {m: Accuracy() for m in train_margin_list}

    model.train()
    data_iter = tqdm(enumerate(data_loader), desc='Training', total=len(data_loader))
    for step, (x_more, x_less) in data_iter:
        score_more = model(x_more.to(device))
        score_less = model(x_less.to(device))
        loss = margin_ranking_loss(score_more, score_less, margin=train_decision_margin, device=device)
        loss.backward()

        if (step + 1) % accumulate_gradient_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()

        with torch.no_grad():
            score_more_cpu = score_more.cpu()
            train_score_mean.update(score_more_cpu)
            train_score_std.update(score_more_cpu)
            score_less_cpu = score_less.cpu()
            train_score_mean.update(score_less_cpu)
            train_score_std.update(score_less_cpu)

            loss_metric.update(loss.cpu())
            for m, a in train_accuracy_dict.items():
                a.update(score_more - score_less > m)
        epoch_str = f'epoch: {sampler.real_epoch} [{sampler.frac_consumed:.4f}]'
        accuracy_str = ', '.join([f'acc_{m}: {train_accuracy_dict[m].compute():.4f}' for m in sorted(train_accuracy_dict.keys())])
        data_iter.set_description(
            f'{epoch_str} loss: {loss_metric.compute():.4f}, {accuracy_str} '
            f'score_mean: {train_score_mean.compute():.6f}, score_std: {train_score_std.compute():.6f}')

    train_metrics_to_track = {f'train_accuracy_{m}': a.compute() for m, a in train_accuracy_dict.items()}
    loss_val = loss_metric.compute_and_reset()
    return train_accuracy_dict[0.0].compute(), {
        'train_loss': loss_val,
        'train_score_mean': train_score_mean.compute(),
        'train_score_std': train_score_std.compute(),
        **train_metrics_to_track,
    }

### Main function

In [None]:
def main(
        pair_df: pd.DataFrame,
        feature_df: pd.DataFrame,
        to_checkpoint: str,
        logger_list: t.List[Logger],
        num_epochs: int,
        batch_size: int,
        num_workers: int,
        device: str,
        lr: float,
        num_warmup_steps: int,
        train_margin_list: t.List[float],
        train_decision_margin: float,
        validate_every_n_samples_init: t.Optional[int] = None,
        accumulate_gradient_steps: int = 1,):    
    train_dataset = PairDataset(
        pair_df=pair_df,
        feature_df=feature_df)
    sampler = RandomSubsetPerEpochSampler(
        data_source=train_dataset,
        samples_per_epoch=validate_every_n_samples_init if validate_every_n_samples_init else len(train_dataset))
    train_data_loader = DataLoader(
        dataset=train_dataset,
        sampler=sampler,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True)
    model = Model(num_features=train_dataset.num_features).to(device)
    optimizer = AdamW(model.parameters(), lr=lr)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=len(train_data_loader) // accumulate_gradient_steps * num_epochs)

    best_accuracy = 0.0
    while sampler.real_epoch < num_epochs:
        accuracy, train_metrics_to_track = do_train_iteration(
            data_loader=train_data_loader,
            model=model,
            device=device,
            optimizer=optimizer,
            scheduler=scheduler,
            train_margin_list=train_margin_list,
            train_decision_margin=train_decision_margin,
            accumulate_gradient_steps=accumulate_gradient_steps)
        for logger in logger_list:
            logger.log_metrics(step=sampler.step, metrics={**train_metrics_to_track,})
        if accuracy > best_accuracy:
            print(f'Best accuracy improved from {best_accuracy} to {accuracy}. Saving the model.')
            torch.save(model.state_dict(), to_checkpoint)
            best_accuracy = accuracy

### Parameter definitions

In [None]:
# Parameters

IS_KAGGLE = False

ROOT_DIR_PATH = Path('/kaggle') if IS_KAGGLE else Path('/home/jovyan/jigsaw-toxic')
DATA_DIR_PATH = ROOT_DIR_PATH / ('input' if IS_KAGGLE else 'data/datasets')
DATASET_DIR_PATH = DATA_DIR_PATH / 'external_20220207_stacking'
PAIR_CSV_PATH = DATASET_DIR_PATH / 'pair.csv'
FEATURE_CSV_PATH = DATASET_DIR_PATH / 'feature.csv'
TENSORBOARD_DIR_PATH = ROOT_DIR_PATH / 'working/tensorboard' if IS_KAGGLE else Path('/home/jovyan/tensorboard')
ARTIFACT_DIR_PATH = ROOT_DIR_PATH / 'working/artifacts' if IS_KAGGLE else ROOT_DIR_PATH / 'artifacts'

TASK_NAME = 'stacking-ranking'
DATASET_NAME = 'external_20220207_stacking'
RUN_NAME = f'v1'

MODEL_NAME = f'{DATASET_NAME}-{RUN_NAME}'
NOTEBOOK_CHECKPOINT_DIR_PATH = Path(f'.checkpoints/{TASK_NAME}')
MODELS_DIR_PATH = ROOT_DIR_PATH / ('working/models' if IS_KAGGLE else 'models')

In [None]:
if not IS_KAGGLE:
    os.makedirs(NOTEBOOK_CHECKPOINT_DIR_PATH, exist_ok=True)
    shutil.copyfile('stacking-ranking.ipynb', NOTEBOOK_CHECKPOINT_DIR_PATH / f'{RUN_NAME}.ipynb')

In [None]:
# Create potentially missing directories
!mkdir -p $MODELS_DIR_PATH
!mkdir -p $ARTIFACT_DIR_PATH

In [None]:
# Read dataframes

pair_df = t.cast(pd.DataFrame, pd.read_csv(PAIR_CSV_PATH))
feature_df = t.cast(pd.DataFrame, pd.read_csv(FEATURE_CSV_PATH))

In [None]:
pair_df

In [None]:
feature_df

### Entrypoint

In [None]:
# if IS_KAGGLE:
#     from kaggle_secrets import UserSecretsClient
#     user_secrets = UserSecretsClient()
#     wandb_api_key = user_secrets.get_secret('wandb-api-token')
# else:
#     with open(os.path.join(os.path.dirname(os.getcwd()), 'deploy/secrets/wandb_api_token.txt')) as f:
#         wandb_api_key = f.read()

BATCH_SIZE = 128
ACCUMULATE_GRAD_STEPS = 1
VALID_CYCLES_PER_EPOCH = 1

with ContextManagerList([
            # StdOutLogger(),
            TensorBoardLogger(
                log_dir=str(TENSORBOARD_DIR_PATH / f'jt-{MODEL_NAME}'),
                metric_whitelist={
                    'train_loss',
                    'train_accuracy_0.0',
                },
            ),
        ]) as logger_list:
    main(
        pair_df=pair_df,
        feature_df=feature_df,
        to_checkpoint=str(MODELS_DIR_PATH / f'{MODEL_NAME}.pt'),
        logger_list=logger_list,
        num_epochs=50,
        batch_size=BATCH_SIZE,
        num_workers=2,
        device='cuda',
        lr=1e-3,
        num_warmup_steps=len(pair_df) // (BATCH_SIZE * ACCUMULATE_GRAD_STEPS * VALID_CYCLES_PER_EPOCH),
        train_margin_list=[0.0],
        train_decision_margin=0.0,
        accumulate_gradient_steps=ACCUMULATE_GRAD_STEPS,
        validate_every_n_samples_init=len(pair_df) // VALID_CYCLES_PER_EPOCH,
    )