Install packages (optional if already used pip and requirements.txt)

In [1]:
# ! pip install datasets
# ! pip install accelerate
# ! pip install evaluate

Load packages

In [2]:
from transformers import (
    AutoTokenizer,
    DataCollatorWithPadding,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer
)
from datasets import Dataset
import evaluate
import numpy as np
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


Define helper functions

In [3]:
tokenizer = AutoTokenizer.from_pretrained('distilbert/distilbert-base-uncased', model_max_length=512)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
def preprocess_function(examples):
    return tokenizer(examples['text'], truncation=True)


accuracy = evaluate.load('accuracy')
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

- Load dataset
- Calculate id2label and label2id dictionaries
- Label, shuffle, stratify, and split

In [4]:
df = pd.read_csv('data/GB-GOV-1.csv')
unique_labels = df.label.unique()
id2label = {i: label for i, label in enumerate(unique_labels)}
label2id = {id2label[i]: i for i in id2label.keys()}
dataset = Dataset.from_pandas(df).class_encode_column("label").train_test_split(
    test_size=0.3,
    stratify_by_column="label",
    shuffle=True,
)

Casting to class labels: 100%|██████████| 3878/3878 [00:00<00:00, 253831.32 examples/s]


Tokenize dataset

In [5]:
tokenized_data = dataset.map(preprocess_function, batched=True)

Map: 100%|██████████| 2714/2714 [00:00<00:00, 17828.03 examples/s]
Map: 100%|██████████| 1164/1164 [00:00<00:00, 22385.93 examples/s]


Load and set up model

In [6]:
model = AutoModelForSequenceClassification.from_pretrained(
    'distilbert/distilbert-base-uncased', num_labels=len(id2label.keys()), id2label=id2label, label2id=label2id
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Set up training arguments and trainer

In [7]:
training_args = TrainingArguments(
    output_dir='models/climate-classifier',
    learning_rate=1e-5, # This can be tweaked depending on how loss progresses
    per_device_train_batch_size=36, # These should be tweaked to match GPU VRAM
    per_device_eval_batch_size=36,
    num_train_epochs=10,
    weight_decay=0.01,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    logging_strategy='epoch',
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data['train'],
    eval_dataset=tokenized_data['test'],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

Initialize training

In [8]:
trainer.train()

                                                
 10%|█         | 76/760 [02:43<21:11,  1.86s/it]Checkpoint destination directory models/climate-classifier/checkpoint-76 already exists and is non-empty.Saving will proceed but saved results may be invalid.


{'eval_loss': 0.43150314688682556, 'eval_accuracy': 0.8213058419243986, 'eval_runtime': 21.0315, 'eval_samples_per_second': 55.345, 'eval_steps_per_second': 1.569, 'epoch': 1.0}


                                                  
 20%|██        | 152/760 [05:40<15:15,  1.51s/it]Checkpoint destination directory models/climate-classifier/checkpoint-152 already exists and is non-empty.Saving will proceed but saved results may be invalid.


{'eval_loss': 0.38561081886291504, 'eval_accuracy': 0.834192439862543, 'eval_runtime': 16.5845, 'eval_samples_per_second': 70.186, 'eval_steps_per_second': 1.99, 'epoch': 2.0}


                                                   
 30%|███       | 228/760 [08:31<14:30,  1.64s/it]Checkpoint destination directory models/climate-classifier/checkpoint-228 already exists and is non-empty.Saving will proceed but saved results may be invalid.


{'eval_loss': 0.3669620752334595, 'eval_accuracy': 0.8436426116838488, 'eval_runtime': 22.3838, 'eval_samples_per_second': 52.002, 'eval_steps_per_second': 1.474, 'epoch': 3.0}


                                                   
 40%|████      | 304/760 [11:44<12:53,  1.70s/it]Checkpoint destination directory models/climate-classifier/checkpoint-304 already exists and is non-empty.Saving will proceed but saved results may be invalid.


{'eval_loss': 0.3765120506286621, 'eval_accuracy': 0.8496563573883161, 'eval_runtime': 20.38, 'eval_samples_per_second': 57.115, 'eval_steps_per_second': 1.619, 'epoch': 4.0}


                                                   
 50%|█████     | 380/760 [14:46<11:05,  1.75s/it]Checkpoint destination directory models/climate-classifier/checkpoint-380 already exists and is non-empty.Saving will proceed but saved results may be invalid.


{'eval_loss': 0.38608911633491516, 'eval_accuracy': 0.8479381443298969, 'eval_runtime': 21.7681, 'eval_samples_per_second': 53.473, 'eval_steps_per_second': 1.516, 'epoch': 5.0}


                                                 
 60%|██████    | 456/760 [17:53<10:22,  2.05s/it]Checkpoint destination directory models/climate-classifier/checkpoint-456 already exists and is non-empty.Saving will proceed but saved results may be invalid.


{'eval_loss': 0.41391292214393616, 'eval_accuracy': 0.8487972508591065, 'eval_runtime': 24.7514, 'eval_samples_per_second': 47.028, 'eval_steps_per_second': 1.333, 'epoch': 6.0}


 66%|██████▌   | 500/760 [19:42<11:19,  2.61s/it]

{'loss': 0.323, 'learning_rate': 3.421052631578948e-06, 'epoch': 6.58}


                                                 
 70%|███████   | 532/760 [21:19<06:53,  1.82s/it]Checkpoint destination directory models/climate-classifier/checkpoint-532 already exists and is non-empty.Saving will proceed but saved results may be invalid.


{'eval_loss': 0.42547136545181274, 'eval_accuracy': 0.8530927835051546, 'eval_runtime': 21.8064, 'eval_samples_per_second': 53.379, 'eval_steps_per_second': 1.513, 'epoch': 7.0}


                                                 
 80%|████████  | 608/760 [24:20<04:44,  1.87s/it]

{'eval_loss': 0.43231654167175293, 'eval_accuracy': 0.8419243986254296, 'eval_runtime': 17.6946, 'eval_samples_per_second': 65.783, 'eval_steps_per_second': 1.865, 'epoch': 8.0}


                                                 
 90%|█████████ | 684/760 [27:16<02:10,  1.71s/it]

{'eval_loss': 0.4386640787124634, 'eval_accuracy': 0.8445017182130584, 'eval_runtime': 22.6799, 'eval_samples_per_second': 51.323, 'eval_steps_per_second': 1.455, 'epoch': 9.0}


                                                 
100%|██████████| 760/760 [30:05<00:00,  1.75s/it]

{'eval_loss': 0.4387195110321045, 'eval_accuracy': 0.8470790378006873, 'eval_runtime': 18.77, 'eval_samples_per_second': 62.014, 'eval_steps_per_second': 1.758, 'epoch': 10.0}


100%|██████████| 760/760 [30:06<00:00,  2.38s/it]

{'train_runtime': 1806.9874, 'train_samples_per_second': 15.019, 'train_steps_per_second': 0.421, 'train_loss': 0.26499609696237664, 'epoch': 10.0}





TrainOutput(global_step=760, training_loss=0.26499609696237664, metrics={'train_runtime': 1806.9874, 'train_samples_per_second': 15.019, 'train_steps_per_second': 0.421, 'train_loss': 0.26499609696237664, 'epoch': 10.0})