In [1]:
from transformers import BertTokenizer
from pathlib import Path
import torch

from box import Box
import pandas as pd
import collections
import os
from tqdm import tqdm, trange
import sys
import random
import numpy as np
import apex
from sklearn.model_selection import train_test_split

import datetime

from fast_bert.modeling import BertForMultiLabelSequenceClassification
from fast_bert.data_cls import BertDataBunch, InputExample, InputFeatures, MultiLabelTextProcessor, convert_examples_to_features
from fast_bert.learner_cls import BertLearner
from fast_bert.metrics import accuracy_multilabel, accuracy_thresh, fbeta, roc_auc

In [2]:
torch.cuda.empty_cache()

In [3]:
pd.set_option('display.max_colwidth', -1)
run_start_time = datetime.datetime.today().strftime('%Y-%m-%d_%H-%M-%S')

In [4]:
DATA_PATH = Path('../data/')
LABEL_PATH = Path('../labels/')

AUG_DATA_PATH = Path('../data/data_augmentation/')

MODEL_PATH=Path('../models/')
LOG_PATH=Path('../logs/')
MODEL_PATH.mkdir(exist_ok=True)

model_state_dict = None

# BERT_PRETRAINED_PATH = Path('../../bert_models/pretrained-weights/cased_L-12_H-768_A-12/')
BERT_PRETRAINED_PATH = Path('../../bert_models/pretrained-weights/uncased_L-12_H-768_A-12/')
# BERT_PRETRAINED_PATH = Path('../../bert_fastai/pretrained-weights/uncased_L-24_H-1024_A-16/')
# FINETUNED_PATH = Path('../models/finetuned_model.bin')
FINETUNED_PATH = None
# model_state_dict = torch.load(FINETUNED_PATH)

LOG_PATH.mkdir(exist_ok=True)

OUTPUT_PATH = MODEL_PATH/'output'
OUTPUT_PATH.mkdir(exist_ok=True)

In [5]:
args = Box({
    "run_text": "multilabel toxic comments with freezable layers",
    "train_size": -1,
    "val_size": -1,
    "log_path": LOG_PATH,
    "full_data_dir": DATA_PATH,
    "data_dir": DATA_PATH,
    "task_name": "toxic_classification_lib",
    "no_cuda": False,
    "bert_model": BERT_PRETRAINED_PATH,
    "output_dir": OUTPUT_PATH,
    "max_seq_length": 512,
    "do_train": True,
    "do_eval": True,
    "do_lower_case": True,
    "train_batch_size": 8,
    "eval_batch_size": 16,
    "learning_rate": 5e-5,
    "num_train_epochs": 6,
    "warmup_proportion": 0.0,
    "no_cuda": False,
    "local_rank": -1,
    "seed": 42,
    "gradient_accumulation_steps": 1,
    "optimize_on_cpu": False,
    "fp16": True,
    "fp16_opt_level": "O1",
    "weight_decay": 0.0,
    "adam_epsilon": 1e-8,
    "max_grad_norm": 1.0,
    "max_steps": -1,
    "warmup_steps": 500,
    "logging_steps": 50,
    "eval_all_checkpoints": True,
    "overwrite_output_dir": True,
    "overwrite_cache": False,
    "seed": 42,
    "loss_scale": 128,
    "task_name": 'intent',
    "model_name": 'xlnet-base-cased',
    "model_type": 'xlnet'
})

In [6]:
import logging

logfile = str(LOG_PATH/'log-{}-{}.txt'.format(run_start_time, args["run_text"]))

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
    datefmt='%m/%d/%Y %H:%M:%S',
    handlers=[
        logging.FileHandler(logfile),
        logging.StreamHandler(sys.stdout)
    ])

logger = logging.getLogger()

In [7]:
logger.info(args)

