# Train BERT text classifier using the Transformers library (code only)

This is a version of the notebook [https://github.com/TurkuNLP/Text_Mining_Course/blob/master/train_bert_for_text_classification.ipynb](train_bert_for_text_classification.ipynb) with just the code for training BERT. See that notebook for explanations.

In [None]:
!pip --quiet install transformers
!pip --quiet install datasets

In [None]:
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
from transformers import TrainingArguments
from transformers import Trainer
from datasets import load_dataset


MODEL_NAME = 'bert-base-cased'
DATASET = ('glue', 'sst2')
LEARNING_RATE=2e-5
BATCH_SIZE=16
TRAIN_EPOCHS=3

dataset = load_dataset(*DATASET)
num_labels = len(set(dataset['train']['label']))
dataset['train'] = dataset['train'].filter(lambda example, idx: idx % 10 == 0, with_indices=True)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def encode_dataset(d):
  return tokenizer(d['sentence'])
encoded_dataset = dataset.map(encode_dataset)

model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=num_labels)

train_args = TrainingArguments(
    'output_dir',
    save_strategy='no',
    evaluation_strategy='epoch',
    logging_strategy='epoch',
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=BATCH_SIZE,
    num_train_epochs=TRAIN_EPOCHS,
)

def compute_accuracy(pred):
    y_pred = pred.predictions.argmax(axis=1)
    y_true = pred.label_ids
    return { 'accuracy': sum(y_pred == y_true) / len(y_true) }

trainer = Trainer(
      model,
      train_args,
      train_dataset=encoded_dataset['train'],
      eval_dataset=encoded_dataset['validation'],
      tokenizer=tokenizer,
      compute_metrics=compute_accuracy
)

trainer.train()