<a href="https://colab.research.google.com/github/HaiLongEmb/zhltest/blob/master/demo_fastbert_sample.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
#@title
!git clone https://github.com/NVIDIA/apex.git
%cd apex
!pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
%cd ..

import torch
import apex
from fastai.text import *
import datetime
run_start_time = datetime.datetime.today().strftime('%Y-%m-%d_%H-%M-%S')

LOG_PATH=Path('logs/')  
MODEL_PATH=Path('models/') 

if not LOG_PATH.exists():
  LOG_PATH.mkdir()
import logging

args = {
    "run_text": "my_test",
    "max_seq_length": 512,
    "do_lower_case": True,
    "train_batch_size": 16,
    "learning_rate": 6e-5,
    "num_train_epochs": 12.0,
    "warmup_proportion": 0.002,
    "local_rank": -1,
    "gradient_accumulation_steps": 1,
    "fp16": True,
    "loss_scale": 128
}

logfile = str(LOG_PATH/'log-{}-{}.txt'.format(run_start_time, args["run_text"]))

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
    datefmt='%m/%d/%Y %H:%M:%S',
    handlers=[
        logging.FileHandler(logfile),
        logging.StreamHandler(sys.stdout)
    ])

logger = logging.getLogger()

device = torch.device('cuda')

if torch.cuda.device_count() > 1:
    multi_gpu = True
else:
    multi_gpu = False

In [0]:
!git clone https://github.com/wshuyi/demo-multi-label-classification-bert.git

In [0]:
!pip install fast-bert

In [0]:
from fast_bert.data import *
from fast_bert.learner import *
from fast_bert.metrics import *
from pytorch_pretrained_bert.tokenization import BertTokenizer

In [0]:
DATA_PATH = Path('demo-multi-label-classification-bert/sample/data/')     
LABEL_PATH = Path('demo-multi-label-classification-bert/sample/labels/')  

BERT_PRETRAINED_MODEL = "bert-base-uncased"

args["do_lower_case"] = True
args["train_batch_size"] = 16
args["learning_rate"] = 6e-5
args["max_seq_length"] = 512
args["fp16"] = True

In [0]:
tokenizer = BertTokenizer.from_pretrained(BERT_PRETRAINED_MODEL, 
                                          do_lower_case=args['do_lower_case'])

In [0]:
label_cols = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]

In [0]:
databunch = BertDataBunch(DATA_PATH, LABEL_PATH, tokenizer, train_file='train.csv', val_file='valid.csv',
                          test_data='test.csv', label_file="labels.csv",
                          text_col="comment_text", label_col=label_cols,
                          bs=args['train_batch_size'], maxlen=args['max_seq_length'], 
                          multi_gpu=multi_gpu, multi_label=True)

In [0]:
metrics = [{'name': 'accuracy', 'function': accuracy_multilabel}]

In [0]:
logger.info(args)

In [0]:
learner = BertLearner.from_pretrained_model(databunch, BERT_PRETRAINED_MODEL, metrics, device, logger, 
                                            is_fp16=args['fp16'], loss_scale=args['loss_scale'], 
                                            multi_gpu=multi_gpu,  multi_label=True)

In [0]:
learner.fit(4, lr=args['learning_rate'], schedule_type="warmup_linear")