In [1]:
import torch 
import pickle
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)
from torch.nn import CrossEntropyLoss, MSELoss

from tqdm import tqdm_notebook, trange
import os
from pytorch_transformers import BertTokenizer, BertModel, BertForMaskedLM, BertForSequenceClassification
from pytorch_transformers.optimization import AdamW, WarmupLinearSchedule

from multiprocessing import Pool, cpu_count
from tools import *
import examples_to_features

# if you want to have more information on what's happening, activate the logger as follows
import logging

logging.basicConfig(level=logging.INFO)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# the input data dir. should contain the .tsv files (or the other data files) for the task
DATA_DIR = "data/"

# Bert pre-trained model selected in the list: bert-base-uncased,
# bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased
# bert-base-multilingual-cased, bert-base-chinese
BERT_MODEL = 'bert-base-cased'

# The name of the task to train. I'm going t bame this 'yelp'
TASK_NAME = 'yelp'

# the output directory where the fine-tuned model and the checkpoints would be written
OUTPUT_DIR = f'outputs/{TASK_NAME}/'

# The directory where the evaluation reports will be written to 
REPORTS_DIR = f'reports/{TASK_NAME}_evaluation_report/'

# this is where BERT will look for pre-trained models to load the parameters from 
CACHE_DIR = 'cache/'

# The maximum total input sequence length after WordPiece tokenization
# Sequences shorter than this will be truncated, and sequences shorter than this will be padded
MAX_SEQ_LENGTH = 128

TRAIN_BATCH_SIZE = 24
EVAL_BATCH_SIZE = 32
LEARNING_RATE = 2e-5
NUM_TRAIN_EPOCHS = 1
RANDOM_SEED = 42
GRADIENT_ACCUMULATION_STEPS = 1
WARMUP_PROPORTION = 0.1
OUTPUT_MODE = 'classification'

CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"

In [3]:
output_mode = OUTPUT_MODE
cache_dir = CACHE_DIR

In [4]:
if os.path.exists(REPORTS_DIR) and os.listdir(REPORTS_DIR):
        REPORTS_DIR += f'/report_{len(os.listdir(REPORTS_DIR))}'
        os.makedirs(REPORTS_DIR)
if not os.path.exists(REPORTS_DIR):
    os.makedirs(REPORTS_DIR)
    REPORTS_DIR += f'/report_{len(os.listdir(REPORTS_DIR))}'
    os.makedirs(REPORTS_DIR)

In [5]:
if os.path.exists(OUTPUT_DIR) and os.listdir(OUTPUT_DIR):
        raise ValueError("Output directory ({}) already exists and is not empty.".format(OUTPUT_DIR))
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

In [6]:
# Use our BinaryClassificationProcessor to load in the data and get everything ready for tokenization step
processor = BinaryClassificationProcessor()
train_examples = processor.get_train_examples(DATA_DIR)
train_examples_len = len(train_examples)


In [7]:
label_list = processor.get_labels() # [0,1] for binary classification
num_labels = len(label_list)

In [8]:
num_train_optimization_steps = int(
    train_examples_len / TRAIN_BATCH_SIZE / GRADIENT_ACCUMULATION_STEPS) * NUM_TRAIN_EPOCHS

In [9]:
# load pre-trained model tokenizeer (vocabulary)
tokenzer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)

INFO:pytorch_transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at /home/achintya/.cache/torch/pytorch_transformers/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1


In [10]:
label_map = {label: i for i, label in enumerate(label_list)}
train_examples_for_processing = [(example, label_map, MAX_SEQ_LENGTH, tokenzer, OUTPUT_MODE) for example in train_examples]

In [11]:
process_count = 1
if __name__ ==  '__main__':
    print(f'Preparing to convert {train_examples_len} examples..')
    print(f'Spawning {process_count} processes..')
    with Pool(process_count) as p:
        train_features = list(tqdm_notebook(p.imap(examples_to_features.convert_example_to_feature, train_examples_for_processing), total=train_examples_len))


Preparing to convert 140000 examples..
Spawning 1 processes..


HBox(children=(IntProgress(value=0, max=140000), HTML(value='')))




In [12]:
with open(DATA_DIR + "train_features.pkl", "wb") as f:
    pickle.dump(train_features, f)