In [1]:
import pandas as pd

In [2]:
df_raw = pd.read_csv("./toxic.csv")
df_raw.head()

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,0000997932d777bf,Explanation\nWhy the edits made under my usern...,0,0,0,0,0,0
1,000103f0d9cfb60f,D'aww! He matches this background colour I'm s...,0,0,0,0,0,0
2,000113f07ec002fd,"Hey man, I'm really not trying to edit war. It...",0,0,0,0,0,0
3,0001b41b1c6bb37e,"""\nMore\nI can't make any real suggestions on ...",0,0,0,0,0,0
4,0001d958c54c6e35,"You, sir, are my hero. Any chance you remember...",0,0,0,0,0,0


In [3]:
TOXIC_LABELS = df_raw.columns.tolist()[2:]
toxic_items = df_raw[df_raw[TOXIC_LABELS].sum(axis=1) > 0]
ok_items = df_raw[df_raw[TOXIC_LABELS].sum(axis=1) == 0]

In [4]:
sample_size = 100

In [5]:
df = pd.concat([
    toxic_items.sample(sample_size),
    ok_items.sample(sample_size)
])

In [6]:
df["label"] = (df[TOXIC_LABELS].sum(axis=1) > 0).astype(int)
df

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate,label
129361,b3e617c314a6a264,You little pussies. This banning is totally wo...,1,0,1,0,1,0,1
90774,f2d7f3deb6a0a4db,red flag baloney \n\nbaloney is what your moth...,1,0,0,0,0,0,1
98809,108f99abbedc8bbb,"""\nYou call that creep """"a contributor""""? You'...",1,0,0,0,1,0,1
5078,0d7b679afe6b35f1,"""\n\nyou're so immature and condescending, gro...",1,0,0,0,0,0,1
80300,d6d082903ad49b3b,Sidenote: Xizer is not a racist. Calling this ...,1,0,1,0,1,1,1
...,...,...,...,...,...,...,...,...,...
93341,f9919a97302053d2,"""\nNot to get off topic (or on topic), but I *...",0,0,0,0,0,0,0
95119,fe48afd2b2d35a1a,Raul654 has just reinserted a much gentler war...,0,0,0,0,0,0,0
49736,84fa9307a3718099,I put a copy of the thumbnail image into WikiC...,0,0,0,0,0,0,0
38442,669a784bf6142ea1,"I have no objections to that change, that was ...",0,0,0,0,0,0,0


In [7]:
from datasets import Dataset

toxic_dataset = Dataset.from_pandas(df).train_test_split(test_size=0.2)
toxic_dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'comment_text', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate', 'label', '__index_level_0__'],
        num_rows: 160
    })
    test: Dataset({
        features: ['id', 'comment_text', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate', 'label', '__index_level_0__'],
        num_rows: 40
    })
})

In [8]:
from transformers import DistilBertTokenizer

bert_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

def tokenize(data):
    return bert_tokenizer(data["comment_text"], truncation=True)

In [9]:
toxic_dataset_tokenized = toxic_dataset.map(tokenize)

Map:   0%|          | 0/160 [00:00<?, ? examples/s]

Map:   0%|          | 0/40 [00:00<?, ? examples/s]

In [10]:
pd.DataFrame(toxic_dataset_tokenized["test"]).head()

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate,label,__index_level_0__,input_ids,attention_mask
0,9700072662fdc436,"""\n\nLobojo wrote: """"I gave you links above to...",0,0,0,0,0,0,0,56530,"[101, 1000, 8840, 5092, 5558, 2626, 1024, 1000...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
1,1f7854ce5385a645,fuck you george \n\nakhtak sharmota,1,1,1,0,1,0,1,11881,"[101, 6616, 2017, 2577, 17712, 22893, 2243, 21...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]"
2,04be53731563cef9,You should block this idiot for life!,1,0,1,0,1,0,1,96614,"[101, 2017, 2323, 3796, 2023, 10041, 2005, 216...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]"
3,d2d884c509460b8e,"""\nObservation of Barack's bio or The Wiki-Oba...",0,0,0,0,0,0,0,135135,"[101, 1000, 8089, 1997, 13857, 1005, 1055, 160...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
4,108f99abbedc8bbb,"""\nYou call that creep """"a contributor""""? You'...",1,0,0,0,1,0,1,98809,"[101, 1000, 2017, 2655, 2008, 19815, 1000, 100...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."


