In [1]:
import random
import typing as t
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import typing_extensions as t_ext
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
from transformers.models.auto.modeling_auto import AutoModelForSequenceClassification
from transformers.models.auto.tokenization_auto import AutoTokenizer

In [2]:
tqdm.pandas()

In [3]:
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)

seed_everything(42)

In [4]:
DATASET_DIR = Path('/home/jovyan/jigsaw-toxic/data/datasets/ccc-2017-multilabel')
COMBINED_DIR = Path('/home/jovyan/jigsaw-toxic/data/datasets/combined')

In [None]:
!ls -la $DATASET_DIR

In [5]:
all_df = pd.read_csv(DATASET_DIR / 'train_no_leak.csv')
all_with_leak_df = pd.read_csv(COMBINED_DIR / 'train_comment_classification_challenge_2017.csv')
valid_df = pd.read_csv(COMBINED_DIR / 'valid.csv')

In [None]:
all_df['n_flags'] = all_df.progress_apply(lambda row: row['toxic'] + row['severe_toxic'] + row['obscene'] + row['threat'] + row['insult'] + row['identity_hate'], axis=1)

In [None]:
all_df[all_df['n_flags'] > 0]

In [None]:
all_df.to_csv(DATASET_DIR / 'train_no_leak_expanded.csv', index=False)

In [None]:
valid_df

In [None]:
class _TokenizedText(t_ext.TypedDict):
    input_ids: torch.Tensor
    attention_mask: torch.Tensor


def _preprocess_tokenizer_output(output: t.Dict[str, t.Any]) -> _TokenizedText:
    return {
        'input_ids': torch.tensor(output['input_ids']),
        'attention_mask': torch.tensor(output['attention_mask']),
    }


class ValidDataset(Dataset):

    def __init__(self, df: pd.DataFrame, tokenizer: AutoTokenizer, max_len: int) -> None:
        super().__init__()
        self._df = df
        self._tokenizer = tokenizer
        self._max_len = max_len

    def __len__(self) -> int:
        return len(self._df)

    def __getitem__(self, idx: int) -> t.Tuple[int, _TokenizedText, _TokenizedText]:
        record = self._df.iloc[idx]
        text_more = str(record['more_toxic'])
        text_less = str(record['less_toxic'])

        tokenized_more = _preprocess_tokenizer_output(self._tokenizer(
            text_more,
            add_special_tokens=True,
            truncation=True,
            padding='max_length',
            max_length=self._max_len,
            return_attention_mask=True))  # type: ignore
        tokenized_less = _preprocess_tokenizer_output(self._tokenizer(
            text_less,
            add_special_tokens=True,
            truncation=True,
            padding='max_length',
            max_length=self._max_len,
            return_attention_mask=True))  # type: ignore

        return idx, tokenized_more, tokenized_less


class PredictDataset(Dataset):

    def __init__(self, df: pd.DataFrame, tokenizer: AutoTokenizer, max_len: int) -> None:
        super().__init__()
        self._df = df
        self._tokenizer = tokenizer
        self._max_len = max_len

    def __len__(self) -> int:
        return len(self._df)

    def __getitem__(self, idx: int) -> t.Tuple[int, _TokenizedText]:
        record = self._df.iloc[idx]
        text = str(record['comment_text'])

        tokenized_text = _preprocess_tokenizer_output(self._tokenizer(
            text,
            add_special_tokens=True,
            truncation=True,
            padding='max_length',
            max_length=self._max_len,
            return_attention_mask=True))  # type: ignore

        return idx, tokenized_text


