In [1]:
import pandas as pd
import glob
import torch, os
from transformers import pipeline, BertForSequenceClassification, BertTokenizerFast, TrainingArguments, Trainer, DistilBertForSequenceClassification, DistilBertTokenizerFast
from torch.utils.data import Dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print('GPU')
else:
    device = torch.device("cpu")
    print('CPU')

GPU


In [3]:
train_paths = ['./content/train/california_wildfires_2018_train.tsv',
               './content/train/canada_wildfires_2016_train.tsv',
               './content/train/cyclone_idai_2019_train.tsv',
               './content/train/ecuador_earthquake_2016_train.tsv',
               './content/train/greece_wildfires_2018_train.tsv',
               './content/train/hurricane_dorian_2019_train.tsv',
               './content/train/hurricane_florence_2018_train.tsv',
               './content/train/hurricane_harvey_2017_train.tsv',
               './content/train/hurricane_irma_2017_train.tsv',
               './content/train/hurricane_maria_2017_train.tsv',
               './content/train/hurricane_matthew_2016_train.tsv',
               './content/train/italy_earthquake_aug_2016_train.tsv',
               './content/train/kaikoura_earthquake_2016_train.tsv',
               './content/train/kerala_floods_2018_train.tsv',
               './content/train/maryland_floods_2018_train.tsv',
               './content/train/midwestern_us_floods_2019_train.tsv',
               './content/train/pakistan_earthquake_2019_train.tsv',
               './content/train/puebla_mexico_earthquake_2017_train.tsv',
               './content/train/srilanka_floods_2017_train.tsv']

val_paths = ['./content/dev/california_wildfires_2018_dev.tsv',
             './content/dev/canada_wildfires_2016_dev.tsv',
             './content/dev/cyclone_idai_2019_dev.tsv',
             './content/dev/ecuador_earthquake_2016_dev.tsv',
             './content/dev/greece_wildfires_2018_dev.tsv',
             './content/dev/hurricane_dorian_2019_dev.tsv',
             './content/dev/hurricane_florence_2018_dev.tsv',
             './content/dev/hurricane_harvey_2017_dev.tsv',
             './content/dev/hurricane_irma_2017_dev.tsv',
             './content/dev/hurricane_maria_2017_dev.tsv',
             './content/dev/hurricane_matthew_2016_dev.tsv',
             './content/dev/italy_earthquake_aug_2016_dev.tsv',
             './content/dev/kaikoura_earthquake_2016_dev.tsv',
             './content/dev/kerala_floods_2018_dev.tsv',
             './content/dev/maryland_floods_2018_dev.tsv',
             './content/dev/midwestern_us_floods_2019_dev.tsv',
             './content/dev/pakistan_earthquake_2019_dev.tsv',
             './content/dev/puebla_mexico_earthquake_2017_dev.tsv',
             './content/dev/srilanka_floods_2017_dev.tsv']

test_paths = ['./content/test/california_wildfires_2018_test.tsv',
              './content/test/canada_wildfires_2016_test.tsv',
              './content/test/cyclone_idai_2019_test.tsv',
              './content/test/ecuador_earthquake_2016_test.tsv',
              './content/test/greece_wildfires_2018_test.tsv',
              './content/test/hurricane_dorian_2019_test.tsv',
              './content/test/hurricane_florence_2018_test.tsv',
              './content/test/hurricane_harvey_2017_test.tsv',
              './content/test/hurricane_irma_2017_test.tsv',
              './content/test/hurricane_maria_2017_test.tsv',
              './content/test/hurricane_matthew_2016_test.tsv',
              './content/test/italy_earthquake_aug_2016_test.tsv',
              './content/test/kaikoura_earthquake_2016_test.tsv',
              './content/test/kerala_floods_2018_test.tsv',
              './content/test/maryland_floods_2018_test.tsv',
              './content/test/midwestern_us_floods_2019_test.tsv',
              './content/test/pakistan_earthquake_2019_test.tsv',
              './content/test/puebla_mexico_earthquake_2017_test.tsv',
              './content/test/srilanka_floods_2017_test.tsv']

train_file_paths = []

for path in train_paths:
    files = glob.glob(path)
    train_file_paths.extend(files)

val_file_paths = []

for path in val_paths:
    files = glob.glob(path)
    val_file_paths.extend(files)

test_file_paths = []

for path in test_paths:
    files = glob.glob(path)
    test_file_paths.extend(files)

In [4]:
train_dfs = []

for file in train_file_paths:
    df = pd.read_csv(file, sep='\t')
    df = df.iloc[:, 1:] # removing tweet ids
    df = df[df['class_label'] != 'missing_or_found_people']
    train_dfs.append(df)

train_df = pd.concat(train_dfs, ignore_index=True)

val_dfs = [] # dataframes

for file in val_file_paths:
    df = pd.read_csv(file, sep='\t')
    df = df.iloc[:, 1:] # removing tweet ids
    df = df[df['class_label'] != 'missing_or_found_people']
    val_dfs.append(df)

val_df = pd.concat(val_dfs, ignore_index=True)

test_dfs = [] # dataframes

for file in test_file_paths:
    df = pd.read_csv(file, sep='\t')
    df = df.iloc[:, 1:] # removing tweet ids
    df = df[df['class_label'] != 'missing_or_found_people']
    test_dfs.append(df)

test_df = pd.concat(test_dfs, ignore_index=True)

In [5]:
class_labels = train_df['class_label'].unique().tolist()
class_counts = train_df['class_label'].value_counts()
class_counts

class_label
rescue_volunteering_or_donation_effort    14891
other_relevant_information                 8501
sympathy_and_support                       6250
infrastructure_and_utility_damage          5715
injured_or_dead_people                     5110
not_humanitarian                           4407
caution_and_advice                         3774
displaced_people_and_evacuations           2800
requests_or_urgent_needs                   1833
Name: count, dtype: int64

Brute force Balance

In [6]:
# brute force balance
class_labels = train_df['class_label'].unique().tolist()
class_counts = train_df['class_label'].value_counts()
# Minimum class size
min_class_size = 3000 #min(class_counts)

balanced_dfs = []  # List to hold balanced DataFrames for each class
for class_label in class_labels:
    class_df = train_df[train_df['class_label'] == class_label]
    # Sample the minimum class size, if sample size is greater than the class size, then sample with replacement
    sampled_df = class_df.sample(min_class_size, replace=min_class_size > class_df.shape[0])
    balanced_dfs.append(sampled_df)

train_df = pd.concat(balanced_dfs, ignore_index=True)
# Optional: Shuffle the balanced DataFrame
train_df = train_df.sample(frac=1, random_state=42).reset_index(drop=True)

class_labels = [s.strip() for s in class_labels]

In [7]:
class_labels = train_df['class_label'].unique().tolist()
class_counts = train_df['class_label'].value_counts()
class_counts

class_label
other_relevant_information                3000
displaced_people_and_evacuations          3000
sympathy_and_support                      3000
infrastructure_and_utility_damage         3000
not_humanitarian                          3000
caution_and_advice                        3000
requests_or_urgent_needs                  3000
rescue_volunteering_or_donation_effort    3000
injured_or_dead_people                    3000
Name: count, dtype: int64

In [8]:
id2label = {id:label for id, label in enumerate(class_labels)}
label2id = {label:id for id, label in enumerate(class_labels)}

In [9]:
from sklearn.preprocessing import LabelEncoder