In [11]:
from transformers import DataCollatorWithPadding

collator = DataCollatorWithPadding(bert_tokenizer)

In [12]:
import evaluate

metric = evaluate.load("accuracy")
def calc_metrics(evaluation):
    logits, labels = evaluation
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [13]:
from transformers import TrainingArguments, Trainer
import numpy as np

In [14]:
batch_size = 32
epochs = 1

In [15]:
train_args = TrainingArguments(
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=epochs,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    output_dir="./output/models",
    logging_strategy="epoch",
    logging_dir="./output/log",
)

In [16]:
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

bert_classification_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
bert_classification_model.config.id2label = {0: "Ok", 1: "Toxic"}
bert_classification_model.config.label2id = {"Ok": 0, "Toxic": 1}

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [17]:
trainer = Trainer(
    model=bert_classification_model,
    processing_class=bert_tokenizer,
    args=train_args,
    train_dataset=toxic_dataset_tokenized["train"],
    eval_dataset=toxic_dataset_tokenized["test"],
    compute_metrics=calc_metrics,
    data_collator=collator,
)

In [18]:
trainer.evaluate()

{'eval_loss': 0.6714123487472534,
 'eval_model_preparation_time': 0.0011,
 'eval_accuracy': 0.6,
 'eval_runtime': 7.2352,
 'eval_samples_per_second': 5.529,
 'eval_steps_per_second': 0.276}

In [19]:
trainer.train()

Epoch,Training Loss,Validation Loss,Model Preparation Time,Accuracy
1,0.6686,0.613915,0.0011,0.8


TrainOutput(global_step=5, training_loss=0.668589735031128, metrics={'train_runtime': 92.0527, 'train_samples_per_second': 1.738, 'train_steps_per_second': 0.054, 'total_flos': 17560209534336.0, 'train_loss': 0.668589735031128, 'epoch': 1.0})

In [20]:
trainer.save_model()

In [23]:
api_key = 'hf_foobarbaz'
path = 'foo/bar'

In [24]:
trainer.model.push_to_hub(
    repo_id=path, token=api_key
)

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/ZhengjunHUO/distilbert-toxicity-classifier/commit/9132a8460eb2b1f2167f904c7660a91500085c76', commit_message='Upload DistilBertForSequenceClassification', commit_description='', oid='9132a8460eb2b1f2167f904c7660a91500085c76', pr_url=None, repo_url=RepoUrl('https://huggingface.co/ZhengjunHUO/distilbert-toxicity-classifier', endpoint='https://huggingface.co', repo_type='model', repo_id='ZhengjunHUO/distilbert-toxicity-classifier'), pr_revision=None, pr_num=None)

In [26]:
trainer.tokenizer.push_to_hub(
    repo_id=path, token=api_key
)

Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.


CommitInfo(commit_url='https://huggingface.co/ZhengjunHUO/distilbert-toxicity-classifier/commit/d65f388c52f34e28a4bd727281fc8513d620704e', commit_message='Upload tokenizer', commit_description='', oid='d65f388c52f34e28a4bd727281fc8513d620704e', pr_url=None, repo_url=RepoUrl('https://huggingface.co/ZhengjunHUO/distilbert-toxicity-classifier', endpoint='https://huggingface.co', repo_type='model', repo_id='ZhengjunHUO/distilbert-toxicity-classifier'), pr_revision=None, pr_num=None)

In [27]:
from transformers import pipeline

t_classifier = pipeline(
    'text-classification',
    model=trainer.model,
    tokenizer=bert_tokenizer,
    use_fast=True,
    top_k=None
)

Device set to use cpu


In [28]:
t_classifier('whoever wrote this is a waste of space')

[[{'label': 'Toxic', 'score': 0.5510469675064087},
  {'label': 'Ok', 'score': 0.4489530026912689}]]

In [None]:
# from transformers import pipeline

# clf = pipeline(
#   'text-classification',
#   "ZhengjunHUO/distilbert-toxicity-classifier", 
#   use_fast=True, 
#   #return_all_scores=True
# )