@torch.no_grad()
def get_valid_toxicity_labels_by_model(
        valid_df: pd.DataFrame,
        model_checkpoint: str = 'unitary/toxic-bert',
        batch_size: int = 8,
        num_workers: int = 8,
        device: str = 'cuda',
        threshold: float = 0.5,
        not_availabel_tag: str = '<na>') -> t.Tuple[t.List[int], t.List[str], t.List[str]]:
    valid_df = valid_df.copy()
    valid_df['more_toxic_bitmap_label'] = not_availabel_tag
    valid_df['less_toxic_bitmap_label'] = not_availabel_tag
    model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint).to(device)
    dataset = ValidDataset(
        valid_df,
        tokenizer=AutoTokenizer.from_pretrained(model_checkpoint),
        max_len=256)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    idx_list, more_toxic_bitmap_label_list, less_toxic_bitmap_label_list = [], [], []
    for idx, tokenized_more, tokenized_less in tqdm(data_loader):
        preds_more = model(tokenized_more['input_ids'].to(device), tokenized_more['attention_mask'].to(device))[0]
        preds_less = model(tokenized_less['input_ids'].to(device), tokenized_less['attention_mask'].to(device))[0]
        labels_more, labels_less = (preds_more > threshold).int(), (preds_less > threshold).int()
        for i, lm, ll in zip(idx, labels_more, labels_less):
            idx_list.append(i)
            more_toxic_bitmap_label_list.append(' '.join([str(x) for x in lm.flatten().tolist()]))
            less_toxic_bitmap_label_list.append(' '.join([str(x) for x in ll.flatten().tolist()]))
    return idx_list, more_toxic_bitmap_label_list, less_toxic_bitmap_label_list


@torch.no_grad()
def get_eval_toxicity_labels_by_model(
        eval_df: pd.DataFrame,
        cls_list: t.List[str],
        model_checkpoint: str = 'unitary/toxic-bert',
        batch_size: int = 8,
        num_workers: int = 8,
        device: str = 'cuda',
        threshold: float = 0.5,
        not_availabel_tag: str = '<na>') -> t.Tuple[t.List[int], t.List[str], t.List[str]]:
    eval_df = eval_df.copy()
    for cls in cls_list:
        eval_df[f'{cls}_score'] = float('nan')
    model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint).to(device)
    dataset = ValidDataset(
        valid_df,
        tokenizer=AutoTokenizer.from_pretrained(model_checkpoint),
        max_len=256)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    idx_list, more_toxic_bitmap_label_list = [], []
    for idx, tokenized_more, tokenized_less in tqdm(data_loader):
        preds_more = model(tokenized_more['input_ids'].to(device), tokenized_more['attention_mask'].to(device))[0]
        preds_less = model(tokenized_less['input_ids'].to(device), tokenized_less['attention_mask'].to(device))[0]
        labels_more, labels_less = (preds_more > threshold).int(), (preds_less > threshold).int()
        for i, lm, ll in zip(idx, labels_more, labels_less):
            idx_list.append(i)
            more_toxic_bitmap_label_list.append(' '.join([str(x) for x in lm.flatten().tolist()]))
            less_toxic_bitmap_label_list.append(' '.join([str(x) for x in ll.flatten().tolist()]))
    return idx_list, more_toxic_bitmap_label_list, less_toxic_bitmap_label_list

In [None]:
idx_list, more_toxic_bitmap_label_list, less_toxic_bitmap_label_list = get_toxicity_labels_by_model(valid_df, model_checkpoint='unitary/toxic-bert')

In [None]:
valid_labeled_df = valid_df.copy()
valid_labeled_df.loc[idx_list, 'more_toxic_bitmap_label'] = more_toxic_bitmap_label_list
valid_labeled_df.loc[idx_list, 'less_toxic_bitmap_label'] = less_toxic_bitmap_label_list

In [None]:
valid_labeled_df

In [None]:
valid_labeled_df.to_csv(DATASET_DIR / 'valid_labeled.csv', index=False)

In [6]:
valid_labeled_df = pd.read_csv(DATASET_DIR / 'valid_labeled.csv')

In [None]:
ubtc_idx_list, ubtc_more_toxic_bitmap_label_list, ubtc_less_toxic_bitmap_label_list = get_toxicity_labels_by_model(valid_df, model_checkpoint='unitary/unbiased-toxic-roberta')