label_encoder = LabelEncoder()
train_df['class_label_num'] = label_encoder.fit_transform(train_df['class_label'])
val_df['class_label_num'] = label_encoder.transform(val_df['class_label'])
test_df['class_label_num'] = label_encoder.transform(test_df['class_label'])

In [10]:
train_df.head()

Unnamed: 0,tweet_text,class_label,class_label_num
0,The almighty Typhoon Haiyan struck the phillip...,other_relevant_information,5
1,"When it comes to hurricanes, preparation is ke...",displaced_people_and_evacuations,1
2,#earthquake Pray for safety of people plz,sympathy_and_support,8
3,Mother Nature is not to be trifled with -- wed...,infrastructure_and_utility_damage,2
4,And if u ever saw what I wrote here today I as...,not_humanitarian,4


In [11]:
train_tweets = list(train_df.tweet_text)
train_labels = list(train_df.class_label_num)

val_tweets = list(val_df.tweet_text)
val_labels = list(val_df.class_label_num)

test_tweets = list(test_df.tweet_text)
test_labels = list(test_df.class_label_num)

In [12]:
model_name = 'bert-base-uncased'
tokenizer = BertTokenizerFast.from_pretrained(model_name)

In [13]:
train_encodings = tokenizer(train_tweets, truncation=True, padding=True, max_length=512)
val_encodings = tokenizer(val_tweets, truncation=True, padding=True, max_length=512)
test_encodings = tokenizer(test_tweets, truncation=True, padding=True, max_length=512)

In [14]:
class TweetDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

    def __len__(self):
        return len(self.labels)


In [15]:
train_dataset = TweetDataset(train_encodings, train_labels)  # Use resampled labels
val_dataset = TweetDataset(val_encodings, val_labels)
test_dataset = TweetDataset(test_encodings, test_labels)

In [16]:
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=len(class_labels))
model.to(device)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [17]:
training_args = TrainingArguments(
    output_dir='./content/bert_model_output',
    do_train=True,
    do_eval=True,
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    fp16=True,
    warmup_steps=200,
    weight_decay=0.01,
    logging_strategy='steps',
    logging_dir='./content/logs',
    logging_steps=100,
    evaluation_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=500,
    load_best_model_at_end=True,
)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None)


In [18]:
# Train the model
trainer.train()

  0%|          | 0/5064 [00:00<?, ?it/s]

