In [1]:
import tensorflow as tf
import pandas as pd
import numpy as np
import time
from transformers import InputExample, InputFeatures, DistilBertTokenizerFast
from tqdm.notebook import tqdm

from pandarallel import pandarallel
pandarallel.initialize()

tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')

INFO: Pandarallel will run on 4 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.


In [2]:
def convert_label(star_rating):
    return "POSITIVE" if star_rating >= 4 else "NEUTRAL" if star_rating == 3 else "NEGATIVE"

In [3]:
test = pd.read_csv("s3://yelp-dataset-pt-9/spencer/data/sentiment/en/test_small.csv")
test['labels'] = test['stars'].parallel_apply(convert_label)

In [4]:
num_chunks = len(test) // 100 + 1

In [5]:
test_chunks = np.array_split(test, num_chunks)

In [6]:
def batch_encodings(rows):
    max_length = 256
    labels = rows['labels'].tolist()

    batch_encoding = tokenizer.batch_encode_plus(
        rows['text'].tolist(), max_length=max_length, pad_to_max_length=True, return_token_type_ids=True
    )

    features = []
    for i in range(len(rows)):
        inputs = {k: batch_encoding[k][i] for k in batch_encoding}

        feature = InputFeatures(**inputs, label=labels[i])
        features.append(feature)
        
    return features

In [7]:
def run_batch_encodings(chunks):
    features = []
    for chunk in tqdm(chunks):
        features += batch_encodings(chunk)
        
    return features

In [8]:
start = time.time()
features1 = run_batch_encodings(test_chunks)
time.time() - start

HBox(children=(FloatProgress(value=0.0, max=2001.0), HTML(value='')))




69.1996340751648

In [9]:
start = time.time()
features2 = batch_encodings(test)
time.time() - start

109.43570113182068

In [10]:
features1 == features2

True

In [11]:
start = time.time()
features1 = run_batch_encodings(test_chunks)
time.time() - start

HBox(children=(FloatProgress(value=0.0, max=2001.0), HTML(value='')))




84.97488117218018

In [12]:
start = time.time()
features2 = batch_encodings(test)
time.time() - start

164.25626611709595

In [13]:
features1 == features2

True