In [8]:
valid_labeled_df

Unnamed: 0,less_toxic,more_toxic,certainty,more_toxic_bitmap_label,less_toxic_bitmap_label
0,"I think people say its not so absurd, because ...",Re Vandalism. . . . that is because the user ...,0.666667,1 0 0 0 0 0,0 0 0 0 0 0
1,You must have the lowest IQ of anyone on wikip...,Wha's a reliable source mike? Something that c...,1.000000,1 0 1 0 1 0,1 0 0 0 1 0
2,"Thanks for removing more of his anonymous, un...","To Bad\nyou dont have the balls to sign in, I ...",1.000000,1 0 0 0 1 0,0 0 0 0 0 0
3,"""\n\nActually, it clearly does not involve pow...",", 18 November 2006 (UTC)\n\nWell comrade. Look...",0.666667,1 0 0 0 0 0,0 0 0 0 0 0
4,You are not cool \n\nYou are the most hated f...,"""== Hey Fuck-head==\nHey, just wanted to say ""...",1.000000,1 0 1 0 1 0,1 0 1 0 1 0
...,...,...,...,...,...
10103,Heeeeeeeeeeeeeeyyyyyyyyyyyy dude! Sup!,you deleted my page \nand i am extrememly mad ...,1.000000,0 0 0 0 0 0,0 0 0 0 0 0
10104,ATTENTION:''''''\n\nANYONE WHO OPPOSES MY OPIN...,suck dick u disrespectful swearing wank!!,1.000000,1 0 1 0 1 0,0 0 0 0 0 0
10105,"Oh yah, and Bayerischermann is another guy try...",Thanks. Also thanks for freezing the article. ...,1.000000,1 0 0 0 0 0,0 0 0 0 0 0
10106,I had to chop out a section specifically compa...,"""\n\nHorrifyingly enough, """"ritualistic penis ...",0.666667,1 0 0 0 0 0,0 0 0 0 0 0


In [11]:
len(valid_labeled_df[valid_labeled_df['more_toxic_bitmap_label'] == valid_labeled_df['less_toxic_bitmap_label']]) / len(valid_labeled_df)

0.36149584487534625

In [7]:
CCC2017_CLS_LIST = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']

def _bitmap_to_readable(label: str) -> str:
    return ' '.join([l for l, b in zip(CCC2017_CLS_LIST, label.split(' ')) if int(b) > 0])

In [8]:
_bitmap_to_readable('1 0 1 0 1 0')

'toxic obscene insult'

In [9]:
def build_readable_label(row: t.Dict[str, int]) -> str:
    return ' '.join([cls for cls in CCC2017_CLS_LIST if row[cls]])

def build_bitmap_label(row: t.Dict[str, int]) -> str:
    return ' '.join([str(row[cls]) for cls in CCC2017_CLS_LIST])


all_df['bitmap_label'] = all_df.progress_apply(lambda row: build_bitmap_label(row), axis=1)
all_df['readable_label'] = all_df.progress_apply(lambda row: build_readable_label(row), axis=1)

  0%|          | 0/151942 [00:00<?, ?it/s]

  0%|          | 0/151942 [00:00<?, ?it/s]

In [23]:
def mine_pairs_v8(all_df: pd.DataFrame, valid_df: pd.DataFrame, samples_per_pair: int = 5) -> pd.DataFrame:
    row_list = []
    for _, row in tqdm(valid_df.iterrows(), total=len(valid_df)):
        more_row_population_df = all_df[all_df['bitmap_label'] == row['more_toxic_bitmap_label']]
        less_row_population_df = all_df[all_df['bitmap_label'] == row['less_toxic_bitmap_label']]
        more_row_df = more_row_population_df.sample(n=min(samples_per_pair, len(more_row_population_df)))
        less_row_df = less_row_population_df.sample(n=min(samples_per_pair, len(less_row_population_df)))
        for (_, more_row), (_, less_row) in zip(more_row_df.iterrows(), less_row_df.iterrows()):
            row_list.append({
                'more_toxic': str(more_row['comment_text']),
                'less_toxic': str(less_row['comment_text']),
                'more_toxic_bitmap_label': str(more_row['bitmap_label']),
                'less_toxic_bitmap_label': str(less_row['bitmap_label']),
            })
    return pd.DataFrame(row_list)