{'loss': 2.1388, 'learning_rate': 2.425e-05, 'epoch': 0.06}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 1.9613018035888672, 'eval_accuracy': 0.2351424519788578, 'eval_f1': 0.17561170333514334, 'eval_precision': 0.43537709271347547, 'eval_recall': 0.2351424519788578, 'eval_runtime': 23.9275, 'eval_samples_per_second': 324.188, 'eval_steps_per_second': 20.27, 'epoch': 0.06}
{'loss': 1.4228, 'learning_rate': 4.9250000000000004e-05, 'epoch': 0.12}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.9819027781486511, 'eval_accuracy': 0.683769498517468, 'eval_f1': 0.6709650985681015, 'eval_precision': 0.7056729362607053, 'eval_recall': 0.683769498517468, 'eval_runtime': 24.8604, 'eval_samples_per_second': 312.022, 'eval_steps_per_second': 19.509, 'epoch': 0.12}
{'loss': 0.8705, 'learning_rate': 4.9002878289473684e-05, 'epoch': 0.18}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.899543046951294, 'eval_accuracy': 0.6927936057754287, 'eval_f1': 0.6761589078478173, 'eval_precision': 0.7061710922839609, 'eval_recall': 0.6927936057754287, 'eval_runtime': 24.1875, 'eval_samples_per_second': 320.702, 'eval_steps_per_second': 20.052, 'epoch': 0.18}
{'loss': 0.7742, 'learning_rate': 4.7974917763157895e-05, 'epoch': 0.24}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8965387344360352, 'eval_accuracy': 0.6912466159597782, 'eval_f1': 0.7011200048508109, 'eval_precision': 0.7455872016026324, 'eval_recall': 0.6912466159597782, 'eval_runtime': 24.6634, 'eval_samples_per_second': 314.515, 'eval_steps_per_second': 19.665, 'epoch': 0.24}
{'loss': 0.8003, 'learning_rate': 4.6946957236842107e-05, 'epoch': 0.3}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.9283761978149414, 'eval_accuracy': 0.6845429934252932, 'eval_f1': 0.6820980313371671, 'eval_precision': 0.7198746392364037, 'eval_recall': 0.6845429934252932, 'eval_runtime': 24.9093, 'eval_samples_per_second': 311.41, 'eval_steps_per_second': 19.471, 'epoch': 0.3}
{'loss': 0.7822, 'learning_rate': 4.591899671052632e-05, 'epoch': 0.36}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7796012759208679, 'eval_accuracy': 0.7364960680675519, 'eval_f1': 0.722276894926144, 'eval_precision': 0.7301788898207838, 'eval_recall': 0.7364960680675519, 'eval_runtime': 24.9095, 'eval_samples_per_second': 311.408, 'eval_steps_per_second': 19.471, 'epoch': 0.36}
{'loss': 0.732, 'learning_rate': 4.489103618421053e-05, 'epoch': 0.41}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7662121653556824, 'eval_accuracy': 0.7389454686089983, 'eval_f1': 0.7283649681701223, 'eval_precision': 0.7472150307731832, 'eval_recall': 0.7389454686089983, 'eval_runtime': 24.2333, 'eval_samples_per_second': 320.097, 'eval_steps_per_second': 20.014, 'epoch': 0.41}
{'loss': 0.7381, 'learning_rate': 4.3863075657894734e-05, 'epoch': 0.47}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8440464735031128, 'eval_accuracy': 0.7138068841046796, 'eval_f1': 0.7089575837642882, 'eval_precision': 0.7342117502436788, 'eval_recall': 0.7138068841046796, 'eval_runtime': 25.0127, 'eval_samples_per_second': 310.122, 'eval_steps_per_second': 19.39, 'epoch': 0.47}
{'loss': 0.6805, 'learning_rate': 4.283511513157895e-05, 'epoch': 0.53}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7758824825286865, 'eval_accuracy': 0.7319840144385716, 'eval_f1': 0.7356251247455862, 'eval_precision': 0.7506800003796024, 'eval_recall': 0.7319840144385716, 'eval_runtime': 24.8712, 'eval_samples_per_second': 311.887, 'eval_steps_per_second': 19.5, 'epoch': 0.53}
{'loss': 0.7092, 'learning_rate': 4.1807154605263163e-05, 'epoch': 0.59}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7744946479797363, 'eval_accuracy': 0.7385587211550857, 'eval_f1': 0.7379416233778328, 'eval_precision': 0.7505403578693205, 'eval_recall': 0.7385587211550857, 'eval_runtime': 24.8372, 'eval_samples_per_second': 312.314, 'eval_steps_per_second': 19.527, 'epoch': 0.59}
{'loss': 0.7007, 'learning_rate': 4.077919407894737e-05, 'epoch': 0.65}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8264136910438538, 'eval_accuracy': 0.7206394224571355, 'eval_f1': 0.7103802571238108, 'eval_precision': 0.7424387120224722, 'eval_recall': 0.7206394224571355, 'eval_runtime': 24.7406, 'eval_samples_per_second': 313.533, 'eval_steps_per_second': 19.603, 'epoch': 0.65}
{'loss': 0.7082, 'learning_rate': 3.975123355263158e-05, 'epoch': 0.71}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7330374717712402, 'eval_accuracy': 0.743973185509862, 'eval_f1': 0.7407326519489287, 'eval_precision': 0.7503357783250844, 'eval_recall': 0.743973185509862, 'eval_runtime': 23.8802, 'eval_samples_per_second': 324.83, 'eval_steps_per_second': 20.31, 'epoch': 0.71}
{'loss': 0.7414, 'learning_rate': 3.872327302631579e-05, 'epoch': 0.77}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7130571007728577, 'eval_accuracy': 0.7493876498646383, 'eval_f1': 0.7428880943387189, 'eval_precision': 0.7498394746250502, 'eval_recall': 0.7493876498646383, 'eval_runtime': 23.6286, 'eval_samples_per_second': 328.288, 'eval_steps_per_second': 20.526, 'epoch': 0.77}
{'loss': 0.6939, 'learning_rate': 3.76953125e-05, 'epoch': 0.83}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7565488219261169, 'eval_accuracy': 0.7361093206136393, 'eval_f1': 0.7280167931772894, 'eval_precision': 0.7465703625651221, 'eval_recall': 0.7361093206136393, 'eval_runtime': 23.5501, 'eval_samples_per_second': 329.383, 'eval_steps_per_second': 20.594, 'epoch': 0.83}
{'loss': 0.6858, 'learning_rate': 3.6667351973684214e-05, 'epoch': 0.89}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7785608768463135, 'eval_accuracy': 0.7273430449916205, 'eval_f1': 0.7197136163941744, 'eval_precision': 0.7486749676497336, 'eval_recall': 0.7273430449916205, 'eval_runtime': 24.9684, 'eval_samples_per_second': 310.672, 'eval_steps_per_second': 19.425, 'epoch': 0.89}
{'loss': 0.6623, 'learning_rate': 3.5639391447368425e-05, 'epoch': 0.95}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8018583059310913, 'eval_accuracy': 0.7126466417429419, 'eval_f1': 0.7064557527519492, 'eval_precision': 0.7350072028950623, 'eval_recall': 0.7126466417429419, 'eval_runtime': 25.2948, 'eval_samples_per_second': 306.664, 'eval_steps_per_second': 19.174, 'epoch': 0.95}
{'loss': 0.6458, 'learning_rate': 3.4611430921052636e-05, 'epoch': 1.01}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7576695680618286, 'eval_accuracy': 0.7385587211550857, 'eval_f1': 0.7340170347716795, 'eval_precision': 0.7555948863158217, 'eval_recall': 0.7385587211550857, 'eval_runtime': 23.5579, 'eval_samples_per_second': 329.274, 'eval_steps_per_second': 20.588, 'epoch': 1.01}
{'loss': 0.4686, 'learning_rate': 3.358347039473684e-05, 'epoch': 1.07}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7607273459434509, 'eval_accuracy': 0.7465515018692793, 'eval_f1': 0.7392754935639518, 'eval_precision': 0.7536522777689992, 'eval_recall': 0.7465515018692793, 'eval_runtime': 23.2848, 'eval_samples_per_second': 333.135, 'eval_steps_per_second': 20.829, 'epoch': 1.07}
{'loss': 0.4924, 'learning_rate': 3.255550986842105e-05, 'epoch': 1.13}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.804895281791687, 'eval_accuracy': 0.7328864251643676, 'eval_f1': 0.7329518399744394, 'eval_precision': 0.7517646755946042, 'eval_recall': 0.7328864251643676, 'eval_runtime': 23.3582, 'eval_samples_per_second': 332.089, 'eval_steps_per_second': 20.764, 'epoch': 1.13}
{'loss': 0.5052, 'learning_rate': 3.1527549342105264e-05, 'epoch': 1.18}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7603083848953247, 'eval_accuracy': 0.7401057109707361, 'eval_f1': 0.7370255430736722, 'eval_precision': 0.7565895522905857, 'eval_recall': 0.7401057109707361, 'eval_runtime': 23.3299, 'eval_samples_per_second': 332.491, 'eval_steps_per_second': 20.789, 'epoch': 1.18}
{'loss': 0.4902, 'learning_rate': 3.0499588815789475e-05, 'epoch': 1.24}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.825619101524353, 'eval_accuracy': 0.7269562975377079, 'eval_f1': 0.7284071838886658, 'eval_precision': 0.7564747265021021, 'eval_recall': 0.7269562975377079, 'eval_runtime': 23.2288, 'eval_samples_per_second': 333.939, 'eval_steps_per_second': 20.879, 'epoch': 1.24}
{'loss': 0.4506, 'learning_rate': 2.9471628289473687e-05, 'epoch': 1.3}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7735508680343628, 'eval_accuracy': 0.7465515018692793, 'eval_f1': 0.7461990412428903, 'eval_precision': 0.760228556550721, 'eval_recall': 0.7465515018692793, 'eval_runtime': 23.6812, 'eval_samples_per_second': 327.56, 'eval_steps_per_second': 20.48, 'epoch': 1.3}
{'loss': 0.4309, 'learning_rate': 2.8443667763157895e-05, 'epoch': 1.36}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7969691753387451, 'eval_accuracy': 0.747840660048988, 'eval_f1': 0.741291106050488, 'eval_precision': 0.7529818189718069, 'eval_recall': 0.747840660048988, 'eval_runtime': 23.6163, 'eval_samples_per_second': 328.46, 'eval_steps_per_second': 20.537, 'epoch': 1.36}
{'loss': 0.4599, 'learning_rate': 2.741570723684211e-05, 'epoch': 1.42}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7411457300186157, 'eval_accuracy': 0.7606033260281037, 'eval_f1': 0.7547242775881314, 'eval_precision': 0.7596980457602515, 'eval_recall': 0.7606033260281037, 'eval_runtime': 23.8888, 'eval_samples_per_second': 324.713, 'eval_steps_per_second': 20.302, 'epoch': 1.42}
{'loss': 0.4465, 'learning_rate': 2.6387746710526318e-05, 'epoch': 1.48}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8121508955955505, 'eval_accuracy': 0.7466804176872502, 'eval_f1': 0.7381231037881751, 'eval_precision': 0.7546704165283991, 'eval_recall': 0.7466804176872502, 'eval_runtime': 23.7041, 'eval_samples_per_second': 327.243, 'eval_steps_per_second': 20.461, 'epoch': 1.48}
{'loss': 0.4449, 'learning_rate': 2.5359786184210526e-05, 'epoch': 1.54}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7466075420379639, 'eval_accuracy': 0.7559623565811525, 'eval_f1': 0.750613960519202, 'eval_precision': 0.7596874321798003, 'eval_recall': 0.7559623565811525, 'eval_runtime': 24.0757, 'eval_samples_per_second': 322.191, 'eval_steps_per_second': 20.145, 'epoch': 1.54}
{'loss': 0.4201, 'learning_rate': 2.4331825657894737e-05, 'epoch': 1.6}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.755074143409729, 'eval_accuracy': 0.7550599458553564, 'eval_f1': 0.757149503859975, 'eval_precision': 0.7675458301989507, 'eval_recall': 0.7550599458553564, 'eval_runtime': 24.7379, 'eval_samples_per_second': 313.567, 'eval_steps_per_second': 19.606, 'epoch': 1.6}
{'loss': 0.4406, 'learning_rate': 2.330386513157895e-05, 'epoch': 1.66}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.770708441734314, 'eval_accuracy': 0.7429418589660951, 'eval_f1': 0.7396704318280701, 'eval_precision': 0.7567436446565737, 'eval_recall': 0.7429418589660951, 'eval_runtime': 24.3113, 'eval_samples_per_second': 319.07, 'eval_steps_per_second': 19.95, 'epoch': 1.66}
{'loss': 0.4834, 'learning_rate': 2.227590460526316e-05, 'epoch': 1.72}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7307701706886292, 'eval_accuracy': 0.7560912723991233, 'eval_f1': 0.750338107157692, 'eval_precision': 0.7566171864251497, 'eval_recall': 0.7560912723991233, 'eval_runtime': 23.9942, 'eval_samples_per_second': 323.287, 'eval_steps_per_second': 20.213, 'epoch': 1.72}
{'loss': 0.3964, 'learning_rate': 2.1247944078947368e-05, 'epoch': 1.78}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7685049772262573, 'eval_accuracy': 0.7487430707747841, 'eval_f1': 0.7440946903158471, 'eval_precision': 0.756848669143101, 'eval_recall': 0.7487430707747841, 'eval_runtime': 24.1124, 'eval_samples_per_second': 321.702, 'eval_steps_per_second': 20.114, 'epoch': 1.78}
{'loss': 0.4173, 'learning_rate': 2.021998355263158e-05, 'epoch': 1.84}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7316918969154358, 'eval_accuracy': 0.7567358514889777, 'eval_f1': 0.7538831206789541, 'eval_precision': 0.7639325275749203, 'eval_recall': 0.7567358514889777, 'eval_runtime': 23.9175, 'eval_samples_per_second': 324.324, 'eval_steps_per_second': 20.278, 'epoch': 1.84}
{'loss': 0.4296, 'learning_rate': 1.919202302631579e-05, 'epoch': 1.9}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7335436940193176, 'eval_accuracy': 0.7548021142194147, 'eval_f1': 0.7531262016463234, 'eval_precision': 0.7635593112496809, 'eval_recall': 0.7548021142194147, 'eval_runtime': 24.7611, 'eval_samples_per_second': 313.273, 'eval_steps_per_second': 19.587, 'epoch': 1.9}
{'loss': 0.4222, 'learning_rate': 1.8164062500000002e-05, 'epoch': 1.95}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7167178988456726, 'eval_accuracy': 0.7611189892999871, 'eval_f1': 0.757629336243296, 'eval_precision': 0.7618639234344838, 'eval_recall': 0.7611189892999871, 'eval_runtime': 23.9788, 'eval_samples_per_second': 323.494, 'eval_steps_per_second': 20.226, 'epoch': 1.95}
{'loss': 0.3703, 'learning_rate': 1.714638157894737e-05, 'epoch': 2.01}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8357546329498291, 'eval_accuracy': 0.7370117313394353, 'eval_f1': 0.7353611744892272, 'eval_precision': 0.756923179275088, 'eval_recall': 0.7370117313394353, 'eval_runtime': 24.1536, 'eval_samples_per_second': 321.153, 'eval_steps_per_second': 20.08, 'epoch': 2.01}
{'loss': 0.2542, 'learning_rate': 1.611842105263158e-05, 'epoch': 2.07}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7981708645820618, 'eval_accuracy': 0.7540286193115895, 'eval_f1': 0.7541956785766519, 'eval_precision': 0.7623108823625716, 'eval_recall': 0.7540286193115895, 'eval_runtime': 24.3324, 'eval_samples_per_second': 318.793, 'eval_steps_per_second': 19.932, 'epoch': 2.07}
{'loss': 0.2191, 'learning_rate': 1.5090460526315788e-05, 'epoch': 2.13}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8092733025550842, 'eval_accuracy': 0.7582828413046281, 'eval_f1': 0.7588496539971398, 'eval_precision': 0.7625126669416972, 'eval_recall': 0.7582828413046281, 'eval_runtime': 24.3788, 'eval_samples_per_second': 318.187, 'eval_steps_per_second': 19.894, 'epoch': 2.13}
{'loss': 0.2398, 'learning_rate': 1.4062500000000001e-05, 'epoch': 2.19}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8349934220314026, 'eval_accuracy': 0.7568647673069485, 'eval_f1': 0.7560057663977741, 'eval_precision': 0.764413584901392, 'eval_recall': 0.7568647673069485, 'eval_runtime': 24.3797, 'eval_samples_per_second': 318.175, 'eval_steps_per_second': 19.894, 'epoch': 2.19}
{'loss': 0.2441, 'learning_rate': 1.3034539473684213e-05, 'epoch': 2.25}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8383969664573669, 'eval_accuracy': 0.7594430836663658, 'eval_f1': 0.7565629992384117, 'eval_precision': 0.760709530498502, 'eval_recall': 0.7594430836663658, 'eval_runtime': 24.4885, 'eval_samples_per_second': 316.761, 'eval_steps_per_second': 19.805, 'epoch': 2.25}
{'loss': 0.2185, 'learning_rate': 1.200657894736842e-05, 'epoch': 2.31}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.9093047976493835, 'eval_accuracy': 0.7484852391388424, 'eval_f1': 0.7462761810380385, 'eval_precision': 0.7588494016560525, 'eval_recall': 0.7484852391388424, 'eval_runtime': 24.3007, 'eval_samples_per_second': 319.209, 'eval_steps_per_second': 19.958, 'epoch': 2.31}
{'loss': 0.2201, 'learning_rate': 1.0978618421052632e-05, 'epoch': 2.37}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8529410362243652, 'eval_accuracy': 0.7633105582054918, 'eval_f1': 0.7596892686913869, 'eval_precision': 0.760808617792133, 'eval_recall': 0.7633105582054918, 'eval_runtime': 24.0943, 'eval_samples_per_second': 321.943, 'eval_steps_per_second': 20.129, 'epoch': 2.37}
{'loss': 0.2391, 'learning_rate': 9.950657894736842e-06, 'epoch': 2.43}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8692417144775391, 'eval_accuracy': 0.7577671780327446, 'eval_f1': 0.7557921536535376, 'eval_precision': 0.7596548798834499, 'eval_recall': 0.7577671780327446, 'eval_runtime': 24.2585, 'eval_samples_per_second': 319.764, 'eval_steps_per_second': 19.993, 'epoch': 2.43}
{'loss': 0.2491, 'learning_rate': 8.922697368421053e-06, 'epoch': 2.49}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8953176736831665, 'eval_accuracy': 0.7462936702333376, 'eval_f1': 0.7441710614353529, 'eval_precision': 0.7552333873257262, 'eval_recall': 0.7462936702333376, 'eval_runtime': 24.0819, 'eval_samples_per_second': 322.109, 'eval_steps_per_second': 20.14, 'epoch': 2.49}
{'loss': 0.2074, 'learning_rate': 7.894736842105263e-06, 'epoch': 2.55}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8768405914306641, 'eval_accuracy': 0.7532551244037643, 'eval_f1': 0.7508448151233592, 'eval_precision': 0.7551323333983814, 'eval_recall': 0.7532551244037643, 'eval_runtime': 25.1484, 'eval_samples_per_second': 308.449, 'eval_steps_per_second': 19.286, 'epoch': 2.55}
{'loss': 0.2326, 'learning_rate': 6.8667763157894735e-06, 'epoch': 2.61}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8770220279693604, 'eval_accuracy': 0.7558334407631816, 'eval_f1': 0.7540232283188598, 'eval_precision': 0.7582470523769911, 'eval_recall': 0.7558334407631816, 'eval_runtime': 26.6011, 'eval_samples_per_second': 291.604, 'eval_steps_per_second': 18.232, 'epoch': 2.61}
{'loss': 0.197, 'learning_rate': 5.838815789473685e-06, 'epoch': 2.67}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.9430109262466431, 'eval_accuracy': 0.7424261956942116, 'eval_f1': 0.7411920152841047, 'eval_precision': 0.7560716052306107, 'eval_recall': 0.7424261956942116, 'eval_runtime': 26.6072, 'eval_samples_per_second': 291.538, 'eval_steps_per_second': 18.228, 'epoch': 2.67}
{'loss': 0.2569, 'learning_rate': 4.810855263157895e-06, 'epoch': 2.73}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.9028953313827515, 'eval_accuracy': 0.7499033131365218, 'eval_f1': 0.7472670205159356, 'eval_precision': 0.7566771025385328, 'eval_recall': 0.7499033131365218, 'eval_runtime': 26.6029, 'eval_samples_per_second': 291.585, 'eval_steps_per_second': 18.231, 'epoch': 2.73}
{'loss': 0.1844, 'learning_rate': 3.7828947368421055e-06, 'epoch': 2.78}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8994812965393066, 'eval_accuracy': 0.7541575351295604, 'eval_f1': 0.7556437247818125, 'eval_precision': 0.7640591374286768, 'eval_recall': 0.7541575351295604, 'eval_runtime': 26.5739, 'eval_samples_per_second': 291.903, 'eval_steps_per_second': 18.251, 'epoch': 2.78}
{'loss': 0.2343, 'learning_rate': 2.7549342105263157e-06, 'epoch': 2.84}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8941711783409119, 'eval_accuracy': 0.756478019853036, 'eval_f1': 0.7551635088423485, 'eval_precision': 0.7610148501062574, 'eval_recall': 0.756478019853036, 'eval_runtime': 24.7287, 'eval_samples_per_second': 313.683, 'eval_steps_per_second': 19.613, 'epoch': 2.84}
{'loss': 0.2328, 'learning_rate': 1.7269736842105266e-06, 'epoch': 2.9}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.9001975655555725, 'eval_accuracy': 0.754544282583473, 'eval_f1': 0.7541927395568432, 'eval_precision': 0.7615503003319412, 'eval_recall': 0.754544282583473, 'eval_runtime': 24.6439, 'eval_samples_per_second': 314.763, 'eval_steps_per_second': 19.68, 'epoch': 2.9}
{'loss': 0.239, 'learning_rate': 6.990131578947369e-07, 'epoch': 2.96}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8927129507064819, 'eval_accuracy': 0.7577671780327446, 'eval_f1': 0.7566188111542367, 'eval_precision': 0.7627854759109136, 'eval_recall': 0.7577671780327446, 'eval_runtime': 26.6062, 'eval_samples_per_second': 291.549, 'eval_steps_per_second': 18.229, 'epoch': 2.96}
{'train_runtime': 2420.6782, 'train_samples_per_second': 33.462, 'train_steps_per_second': 2.092, 'train_loss': 0.5103363904154508, 'epoch': 3.0}


