In [1]:
! pip install wget

Collecting wget
  Downloading https://files.pythonhosted.org/packages/47/6a/62e288da7bcda82b935ff0c6cfe542970f04e29c756b0e147251b2fb251f/wget-3.2.zip
Building wheels for collected packages: wget
  Building wheel for wget (setup.py) ... [?25l[?25hdone
  Created wheel for wget: filename=wget-3.2-cp36-none-any.whl size=9682 sha256=b974f223c0446b6eca534369b1daa4dcd3298d77535d2c01f13884eba7c1a5e7
  Stored in directory: /root/.cache/pip/wheels/40/15/30/7d8f7cea2902b4db79e3fea550d7d7b85ecb27ef992b618f3f
Successfully built wget
Installing collected packages: wget
Successfully installed wget-3.2


In [2]:
! git clone https://github.com/huggingface/transformers
! cd transformers && pip install .

Cloning into 'transformers'...
remote: Enumerating objects: 88, done.[K
remote: Counting objects: 100% (88/88), done.[K
remote: Compressing objects: 100% (63/63), done.[K
remote: Total 51265 (delta 39), reused 57 (delta 18), pack-reused 51177[K
Receiving objects: 100% (51265/51265), 37.95 MiB | 28.00 MiB/s, done.
Resolving deltas: 100% (35812/35812), done.
Processing /content/transformers
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting sentencepiece==0.1.91
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 22.3MB/s 
[?25hCollecting tokenizers==0.9.3
[?25l  Downloading https://files.pythonhosted.org/packages/4c/34/b39eb9994bc3c999270b69c9eea40ecc6f0e97991dba28282b9fd32d44ee

In [3]:
import wget, tarfile
import os


# ----- download dataset -----
def download_dataset(url: str, save_path: str) -> None:
    if not os.path.isdir("data/"):
        os.makedirs("data/")
    extra_path = "data/" + save_path.split(".")[0]
    if not os.path.isdir(extra_path):
        wget.download(url, out=save_path)
        with tarfile.open(save_path) as tf:
            extra_path = "data/"+save_path.split(".")[0]
            tf.extractall(extra_path)
        os.remove(save_path)
        print("Download success in {}".format(extra_path))
    else:
        print("data has downloaded in {}".format(extra_path))


data_url = 'http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz'
save_file = 'aclImdb_v1.tar.gz'
download_dataset(data_url, save_file)

Download success in data/aclImdb_v1


In [4]:
# ---------- SA IMDB ----------
# ----- data -----
from pathlib import Path


def read_imdb_split(split_dir):
    split_dir = Path(split_dir)
    texts = []
    labels = []
    for label_dir in ["pos", "neg"]:
        for text_file in (split_dir/label_dir).iterdir():
            texts.append(text_file.read_text(encoding="utf-8"))
            labels.append(0 if label_dir is "neg" else 1)

    return texts, labels


train_texts, train_labels = read_imdb_split('data/aclImdb_v1/aclImdb/train')
test_texts, test_labels = read_imdb_split('data/aclImdb_v1/aclImdb/test')

In [5]:
# ----- split train valid -----
from sklearn.model_selection import train_test_split
train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=.2)

# ----- BertTokenizer -----
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)
test_encodings = tokenizer(test_texts, truncation=True, padding=True)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




In [6]:
# ----- Dataset -----
import torch


class IMDbDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = IMDbDataset(train_encodings, train_labels)
val_dataset = IMDbDataset(val_encodings, val_labels)
test_dataset = IMDbDataset(test_encodings, test_labels)


In [7]:
# ----- metrics -----
from sklearn.metrics import accuracy_score

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc
    }

In [8]:
# ----- Trainer -----
from transformers import BertForSequenceClassification, Trainer, TrainingArguments
from transformers.trainer_utils import EvaluationStrategy

training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=3,              # total number of training epochs
    per_device_train_batch_size=3,   # batch size per device during training
    per_device_eval_batch_size=3,    # batch size for evaluation
    warmup_steps=50,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_steps=50,
    evaluation_strategy=EvaluationStrategy.STEPS,
    eval_steps=250,
    gradient_accumulation_steps=12,
)

model = BertForSequenceClassification.from_pretrained("bert-base-uncased")

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset,             # evaluation dataset
    compute_metrics=compute_metrics
)

trainer.train()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Step,Training Loss,Validation Loss,Accuracy
250,0.249159,0.211517,0.917
500,0.21247,0.215702,0.9224
750,0.108501,0.226248,0.926
1000,0.129472,0.219542,0.9334
1250,0.049817,0.28425,0.9308
1500,0.041864,0.299125,0.9362


TrainOutput(global_step=1665, training_loss=0.14538876645199889)

In [9]:
# predict
# ----- test data -----

prediction = trainer.predict(test_dataset)
print(compute_metrics(prediction))

{'accuracy': 0.93884}