In [24]:
train_pair_v8_df = mine_pairs_v8(all_df=all_df, valid_df=valid_labeled_df, samples_per_pair=5)

  0%|          | 0/10108 [00:00<?, ?it/s]

In [25]:
train_pair_v8_df

Unnamed: 0,more_toxic,less_toxic,more_toxic_bitmap_label,less_toxic_bitmap_label
0,Hate you \n\ni;uhbsirtubgyihihlkjngkjbnkgjnbkj...,Tuesday's With Morrie Movie Fact\n\nIn the Fil...,1 0 0 0 0 0,0 0 0 0 0 0
1,"Dean Gaffney \n\nJPS, I'm not suprised it was ...",The situation is either every longer term edit...,1 0 0 0 0 0,0 0 0 0 0 0
2,do research \n\nplease do some research before...,"posthumous adoption \n\nUh, this means that he...",1 0 0 0 0 0,0 0 0 0 0 0
3,"Stop adding unsourced, biased info to Joseph B...","Also, according to Wikipedia:Blocking policy, ...",1 0 0 0 0 0,0 0 0 0 0 0
4,"Leftcoastman \n\nSeriously, you're a whiny lit...",Wayward likes teeth in his crust!,1 0 0 0 0 0,0 0 0 0 0 0
...,...,...,...,...
50523,"}}\n\n{{Unblock|Impudent, feckless, dick. How ...",or else I WOULD DESTROY YOU!!!!!,1 0 1 0 0 0,1 0 0 1 0 0
50524,"""\nYea. It is weird. ;) Man, that does suck. H...",I'm gonna beat you to a bloody pulp then shoot...,1 0 1 0 0 0,1 0 0 1 0 0
50525,They're about to start construction on new lin...,I'ma smack ya upside da head wit a shovel \n\n...,1 0 1 0 0 0,1 0 0 1 0 0
50526,"""\n\n Era's \n\nNot every single day of WWE hi...",you are part of isis \n\nCouple of days ago u ...,1 0 1 0 0 0,1 0 0 1 0 0


In [20]:
!ls -la $DATASET_DIR

total 28213472
drwxr-xr-x  2 jovyan users        4096 Jan 27 16:07 .
drwxr-xr-x 11 jovyan users        4096 Jan 26 19:03 ..
-rw-r--r--  1 jovyan users        1699 Jan 12 11:46 label_toxicity.csv
-rw-r--r--  1 jovyan users    64981283 Jan 13 08:38 train_no_leak.csv
-rw-r--r--  1 jovyan users    67448851 Jan 13 08:40 train_no_leak_expanded.csv
-rw-r--r--  1 jovyan users    30080933 Jan  6 20:31 train_no_leak_pair.csv
-rw-r--r--  1 jovyan users    38018364 Dec 30 18:35 train_no_leak_pair_harder_1.csv
-rw-r--r--  1 jovyan users    38087288 Dec 30 18:35 train_no_leak_pair_harder_2.csv
-rw-r--r--  1 jovyan users    37854841 Dec 30 18:35 train_no_leak_pair_harder_3.csv
-rw-r--r--  1 jovyan users    37854841 Dec 30 12:17 train_no_leak_pair_harder.csv
-rw-r--r--  1 jovyan users  3418444331 Jan 12 12:03 train_no_leak_pair_v2.csv
-rw-r--r--  1 jovyan users 24790812945 Jan 12 12:02 train_no_leak_pair_v2_full.csv
-rw-r--r--  1 jovyan users    28057742 Jan 13 10:30 train_no_leak_pair_v3.csv
-rw-r--r