TrainOutput(global_step=5064, training_loss=0.5103363904154508, metrics={'train_runtime': 2420.6782, 'train_samples_per_second': 33.462, 'train_steps_per_second': 2.092, 'train_loss': 0.5103363904154508, 'epoch': 3.0})

In [19]:
from sklearn.metrics import classification_report
predictions = trainer.predict(test_dataset)
labels = predictions.label_ids
preds = predictions.predictions.argmax(-1)

print(classification_report(labels, preds, target_names=class_labels))

  0%|          | 0/943 [00:00<?, ?it/s]

                                        precision    recall  f1-score   support

            other_relevant_information       0.56      0.81      0.66      1070
      displaced_people_and_evacuations       0.85      0.92      0.88       790
                  sympathy_and_support       0.83      0.78      0.80      1617
     infrastructure_and_utility_damage       0.92      0.93      0.93      1447
                      not_humanitarian       0.55      0.68      0.61      1245
                    caution_and_advice       0.64      0.37      0.47      2407
              requests_or_urgent_needs       0.39      0.78      0.52       521
rescue_volunteering_or_donation_effort       0.86      0.82      0.84      4219
                injured_or_dead_people       0.84      0.81      0.82      1772

                              accuracy                           0.75     15088
                             macro avg       0.72      0.77      0.73     15088
                          weighted avg

