In [29]:
import pandas as pd
import pickle
from datasets import Dataset
import torch
from datasets import DatasetDict
from transformers import BertForSequenceClassification, BertTokenizer
from transformers import DataCollatorWithPadding
import evaluate
from transformers import TrainingArguments, Trainer
import numpy as np

In [2]:
train_df = pd.read_pickle("train_dataset")
test_df = pd.read_pickle("test_dataset")
val_df = pd.read_pickle("val_dataset")

In [3]:
dataset = DatasetDict({
    'train': Dataset.from_pandas(train_df),
    'test': Dataset.from_pandas(test_df),
    'unsupervised': Dataset.from_pandas(val_df)
})

In [5]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 427
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 92
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 92
    })
})

In [4]:
# define preprocess function
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

In [7]:
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

In [8]:
tokenized_data = dataset.map(preprocess_function, batched=True, batch_size=100,  load_from_cache_file=True)

Map:   0%|          | 0/427 [00:00<?, ? examples/s]

Map:   0%|          | 0/92 [00:00<?, ? examples/s]

Map:   0%|          | 0/92 [00:00<?, ? examples/s]

In [9]:
tokenized_data['train'][0].keys()

dict_keys(['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'])

In [11]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [13]:
accuracy = evaluate.load("accuracy")

In [30]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [17]:
train_df['label'].unique()

array([4, 5, 6, 0, 3, 1, 2])

In [18]:
model = BertForSequenceClassification.from_pretrained("bert-base-cased", num_labels=7)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initi

In [31]:
training_args = TrainingArguments(
    output_dir="./bert_runs",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=10,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

In [32]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [33]:
torch.cuda.empty_cache()

In [34]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.856298,0.728261
2,No log,0.906006,0.76087
3,No log,0.925589,0.706522
4,No log,1.022355,0.728261
5,0.457100,1.088503,0.76087
6,0.457100,1.176671,0.73913
7,0.457100,1.236261,0.728261
8,0.457100,1.301437,0.73913
9,0.457100,1.311407,0.73913
10,0.033500,1.324309,0.73913


TrainOutput(global_step=1070, training_loss=0.22969851463197546, metrics={'train_runtime': 287.2986, 'train_samples_per_second': 14.863, 'train_steps_per_second': 3.724, 'total_flos': 126075938513100.0, 'train_loss': 0.22969851463197546, 'epoch': 10.0})

In [35]:
trainer.evaluate()

{'eval_loss': 0.8562980890274048,
 'eval_accuracy': 0.7282608695652174,
 'eval_runtime': 1.267,
 'eval_samples_per_second': 72.613,
 'eval_steps_per_second': 18.153,
 'epoch': 10.0}

In [42]:
training_args = TrainingArguments(
    output_dir="./bert_runs",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=10,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

In [43]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [44]:
torch.cuda.empty_cache()

In [45]:
trainer.train()



Epoch,Training Loss,Validation Loss,Accuracy
1,No log,1.033177,0.75
2,No log,0.907752,0.771739
3,No log,0.970062,0.793478
4,No log,1.099334,0.771739
5,No log,1.048463,0.76087
6,No log,1.082632,0.771739
7,No log,1.142383,0.771739
8,No log,1.18034,0.75
9,No log,1.202066,0.75
10,0.095900,1.202592,0.75


TrainOutput(global_step=540, training_loss=0.0892551847630077, metrics={'train_runtime': 417.6259, 'train_samples_per_second': 10.224, 'train_steps_per_second': 1.293, 'total_flos': 175062016055100.0, 'train_loss': 0.0892551847630077, 'epoch': 10.0})

In [46]:
trainer.evaluate()

{'eval_loss': 0.9077516794204712,
 'eval_accuracy': 0.7717391304347826,
 'eval_runtime': 1.7083,
 'eval_samples_per_second': 53.856,
 'eval_steps_per_second': 7.025,
 'epoch': 10.0}

In [47]:
trainer.save_model('./baseline_bert')

In [49]:
trainer.predict(tokenized_data['unsupervised'])

PredictionOutput(predictions=array([[-1.64433467e+00,  2.88456857e-01, -3.14732909e-01,
        -1.13448203e+00,  1.46027863e+00,  4.26518393e+00,
        -1.19324017e+00],
       [ 1.95076859e+00, -1.36538720e+00, -2.64445186e+00,
         2.92469710e-02, -2.04532218e+00, -6.66371226e-01,
         5.54735720e-01],
       [-4.97009665e-01,  2.11432949e-01, -1.04289520e+00,
        -1.44406736e+00,  6.81269598e+00, -8.06788981e-01,
        -1.61416578e+00],
       [ 8.10820043e-01, -9.84315991e-01, -1.22105169e+00,
        -3.64282221e-01, -3.35116088e-01,  4.44170684e-02,
         1.30879378e+00],
       [-1.16928017e+00,  3.94857287e-01, -1.66562200e+00,
        -1.23402858e+00,  1.27487934e+00,  3.08241796e+00,
        -1.07579768e+00],
       [ 3.03948950e-03,  3.77182811e-01, -1.33690202e+00,
        -1.42729068e+00,  6.27806473e+00, -1.21254778e+00,
        -1.61610413e+00],
       [-4.19007778e-01,  2.19068572e-01, -1.19383562e+00,
        -1.45633817e+00,  6.71228027e+00, -9.589