In [26]:
train_pair_v8_df.to_csv(DATASET_DIR / 'train_no_leak_pair_v8.csv', index=False)

In [None]:
class L:
    TOXIC = 'toxic'
    SEVERE_TOXIC = 'severe_toxic'
    INSULT = 'insult'
    OBSCENE = 'obscene'
    IDENTITY_HATE = 'identity_hate'
    THREAT = 'threat'


class MoreCondition:

    def mask(self, less_label_set: t.Set[str], df: pd.DataFrame) -> t.Optional[pd.Series]:
        raise NotImplementedError()


class SimpleMoreCondition(MoreCondition):
    _ALL_LABEL_SET = {
        L.TOXIC,
        L.SEVERE_TOXIC,
        L.INSULT,
        L.OBSCENE,
        L.IDENTITY_HATE,
        L.THREAT,
    }

    def mask(self, less_label_set: t.Set[str], df: pd.DataFrame) -> t.Optional[pd.Series]:
        """
        More row contains all the labels from the `less_label_set` and 
        at least one label from the `self._ALL_LABEL_SET - less_label_set`.
        """
        all_less_labels_mask: t.Optional[pd.Series] = None
        for label in less_label_set:
            all_less_labels_mask = all_less_labels_mask & (df[label] == 1) \
                if all_less_labels_mask is not None else (df[label] == 1)
        at_least_one_more_label_mask: t.Optional[pd.Series] = None
        for label in self._ALL_LABEL_SET - less_label_set:
            at_least_one_more_label_mask = at_least_one_more_label_mask | (df[label] == 1) \
                if at_least_one_more_label_mask is not None else (df[label] == 1)
        if all_less_labels_mask is None and at_least_one_more_label_mask is None:
            return None
        elif all_less_labels_mask is not None and at_least_one_more_label_mask is None:
            return all_less_labels_mask
        elif all_less_labels_mask is None and at_least_one_more_label_mask is not None:
            return at_least_one_more_label_mask
        assert all_less_labels_mask is not None and at_least_one_more_label_mask is not None
        return all_less_labels_mask & at_least_one_more_label_mask


class ComparisonBasedMoreCondition(MoreCondition):

    def __init__(self, less_has: t.Set[str], more_has: t.Set[str]):
        self._less_has = less_has
        self._more_has = more_has

    def mask(self, less_label_set: t.Set[str], df: pd.DataFrame) -> t.Optional[pd.Series]:
        if not self._less_has.issubset(less_label_set):
            return None
        base_label_set = less_label_set - self._less_has
        mask: t.Optional[pd.Series] = None
        for label in base_label_set | self._more_has:
            mask = mask & (df[label] == 1) if mask is not None else (df[label] == 1)
        assert mask is not None
        return mask


more_condition_list = [
    # SimpleMoreCondition(),
    # Inferred directly from the valid data.
    ComparisonBasedMoreCondition({L.INSULT}, {L.SEVERE_TOXIC}),
    ComparisonBasedMoreCondition({L.IDENTITY_HATE}, {L.SEVERE_TOXIC}),
    ComparisonBasedMoreCondition({L.TOXIC}, {L.OBSCENE, L.INSULT}),
    ComparisonBasedMoreCondition({L.INSULT}, {L.IDENTITY_HATE}),
    ComparisonBasedMoreCondition({L.TOXIC}, {L.IDENTITY_HATE}),
    ComparisonBasedMoreCondition({L.THREAT}, {L.OBSCENE, L.INSULT}),
    ComparisonBasedMoreCondition({L.INSULT}, {L.SEVERE_TOXIC, L.OBSCENE}),
    ComparisonBasedMoreCondition({L.OBSCENE}, {L.IDENTITY_HATE}),
    # Inferred from the transitivity of < operation.
    # ComparisonBasedMoreCondition({L.TOXIC}, {L.SEVERE_TOXIC}),
    # ComparisonBasedMoreCondition({L.THREAT}, {L.OBSCENE, L.SEVERE_TOXIC}),
    # ComparisonBasedMoreCondition({L.THREAT}, {L.OBSCENE, L.IDENTITY_HATE}),
    # ComparisonBasedMoreCondition({L.OBSCENE}, {L.SEVERE_TOXIC}),
    # ComparisonBasedMoreCondition({L.THREAT}, {L.SEVERE_TOXIC, L.INSULT}),
    # ComparisonBasedMoreCondition({L.THREAT}, {L.IDENTITY_HATE, L.INSULT}),
    # Inferred from the common sense.
    # ComparisonBasedMoreCondition({L.TOXIC}, {L.THREAT}),
    # ComparisonBasedMoreCondition({L.OBSCENE}, {L.THREAT}),
]