In [20]:
model_name = 'distilbert-base-uncased'
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)

train_encodings = tokenizer(train_tweets, truncation=True, padding=True, max_length=512)
val_encodings = tokenizer(val_tweets, truncation=True, padding=True, max_length=512)
test_encodings = tokenizer(test_tweets, truncation=True, padding=True, max_length=512)

train_dataset = TweetDataset(train_encodings, train_labels)  # Use resampled labels
val_dataset = TweetDataset(val_encodings, val_labels)
test_dataset = TweetDataset(test_encodings, test_labels)

model = DistilBertForSequenceClassification.from_pretrained(model_name, num_labels=len(class_labels))
model.to(device)

Downloading tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Downloading vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.weight', 'pre_classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 

In [21]:
training_args = TrainingArguments(
    output_dir='./content/bert_model_output',
    do_train=True,
    do_eval=True,
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    fp16=True,
    warmup_steps=200,
    weight_decay=0.01,
    logging_strategy='steps',
    logging_dir='./content/logs',
    logging_steps=100,
    evaluation_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=2000,
    load_best_model_at_end=True,
)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None)


  0%|          | 0/5064 [00:00<?, ?it/s]

{'loss': 2.1419, 'learning_rate': 2.5e-05, 'epoch': 0.06}


  0%|          | 0/485 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 1.9131654500961304, 'eval_accuracy': 0.35632332087147095, 'eval_f1': 0.3015517081734118, 'eval_precision': 0.5202917884999859, 'eval_recall': 0.35632332087147095, 'eval_runtime': 13.3391, 'eval_samples_per_second': 581.523, 'eval_steps_per_second': 36.359, 'epoch': 0.06}
{'loss': 1.288, 'learning_rate': 4.975e-05, 'epoch': 0.12}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.9364122748374939, 'eval_accuracy': 0.7007863864896223, 'eval_f1': 0.6904193448164562, 'eval_precision': 0.7105824186038844, 'eval_recall': 0.7007863864896223, 'eval_runtime': 13.0385, 'eval_samples_per_second': 594.932, 'eval_steps_per_second': 37.198, 'epoch': 0.12}
{'loss': 0.9252, 'learning_rate': 4.898231907894737e-05, 'epoch': 0.18}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8912947177886963, 'eval_accuracy': 0.7033647028490396, 'eval_f1': 0.6962402713682452, 'eval_precision': 0.7153265625766249, 'eval_recall': 0.7033647028490396, 'eval_runtime': 13.0937, 'eval_samples_per_second': 592.421, 'eval_steps_per_second': 37.041, 'epoch': 0.18}
{'loss': 0.7781, 'learning_rate': 4.795435855263158e-05, 'epoch': 0.24}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.9506560564041138, 'eval_accuracy': 0.6899574577800696, 'eval_f1': 0.6950076346493601, 'eval_precision': 0.7307416475341719, 'eval_recall': 0.6899574577800696, 'eval_runtime': 13.0122, 'eval_samples_per_second': 596.133, 'eval_steps_per_second': 37.273, 'epoch': 0.24}
{'loss': 0.8103, 'learning_rate': 4.6926398026315795e-05, 'epoch': 0.3}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8809457421302795, 'eval_accuracy': 0.70491169266469, 'eval_f1': 0.7095836074588076, 'eval_precision': 0.7377961123619305, 'eval_recall': 0.70491169266469, 'eval_runtime': 12.6706, 'eval_samples_per_second': 612.206, 'eval_steps_per_second': 38.278, 'epoch': 0.3}
{'loss': 0.78, 'learning_rate': 4.58984375e-05, 'epoch': 0.36}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7902679443359375, 'eval_accuracy': 0.7336599200721928, 'eval_f1': 0.7241894031399231, 'eval_precision': 0.730452141047995, 'eval_recall': 0.7336599200721928, 'eval_runtime': 12.8466, 'eval_samples_per_second': 603.82, 'eval_steps_per_second': 37.753, 'epoch': 0.36}
{'loss': 0.7411, 'learning_rate': 4.487047697368421e-05, 'epoch': 0.41}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7969747185707092, 'eval_accuracy': 0.7286322031713292, 'eval_f1': 0.7247663481602811, 'eval_precision': 0.7477564582566343, 'eval_recall': 0.7286322031713292, 'eval_runtime': 12.649, 'eval_samples_per_second': 613.252, 'eval_steps_per_second': 38.343, 'epoch': 0.41}
{'loss': 0.7734, 'learning_rate': 4.384251644736842e-05, 'epoch': 0.47}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8417895436286926, 'eval_accuracy': 0.707747840660049, 'eval_f1': 0.7025951067576471, 'eval_precision': 0.7305990870183083, 'eval_recall': 0.707747840660049, 'eval_runtime': 12.4041, 'eval_samples_per_second': 625.357, 'eval_steps_per_second': 39.1, 'epoch': 0.47}
{'loss': 0.6953, 'learning_rate': 4.2814555921052634e-05, 'epoch': 0.53}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8312391042709351, 'eval_accuracy': 0.7101972412014954, 'eval_f1': 0.7167889584761664, 'eval_precision': 0.7449696230050482, 'eval_recall': 0.7101972412014954, 'eval_runtime': 12.5147, 'eval_samples_per_second': 619.833, 'eval_steps_per_second': 38.755, 'epoch': 0.53}
{'loss': 0.695, 'learning_rate': 4.1786595394736846e-05, 'epoch': 0.59}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7669773101806641, 'eval_accuracy': 0.737269562975377, 'eval_f1': 0.7355745179302609, 'eval_precision': 0.7489052302407762, 'eval_recall': 0.737269562975377, 'eval_runtime': 12.3823, 'eval_samples_per_second': 626.461, 'eval_steps_per_second': 39.169, 'epoch': 0.59}
{'loss': 0.6834, 'learning_rate': 4.075863486842105e-05, 'epoch': 0.65}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8113236427307129, 'eval_accuracy': 0.7169008637359804, 'eval_f1': 0.7084835214892135, 'eval_precision': 0.735033037839787, 'eval_recall': 0.7169008637359804, 'eval_runtime': 12.6862, 'eval_samples_per_second': 611.453, 'eval_steps_per_second': 38.231, 'epoch': 0.65}
{'loss': 0.7133, 'learning_rate': 3.973067434210527e-05, 'epoch': 0.71}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7460393905639648, 'eval_accuracy': 0.7359804047956684, 'eval_f1': 0.7361709734573462, 'eval_precision': 0.7496300722141201, 'eval_recall': 0.7359804047956684, 'eval_runtime': 12.7948, 'eval_samples_per_second': 606.263, 'eval_steps_per_second': 37.906, 'epoch': 0.71}
{'loss': 0.7383, 'learning_rate': 3.870271381578947e-05, 'epoch': 0.77}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.706682026386261, 'eval_accuracy': 0.7509346396802887, 'eval_f1': 0.7411404783158535, 'eval_precision': 0.7530663092027522, 'eval_recall': 0.7509346396802887, 'eval_runtime': 12.7922, 'eval_samples_per_second': 606.387, 'eval_steps_per_second': 37.914, 'epoch': 0.77}
{'loss': 0.6896, 'learning_rate': 3.7674753289473685e-05, 'epoch': 0.83}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7636823654174805, 'eval_accuracy': 0.7336599200721928, 'eval_f1': 0.7289351484864164, 'eval_precision': 0.7481006820160768, 'eval_recall': 0.7336599200721928, 'eval_runtime': 12.4722, 'eval_samples_per_second': 621.944, 'eval_steps_per_second': 38.887, 'epoch': 0.83}
{'loss': 0.6917, 'learning_rate': 3.6646792763157896e-05, 'epoch': 0.89}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7748085856437683, 'eval_accuracy': 0.7261828026298827, 'eval_f1': 0.7220705061998156, 'eval_precision': 0.7486405063802911, 'eval_recall': 0.7261828026298827, 'eval_runtime': 12.6583, 'eval_samples_per_second': 612.8, 'eval_steps_per_second': 38.315, 'epoch': 0.89}
{'loss': 0.6538, 'learning_rate': 3.561883223684211e-05, 'epoch': 0.95}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7636112570762634, 'eval_accuracy': 0.7303081088049503, 'eval_f1': 0.7248142633891055, 'eval_precision': 0.7459500496934476, 'eval_recall': 0.7303081088049503, 'eval_runtime': 12.6443, 'eval_samples_per_second': 613.476, 'eval_steps_per_second': 38.357, 'epoch': 0.95}
{'loss': 0.645, 'learning_rate': 3.459087171052632e-05, 'epoch': 1.01}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.729108989238739, 'eval_accuracy': 0.7462936702333376, 'eval_f1': 0.7397683981676452, 'eval_precision': 0.759934265512884, 'eval_recall': 0.7462936702333376, 'eval_runtime': 12.9903, 'eval_samples_per_second': 597.14, 'eval_steps_per_second': 37.336, 'epoch': 1.01}
{'loss': 0.4808, 'learning_rate': 3.356291118421053e-05, 'epoch': 1.07}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8013255000114441, 'eval_accuracy': 0.7349490782519015, 'eval_f1': 0.7311247909940749, 'eval_precision': 0.7522971125073361, 'eval_recall': 0.7349490782519015, 'eval_runtime': 12.982, 'eval_samples_per_second': 597.518, 'eval_steps_per_second': 37.359, 'epoch': 1.07}
{'loss': 0.5143, 'learning_rate': 3.2534950657894735e-05, 'epoch': 1.13}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7633956670761108, 'eval_accuracy': 0.7461647544153668, 'eval_f1': 0.7447752487401544, 'eval_precision': 0.7543370639861077, 'eval_recall': 0.7461647544153668, 'eval_runtime': 13.1144, 'eval_samples_per_second': 591.487, 'eval_steps_per_second': 36.982, 'epoch': 1.13}
{'loss': 0.5056, 'learning_rate': 3.150699013157895e-05, 'epoch': 1.18}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7702411413192749, 'eval_accuracy': 0.7442310171458038, 'eval_f1': 0.7440081725670555, 'eval_precision': 0.7609343862974124, 'eval_recall': 0.7442310171458038, 'eval_runtime': 12.9232, 'eval_samples_per_second': 600.237, 'eval_steps_per_second': 37.529, 'epoch': 1.18}
{'loss': 0.5167, 'learning_rate': 3.047902960526316e-05, 'epoch': 1.24}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7821760773658752, 'eval_accuracy': 0.7375273946113188, 'eval_f1': 0.7376685527839122, 'eval_precision': 0.7572541135474696, 'eval_recall': 0.7375273946113188, 'eval_runtime': 12.9412, 'eval_samples_per_second': 599.403, 'eval_steps_per_second': 37.477, 'epoch': 1.24}
{'loss': 0.4526, 'learning_rate': 2.945106907894737e-05, 'epoch': 1.3}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7773452401161194, 'eval_accuracy': 0.737398478793348, 'eval_f1': 0.7403696355647867, 'eval_precision': 0.7580766277451305, 'eval_recall': 0.737398478793348, 'eval_runtime': 12.9818, 'eval_samples_per_second': 597.528, 'eval_steps_per_second': 37.36, 'epoch': 1.3}
{'loss': 0.4385, 'learning_rate': 2.8443667763157895e-05, 'epoch': 1.36}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7570482492446899, 'eval_accuracy': 0.7537707876756478, 'eval_f1': 0.7482035257935071, 'eval_precision': 0.7534513898509602, 'eval_recall': 0.7537707876756478, 'eval_runtime': 13.0276, 'eval_samples_per_second': 595.427, 'eval_steps_per_second': 37.229, 'epoch': 1.36}
{'loss': 0.4869, 'learning_rate': 2.741570723684211e-05, 'epoch': 1.42}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7178180813789368, 'eval_accuracy': 0.7617635683898414, 'eval_f1': 0.7591468232039045, 'eval_precision': 0.7628213945026088, 'eval_recall': 0.7617635683898414, 'eval_runtime': 13.1341, 'eval_samples_per_second': 590.599, 'eval_steps_per_second': 36.927, 'epoch': 1.42}
{'loss': 0.4615, 'learning_rate': 2.6387746710526318e-05, 'epoch': 1.48}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7831714749336243, 'eval_accuracy': 0.7426840273301534, 'eval_f1': 0.7388926935392288, 'eval_precision': 0.7526822923895383, 'eval_recall': 0.7426840273301534, 'eval_runtime': 13.1628, 'eval_samples_per_second': 589.313, 'eval_steps_per_second': 36.846, 'epoch': 1.48}
{'loss': 0.4401, 'learning_rate': 2.5359786184210526e-05, 'epoch': 1.54}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7372922301292419, 'eval_accuracy': 0.7557045249452108, 'eval_f1': 0.7537190964791115, 'eval_precision': 0.7593282285475111, 'eval_recall': 0.7557045249452108, 'eval_runtime': 13.103, 'eval_samples_per_second': 592.0, 'eval_steps_per_second': 37.014, 'epoch': 1.54}
{'loss': 0.4427, 'learning_rate': 2.4331825657894737e-05, 'epoch': 1.6}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7510848045349121, 'eval_accuracy': 0.7522237978599974, 'eval_f1': 0.7520708815441424, 'eval_precision': 0.7596503256850423, 'eval_recall': 0.7522237978599974, 'eval_runtime': 13.5921, 'eval_samples_per_second': 570.701, 'eval_steps_per_second': 35.683, 'epoch': 1.6}
{'loss': 0.4572, 'learning_rate': 2.330386513157895e-05, 'epoch': 1.66}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.759259045124054, 'eval_accuracy': 0.7426840273301534, 'eval_f1': 0.7390927748277205, 'eval_precision': 0.7549339470146514, 'eval_recall': 0.7426840273301534, 'eval_runtime': 12.4847, 'eval_samples_per_second': 621.319, 'eval_steps_per_second': 38.847, 'epoch': 1.66}
{'loss': 0.5048, 'learning_rate': 2.227590460526316e-05, 'epoch': 1.72}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7471117377281189, 'eval_accuracy': 0.746809333505221, 'eval_f1': 0.7476448840504191, 'eval_precision': 0.7558483534572801, 'eval_recall': 0.746809333505221, 'eval_runtime': 12.6135, 'eval_samples_per_second': 614.977, 'eval_steps_per_second': 38.451, 'epoch': 1.72}
{'loss': 0.4205, 'learning_rate': 2.1247944078947368e-05, 'epoch': 1.78}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7381922602653503, 'eval_accuracy': 0.753512956039706, 'eval_f1': 0.7495926575697672, 'eval_precision': 0.7542828927743684, 'eval_recall': 0.753512956039706, 'eval_runtime': 12.6572, 'eval_samples_per_second': 612.853, 'eval_steps_per_second': 38.318, 'epoch': 1.78}
{'loss': 0.4378, 'learning_rate': 2.021998355263158e-05, 'epoch': 1.84}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7581716775894165, 'eval_accuracy': 0.749774397318551, 'eval_f1': 0.748103491636849, 'eval_precision': 0.7578016849961453, 'eval_recall': 0.749774397318551, 'eval_runtime': 12.5923, 'eval_samples_per_second': 616.013, 'eval_steps_per_second': 38.516, 'epoch': 1.84}
{'loss': 0.4279, 'learning_rate': 1.919202302631579e-05, 'epoch': 1.9}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7848740816116333, 'eval_accuracy': 0.7425551115121826, 'eval_f1': 0.7407254012409111, 'eval_precision': 0.754804863213075, 'eval_recall': 0.7425551115121826, 'eval_runtime': 12.6152, 'eval_samples_per_second': 614.894, 'eval_steps_per_second': 38.446, 'epoch': 1.9}
{'loss': 0.4429, 'learning_rate': 1.8164062500000002e-05, 'epoch': 1.95}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7355758547782898, 'eval_accuracy': 0.7522237978599974, 'eval_f1': 0.7477127132567268, 'eval_precision': 0.7502370185253144, 'eval_recall': 0.7522237978599974, 'eval_runtime': 12.4854, 'eval_samples_per_second': 621.287, 'eval_steps_per_second': 38.845, 'epoch': 1.95}
{'loss': 0.3741, 'learning_rate': 1.7136101973684213e-05, 'epoch': 2.01}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.7974388003349304, 'eval_accuracy': 0.7434575222379786, 'eval_f1': 0.7406971413311083, 'eval_precision': 0.7546049949466774, 'eval_recall': 0.7434575222379786, 'eval_runtime': 12.4931, 'eval_samples_per_second': 620.903, 'eval_steps_per_second': 38.821, 'epoch': 2.01}
{'loss': 0.2856, 'learning_rate': 1.610814144736842e-05, 'epoch': 2.07}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8066421747207642, 'eval_accuracy': 0.7482274075029006, 'eval_f1': 0.7474551285581336, 'eval_precision': 0.7542530338615103, 'eval_recall': 0.7482274075029006, 'eval_runtime': 12.6074, 'eval_samples_per_second': 615.272, 'eval_steps_per_second': 38.469, 'epoch': 2.07}
{'loss': 0.2337, 'learning_rate': 1.5080180921052633e-05, 'epoch': 2.13}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8196850419044495, 'eval_accuracy': 0.7538997034936187, 'eval_f1': 0.7531076287158235, 'eval_precision': 0.7565679698795539, 'eval_recall': 0.7538997034936187, 'eval_runtime': 12.6529, 'eval_samples_per_second': 613.06, 'eval_steps_per_second': 38.331, 'epoch': 2.13}
{'loss': 0.2718, 'learning_rate': 1.4052220394736842e-05, 'epoch': 2.19}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8673403263092041, 'eval_accuracy': 0.7431996906020368, 'eval_f1': 0.74191867119257, 'eval_precision': 0.7535128345459882, 'eval_recall': 0.7431996906020368, 'eval_runtime': 12.5482, 'eval_samples_per_second': 618.176, 'eval_steps_per_second': 38.651, 'epoch': 2.19}
{'loss': 0.2914, 'learning_rate': 1.3024259868421054e-05, 'epoch': 2.25}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8684523701667786, 'eval_accuracy': 0.7451334278715999, 'eval_f1': 0.7454872970672863, 'eval_precision': 0.7543948402031057, 'eval_recall': 0.7451334278715999, 'eval_runtime': 13.2038, 'eval_samples_per_second': 587.481, 'eval_steps_per_second': 36.732, 'epoch': 2.25}
{'loss': 0.2673, 'learning_rate': 1.1996299342105264e-05, 'epoch': 2.31}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8973827362060547, 'eval_accuracy': 0.7395900476988526, 'eval_f1': 0.7385817334915146, 'eval_precision': 0.7510792058098857, 'eval_recall': 0.7395900476988526, 'eval_runtime': 12.9024, 'eval_samples_per_second': 601.207, 'eval_steps_per_second': 37.59, 'epoch': 2.31}
{'loss': 0.2683, 'learning_rate': 1.0968338815789473e-05, 'epoch': 2.37}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.863448441028595, 'eval_accuracy': 0.7518370504060848, 'eval_f1': 0.7456531330016367, 'eval_precision': 0.7492501782616173, 'eval_recall': 0.7518370504060848, 'eval_runtime': 13.0291, 'eval_samples_per_second': 595.36, 'eval_steps_per_second': 37.224, 'epoch': 2.37}
{'loss': 0.2665, 'learning_rate': 9.940378289473685e-06, 'epoch': 2.43}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8835780024528503, 'eval_accuracy': 0.7431996906020368, 'eval_f1': 0.7447844229188976, 'eval_precision': 0.7543101882205093, 'eval_recall': 0.7431996906020368, 'eval_runtime': 13.2879, 'eval_samples_per_second': 583.765, 'eval_steps_per_second': 36.499, 'epoch': 2.43}
{'loss': 0.2827, 'learning_rate': 8.912417763157894e-06, 'epoch': 2.49}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8548880815505981, 'eval_accuracy': 0.7465515018692793, 'eval_f1': 0.7451989340404944, 'eval_precision': 0.7500988147282833, 'eval_recall': 0.7465515018692793, 'eval_runtime': 13.5576, 'eval_samples_per_second': 572.151, 'eval_steps_per_second': 35.773, 'epoch': 2.49}
{'loss': 0.2402, 'learning_rate': 7.884457236842106e-06, 'epoch': 2.55}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8738114833831787, 'eval_accuracy': 0.7450045120536289, 'eval_f1': 0.7439392021783975, 'eval_precision': 0.7502141683424601, 'eval_recall': 0.7450045120536289, 'eval_runtime': 13.5059, 'eval_samples_per_second': 574.343, 'eval_steps_per_second': 35.91, 'epoch': 2.55}
{'loss': 0.2451, 'learning_rate': 6.856496710526317e-06, 'epoch': 2.61}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8666008710861206, 'eval_accuracy': 0.7473249967771045, 'eval_f1': 0.7457022415056196, 'eval_precision': 0.7502982681432623, 'eval_recall': 0.7473249967771045, 'eval_runtime': 13.4071, 'eval_samples_per_second': 578.573, 'eval_steps_per_second': 36.175, 'epoch': 2.61}
{'loss': 0.2306, 'learning_rate': 5.828536184210527e-06, 'epoch': 2.67}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.912369966506958, 'eval_accuracy': 0.7404924584246487, 'eval_f1': 0.7379659022670713, 'eval_precision': 0.7483462720439309, 'eval_recall': 0.7404924584246487, 'eval_runtime': 14.0178, 'eval_samples_per_second': 553.367, 'eval_steps_per_second': 34.599, 'epoch': 2.67}
{'loss': 0.2709, 'learning_rate': 4.8005756578947365e-06, 'epoch': 2.73}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8849219679832458, 'eval_accuracy': 0.7450045120536289, 'eval_f1': 0.7439181874963987, 'eval_precision': 0.7503744690362306, 'eval_recall': 0.7450045120536289, 'eval_runtime': 13.6408, 'eval_samples_per_second': 568.662, 'eval_steps_per_second': 35.555, 'epoch': 2.73}
{'loss': 0.2284, 'learning_rate': 3.772615131578947e-06, 'epoch': 2.78}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8945534825325012, 'eval_accuracy': 0.7457780069614541, 'eval_f1': 0.7470246941997027, 'eval_precision': 0.7551329043901039, 'eval_recall': 0.7457780069614541, 'eval_runtime': 13.3253, 'eval_samples_per_second': 582.124, 'eval_steps_per_second': 36.397, 'epoch': 2.78}
{'loss': 0.2543, 'learning_rate': 2.744654605263158e-06, 'epoch': 2.84}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8829323649406433, 'eval_accuracy': 0.7479695758669589, 'eval_f1': 0.7465020356814143, 'eval_precision': 0.7514727025601909, 'eval_recall': 0.7479695758669589, 'eval_runtime': 13.0847, 'eval_samples_per_second': 592.83, 'eval_steps_per_second': 37.066, 'epoch': 2.84}
{'loss': 0.2379, 'learning_rate': 1.7166940789473684e-06, 'epoch': 2.9}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8854002356529236, 'eval_accuracy': 0.7487430707747841, 'eval_f1': 0.749469282906591, 'eval_precision': 0.7556845279348615, 'eval_recall': 0.7487430707747841, 'eval_runtime': 13.1889, 'eval_samples_per_second': 588.147, 'eval_steps_per_second': 36.773, 'epoch': 2.9}
{'loss': 0.2657, 'learning_rate': 6.887335526315789e-07, 'epoch': 2.96}


  0%|          | 0/485 [00:00<?, ?it/s]