08/02/2019 09:38:23 - INFO - root -   {'run_text': 'multilabel toxic comments with freezable layers', 'train_size': -1, 'val_size': -1, 'log_path': PosixPath('../logs'), 'full_data_dir': PosixPath('../data'), 'data_dir': PosixPath('../data'), 'task_name': 'intent', 'no_cuda': False, 'bert_model': PosixPath('../../bert_models/pretrained-weights/uncased_L-12_H-768_A-12'), 'output_dir': PosixPath('../models/output'), 'max_seq_length': 512, 'do_train': True, 'do_eval': True, 'do_lower_case': True, 'train_batch_size': 8, 'eval_batch_size': 16, 'learning_rate': 5e-05, 'num_train_epochs': 6, 'warmup_proportion': 0.0, 'local_rank': -1, 'seed': 42, 'gradient_accumulation_steps': 1, 'optimize_on_cpu': False, 'fp16': True, 'fp16_opt_level': 'O1', 'weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'max_steps': -1, 'warmup_steps': 500, 'logging_steps': 50, 'eval_all_checkpoints': True, 'overwrite_output_dir': True, 'overwrite_cache': False, 'loss_scale': 128, 'model_name': 'xlnet-bas

In [8]:
# tokenizer = BertTokenizer.from_pretrained(BERT_PRETRAINED_PATH, do_lower_case=args['do_lower_case'])

In [None]:
device = torch.device('cuda')
if torch.cuda.device_count() > 1:
    args.multi_gpu = True
else:
    args.multi_gpu = False

In [None]:
label_cols = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]

In [None]:
from fast_bert.prediction import BertClassificationPredictor

In [None]:
predictor = BertClassificationPredictor(args.output_dir/'model_out', args.output_dir, LABEL_PATH, 
                                        multi_label=True, model_type='xlnet', do_lower_case=False)

