In [1]:
from __future__ import annotations

import numpy as np
import random
import typing as t

import pandas as pd
import torch
import typing_extensions as t_ext
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
from transformers.models.auto.modeling_auto import AutoModel
from transformers.models.auto.tokenization_auto import AutoTokenizer

In [2]:
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

seed_everything(42)

In [3]:
class _TokenizedText(t_ext.TypedDict):
    input_ids: torch.Tensor
    attention_mask: torch.Tensor


def _preprocess_tokenizer_output(output: t.Dict[str, t.Any]) -> _TokenizedText:
    return {
        'input_ids': torch.tensor(output['input_ids']),
        'attention_mask': torch.tensor(output['attention_mask']),
    }


def _split_str_to_chunk_list(s: str, chunk_size: int) -> t.List[str]:
    chunk_list = []
    chunk = []
    for token in s.split(' '):
        chunk.append(token)
        if len(chunk) >= chunk_size:
            chunk_list.append(' '.join(chunk))
            chunk.clear()
    if chunk:
        chunk_list.append(' '.join(chunk))
    return chunk_list


def predict_collate_fn(
        sample_list: t.List[t.Tuple[str, _TokenizedText]]
        ) -> t.Tuple[t.List[str], _TokenizedText, t.List[slice]]:
    curr_pos = 0

    idx_list: t.List[str] = []
    input_ids_list = []
    attention_mask_list = []
    slice_list: t.List[slice] = []
    
    for sample in sample_list:
        idx_list.append(sample[0])
        input_ids, attention_mask = sample[1]['input_ids'], sample[1]['attention_mask']
        input_ids_list.append(input_ids)
        attention_mask_list.append(attention_mask)
        slice_list.append(slice(curr_pos, curr_pos + input_ids.shape[0]))
        curr_pos += input_ids.shape[0]

    tokenized_collated: _TokenizedText = {
        'input_ids': torch.cat(input_ids_list, dim=0),
        'attention_mask': torch.cat(attention_mask_list, dim=0),
    }

    return idx_list, tokenized_collated, slice_list


class PredictDataset(Dataset):

    def __init__(self, df: pd.DataFrame, tokenizer: AutoTokenizer, max_len: int) -> None:
        super().__init__()
        self._df = df
        self._tokenizer = tokenizer
        self._max_len = max_len

    def __len__(self) -> int:
        return len(self._df)

    def __getitem__(self, idx: int) -> t.Tuple[str, _TokenizedText]:
        record = self._df.iloc[idx]
        comment_id, text = str(record['comment_id']), str(record['text'])

        input_ids_list, attention_mask_list = [], []
        for chunk in _split_str_to_chunk_list(text, chunk_size=self._max_len):
            tokenized_chunk = _preprocess_tokenizer_output(self._tokenizer(
                chunk,
                add_special_tokens=True,
                truncation=True,
                padding='max_length',
                max_length=self._max_len,
                return_attention_mask=True))  # type: ignore
            input_ids_list.append(tokenized_chunk['input_ids'])
            attention_mask_list.append(tokenized_chunk['attention_mask'])

        tokenized_text: _TokenizedText = {
            'input_ids': torch.stack(input_ids_list, dim=0),
            'attention_mask': torch.stack(attention_mask_list, dim=0),
        }

        return comment_id, tokenized_text

In [4]:
class Model(torch.nn.Module):

    def __init__(self, checkpoint: str, output_logits: int):
        super(Model, self).__init__()
        self.bert = AutoModel.from_pretrained(checkpoint, return_dict=False)
        self.regressor = torch.nn.Sequential(
            torch.nn.Linear(output_logits, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 1),
            torch.nn.Tanh(),
        )

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        _, pooled_output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask)
        return self.regressor(pooled_output)

In [5]:
def do_prediction_iteration(
        data_loader: DataLoader,
        model: Model,
        device: str) -> pd.DataFrame:
    model.eval()
    prediction_dict_list = []
    with torch.no_grad():
        for comment_id_list, tokenized_text, slice_list in tqdm(data_loader, desc='Prediction'):
            scores_tensor = model(
                tokenized_text['input_ids'].to(device),
                tokenized_text['attention_mask'].to(device),)
            scores_tensor = torch.cat([torch.max(scores_tensor[s], dim=0, keepdim=True)[0] for s in slice_list], dim=0)
            for comment_id, score in zip(comment_id_list, scores_tensor.flatten().tolist()):
                prediction_dict_list.append({
                    'comment_id': comment_id,
                    'score': score,
                })
    return pd.DataFrame(prediction_dict_list)

In [6]:
def main(
    input_csv_path: str,
    output_csv_path: str,
    batch_size: int,
    model_checkpoint: str,
    tokenizer_checkpoint: str,
    max_len: int,
    output_logits: int,
    num_workers: int,
    device: str,
):
    model = Model(checkpoint=tokenizer_checkpoint, output_logits=output_logits).to(device)
    model.load_state_dict(torch.load(model_checkpoint, map_location=device))
    in_df = pd.read_csv(input_csv_path)
    dataset = PredictDataset(
        df=in_df,
        tokenizer=AutoTokenizer.from_pretrained(tokenizer_checkpoint),
        max_len=max_len)
    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=predict_collate_fn,  # type: ignore
        num_workers=num_workers,
        pin_memory=device.startswith('cuda'))
    out_df = do_prediction_iteration(data_loader=data_loader, model=model, device=device)
    out_df.to_csv(output_csv_path, index=False)

In [7]:
IS_KAGGLE = False
MAX_LEN = 256
OUTPUT_LOGITS = 768

INPUT_CSV_PATH = '/kaggle/input/jigsaw-toxic-severity-rating/comments_to_score.csv' if IS_KAGGLE \
    else '/home/jovyan/jigsaw-toxic/data/jigsaw-toxic-severity-rating/comments_to_score.csv'
OUTPUT_CSV_PATH = '/kaggle/working/submission.csv' if IS_KAGGLE else '/home/jovyan/jigsaw-toxic/output/ruddit.csv'
MODEL_PATH = '/kaggle/input/jt-models-unintended-bias-in-toxicity/margin-ranking-unintended-bias-in-toxicity-classification-v1-unbiased-toxic-roberta.pt' if IS_KAGGLE else \
    '/home/jovyan/jigsaw-toxic/models/margin-ranking-unintended-bias-in-toxicity-classification-v1-unbiased-toxic-roberta.pt'
TOKENIZER_PATH = '/kaggle/input/unbiased-toxic-roberta/unbiased-toxic-roberta_update/unbiased-toxic-roberta' if IS_KAGGLE else 'unitary/unbiased-toxic-roberta'
BATCH_SIZE = 16 if IS_KAGGLE else 8
NUM_WORKERS = 2 if IS_KAGGLE else 8

In [None]:
main(
    input_csv_path=INPUT_CSV_PATH,
    output_csv_path=OUTPUT_CSV_PATH,
    batch_size=BATCH_SIZE,
    model_checkpoint=MODEL_PATH,
    tokenizer_checkpoint=TOKENIZER_PATH,
    num_workers=NUM_WORKERS,
    output_logits=OUTPUT_LOGITS,
    max_len=MAX_LEN,
    device='cuda')