def _label_set(readable_label_str: str) -> t.Set[str]:
    return set(readable_label_str.split(' ')) if readable_label_str else set()


def mine_pairs(
        df: pd.DataFrame,
        more_condition_list: t.List[MoreCondition],
        non_toxic_ratio: float = 0.1,
        max_n_flags_distance: int = 2,
        toxic_max_pairs_per_sample: int = 3,
        non_toxic_max_pairs_per_sample: int = 3) -> pd.DataFrame:
    simple_more_condition = SimpleMoreCondition()
    pair_row_list = []
    num_no_pairs = 0
    toxic_df = df[df['n_flags'] > 0]
    non_toxic_df = df[df['n_flags'] == 0]
    less_df = pd.concat([
        toxic_df,
        non_toxic_df.sample(frac=1.0).iloc[:int(len(toxic_df) * non_toxic_ratio)]
    ])
    it = tqdm(less_df.iterrows(), total=len(less_df))
    for idx, less_row in it:
        less_label_set = _label_set(less_row['readable_label'])
        base_mask = (df.index != idx) \
            & (df['readable_label'] != less_row['readable_label']) \
            & (df['n_flags'] >= less_row['n_flags']) \
            & (df['n_flags'] <= less_row['n_flags'] + max_n_flags_distance)
        cond_mask = None
        for c in more_condition_list:
            c_mask = c.mask(less_label_set, df)
            if c_mask is not None:
                cond_mask = cond_mask | c_mask if cond_mask is not None else c_mask
        if cond_mask is None and less_row['n_flags'] == 0:
           cond_mask = simple_more_condition.mask(less_label_set, df)
        # if cond_mask is None:
        #     raise RuntimeError(f'Conditions failed with label set: {less_label_set}')
        if cond_mask is None:
            continue
        more_df = df[base_mask & cond_mask]
        if len(more_df):
            for _, more_row in more_df.sample(
                    n=min(len(more_df), toxic_max_pairs_per_sample if less_row['n_flags'] > 0 else non_toxic_max_pairs_per_sample)).iterrows():
                pair_row_list.append({
                    'less_toxic': less_row['comment_text'],
                    'more_toxic': more_row['comment_text'],
                    'less_toxic_readable_label': less_row['readable_label'],
                    'less_toxic_bitmap_label': less_row['bitmap_label'],
                    'more_toxic_readable_label': more_row['readable_label'],
                    'more_toxic_bitmap_label': more_row['bitmap_label'],
                    'is_subset': int(less_label_set.issubset(_label_set(more_row['readable_label']))),
                })
        else:
            num_no_pairs += 1
        it.set_description(f'num_pairs: {len(pair_row_list)}, num_no_pairs: {num_no_pairs}')
    return pd.DataFrame(pair_row_list)

In [None]:
# pair_df = mine_pairs(df=all_df, more_condition_list=more_condition_list)
pair_df = mine_pairs(
    df=all_df,
    more_condition_list=more_condition_list,
    non_toxic_ratio=0.25,
    max_n_flags_distance=3,
    toxic_max_pairs_per_sample=9,
    non_toxic_max_pairs_per_sample=3)