08/02/2019 09:38:23 - INFO - transformers.tokenization_utils -   Model name '../models/output/model_out' not found in model shortcut name list (xlnet-base-cased, xlnet-large-cased). Assuming '../models/output/model_out' is a path or url to a directory containing tokenizer files.
08/02/2019 09:38:23 - INFO - transformers.tokenization_utils -   loading file ../models/output/model_out/added_tokens.json
08/02/2019 09:38:23 - INFO - transformers.tokenization_utils -   loading file ../models/output/model_out/special_tokens_map.json
08/02/2019 09:38:23 - INFO - transformers.tokenization_utils -   loading file ../models/output/model_out/spiece.model
08/02/2019 09:38:23 - INFO - transformers.modeling_utils -   loading configuration file ../models/output/model_out/config.json
08/02/2019 09:38:23 - INFO - transformers.modeling_utils -   Model config {
  "attn_type": "bi",
  "bi_data": false,
  "clamp_len": -1,
  "d_head": 64,
  "d_inner": 3072,
  "d_model": 768,
  "dropout": 0.1,
  "end_n_top": 5

In [None]:
output = predictor.predict_batch(list(pd.read_csv("../data/test.csv")['comment_text'].values))

08/02/2019 09:38:37 - INFO - root -   Writing example 0 of 153164
08/02/2019 09:38:43 - INFO - root -   Writing example 10000 of 153164
08/02/2019 09:38:50 - INFO - root -   Writing example 20000 of 153164
08/02/2019 09:38:57 - INFO - root -   Writing example 30000 of 153164
08/02/2019 09:39:04 - INFO - root -   Writing example 40000 of 153164
08/02/2019 09:39:11 - INFO - root -   Writing example 50000 of 153164
08/02/2019 09:39:18 - INFO - root -   Writing example 60000 of 153164
08/02/2019 09:39:26 - INFO - root -   Writing example 70000 of 153164
08/02/2019 09:39:33 - INFO - root -   Writing example 80000 of 153164
08/02/2019 09:39:39 - INFO - root -   Writing example 90000 of 153164
08/02/2019 09:39:47 - INFO - root -   Writing example 100000 of 153164
08/02/2019 09:39:54 - INFO - root -   Writing example 110000 of 153164
08/02/2019 09:40:01 - INFO - root -   Writing example 120000 of 153164
08/02/2019 09:40:08 - INFO - root -   Writing example 130000 of 153164
08/02/2019 09:40:16 

In [None]:
pd.DataFrame(output).to_csv('../data/output_xlnet.csv')

In [15]:
results = pd.read_csv('../data/output_xlnet.csv')

In [48]:
preds = pd.DataFrame([{item[0]: item[1] for item in pred} for pred in output])

In [50]:
preds.head()

Unnamed: 0,identity_hate,insult,obscene,severe_toxic,threat,toxic
0,0.787239,0.97021,0.990423,0.316317,0.015324,0.996634
1,7.1e-05,0.000166,0.00017,6.4e-05,4.1e-05,0.000707
2,7.3e-05,0.000178,0.000183,7e-05,4.5e-05,0.0006
3,7.3e-05,0.000179,0.000185,7.1e-05,4.5e-05,0.000594
4,7.3e-05,0.000175,0.00018,6.8e-05,4.4e-05,0.000619


In [61]:
test_df = pd.read_csv("../data/train.csv")
test_df.head()

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,0000997932d777bf,"Explanation\nWhy the edits made under my username Hardcore Metallica Fan were reverted? They weren't vandalisms, just closure on some GAs after I voted at New York Dolls FAC. And please don't remove the template from the talk page since I'm retired now.89.205.38.27",0,0,0,0,0,0
1,000103f0d9cfb60f,"D'aww! He matches this background colour I'm seemingly stuck with. Thanks. (talk) 21:51, January 11, 2016 (UTC)",0,0,0,0,0,0
2,000113f07ec002fd,"Hey man, I'm really not trying to edit war. It's just that this guy is constantly removing relevant information and talking to me through edits instead of my talk page. He seems to care more about the formatting than the actual info.",0,0,0,0,0,0
3,0001b41b1c6bb37e,"""\nMore\nI can't make any real suggestions on improvement - I wondered if the section statistics should be later on, or a subsection of """"types of accidents"""" -I think the references may need tidying so that they are all in the exact same format ie date format etc. I can do that later on, if no-one else does first - if you have any preferences for formatting style on references or want to do it yourself please let me know.\n\nThere appears to be a backlog on articles for review so I guess there may be a delay until a reviewer turns up. It's listed in the relevant form eg Wikipedia:Good_article_nominations#Transport """,0,0,0,0,0,0
4,0001d958c54c6e35,"You, sir, are my hero. Any chance you remember what page that's on?",0,0,0,0,0,0


In [52]:
output_df = pd.merge(test_df, preds, how='left', left_index=True, right_index=True)
del output_df['comment_text']

In [56]:
columns = ['id','toxic','severe_toxic','obscene','threat','insult','identity_hate']

In [57]:
output_df = output_df[columns]

In [58]:
output_df.to_csv('../data/output_xlnet.csv', index=None)

In [59]:
pd.read_csv('../data/output_xlnet.csv', index_col='id')

Unnamed: 0_level_0,toxic,severe_toxic,obscene,threat,insult,identity_hate
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
00001cee341fdb12,0.996634,0.316317,0.990423,0.015324,0.970210,0.787239
0000247867823ef7,0.000707,0.000064,0.000170,0.000041,0.000166,0.000071
00013b17ad220c46,0.000600,0.000070,0.000183,0.000045,0.000178,0.000073
00017563c3f7919a,0.000594,0.000071,0.000185,0.000045,0.000179,0.000073
00017695ad8997eb,0.000619,0.000068,0.000180,0.000044,0.000175,0.000073
0001ea8717f6de06,0.000590,0.000071,0.000186,0.000046,0.000180,0.000074
00024115d4cbde0f,0.000593,0.000071,0.000185,0.000045,0.000179,0.000074
000247e83dcc1211,0.226302,0.000131,0.002582,0.000855,0.003960,0.001135
00025358d4737918,0.000625,0.000068,0.000179,0.000044,0.000174,0.000073
00026d1092fe71cc,0.000596,0.000070,0.000184,0.000045,0.000178,0.000073