{'eval_loss': 0.8810961842536926, 'eval_accuracy': 0.7490009024107258, 'eval_f1': 0.7490307792430112, 'eval_precision': 0.7542143741854624, 'eval_recall': 0.7490009024107258, 'eval_runtime': 13.2827, 'eval_samples_per_second': 583.991, 'eval_steps_per_second': 36.514, 'epoch': 2.96}
{'train_runtime': 1235.8153, 'train_samples_per_second': 65.544, 'train_steps_per_second': 4.098, 'train_loss': 0.5237516046135346, 'epoch': 3.0}


TrainOutput(global_step=5064, training_loss=0.5237516046135346, metrics={'train_runtime': 1235.8153, 'train_samples_per_second': 65.544, 'train_steps_per_second': 4.098, 'train_loss': 0.5237516046135346, 'epoch': 3.0})

In [22]:
from sklearn.metrics import classification_report
predictions = trainer.predict(test_dataset)
labels = predictions.label_ids
preds = predictions.predictions.argmax(-1)

print(classification_report(labels, preds, target_names=class_labels))

  0%|          | 0/943 [00:00<?, ?it/s]

                                        precision    recall  f1-score   support

            other_relevant_information       0.62      0.78      0.69      1070
      displaced_people_and_evacuations       0.82      0.93      0.88       790
                  sympathy_and_support       0.82      0.80      0.81      1617
     infrastructure_and_utility_damage       0.93      0.93      0.93      1447
                      not_humanitarian       0.53      0.70      0.61      1245
                    caution_and_advice       0.61      0.42      0.50      2407
              requests_or_urgent_needs       0.39      0.76      0.51       521
rescue_volunteering_or_donation_effort       0.88      0.78      0.83      4219
                injured_or_dead_people       0.82      0.83      0.82      1772

                              accuracy                           0.75     15088
                             macro avg       0.71      0.77      0.73     15088
                          weighted avg