In [1]:
import numpy as np
import pandas as pd
import evaluate
import torch
import json
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from sklearn.metrics import accuracy_score, f1_score
from torch.nn import CrossEntropyLoss

from preprocessing import climate_fever_to_claim_evidence_pairs

In [2]:
# This should always output true now, but worth checking
print(f"is GPU available: {torch.cuda.is_available()}")

is GPU available: True


In [3]:
# The climate-fever dataset is in a format which the transformers Trainer does not understand
# It must be preprocessed using the functions in preprocessing.py

df = pd.read_json("data/climate_fever/climate-fever-dataset-r1.jsonl", lines=True)
preprocessed_df = climate_fever_to_claim_evidence_pairs(df)


In [4]:
# Map evidence_labels to integers so that the Trainer will know what the labels mean
label_dict = {
    "REFUTES": 0,
    "NOT_ENOUGH_INFO": 1,
    "SUPPORTS": 2
}

preprocessed_df["labels"] = preprocessed_df["evidence_label"].map(label_dict)
preprocessed_df

Unnamed: 0,claim_id,claim,evidence_id,evidence_label,evidence,entropy,labels
0,0,Global warming is driving polar bears toward e...,Extinction risk from global warming:170,NOT_ENOUGH_INFO,"""Recent Research Shows Human Activity Driving ...",0.693147,1
1,0,Global warming is driving polar bears toward e...,Global warming:14,SUPPORTS,Environmental impacts include the extinction o...,0.000000,2
2,0,Global warming is driving polar bears toward e...,Global warming:178,NOT_ENOUGH_INFO,Rising temperatures push bees to their physiol...,0.693147,1
3,0,Global warming is driving polar bears toward e...,Habitat destruction:61,SUPPORTS,"Rising global temperatures, caused by the gree...",0.000000,2
4,0,Global warming is driving polar bears toward e...,Polar bear:1328,NOT_ENOUGH_INFO,"""Bear hunting caught in global warming debate"".",0.693147,1
...,...,...,...,...,...,...,...
7670,3134,"Over the last decade, heatwaves are five times...",Bushfires in Australia:126,SUPPORTS,Australia's climate has warmed by more than on...,0.000000,2
7671,3134,"Over the last decade, heatwaves are five times...",Effects of global warming:86,NOT_ENOUGH_INFO,"In the last 30–40 years, heat waves with high ...",0.693147,1
7672,3134,"Over the last decade, heatwaves are five times...",Global warming:155,NOT_ENOUGH_INFO,Many regions have probably already seen increa...,0.693147,1
7673,3134,"Over the last decade, heatwaves are five times...",Global warming:156,NOT_ENOUGH_INFO,"Since the 1950s, droughts and heat waves have ...",0.693147,1


In [5]:
dataset = Dataset.from_pandas(preprocessed_df)
dataset

Dataset({
    features: ['claim_id', 'claim', 'evidence_id', 'evidence_label', 'evidence', 'entropy', 'labels'],
    num_rows: 7675
})

In [6]:
# Shuffle the dataset! This randomly rearranges the dataset, which is good especially with this one since the same claim appears five times in a row
# The seed parameter means we can access the exact same shuffle again if we need to
dataset = dataset.shuffle(seed=13)

In [7]:
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilroberta-base")



tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/480 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

In [8]:
# This splits our dataset so that we use 90% of it for training, and 10% for testing
split_dataset = dataset.train_test_split(test_size=0.1)

In [9]:
def custom_tokenize(examples):
    # The code block below this one can be used to find what the max_length should be set to.
    # Otherwise you have too much padding
    # Consider this properly later
    tokenized_output = tokenizer(
        text=[f"Claim: {claim} Evidence: {evidence}" for claim, evidence in zip(examples["claim"], examples["evidence"])],
        max_length=512, 
        padding="max_length", 
        truncation=True)

    return tokenized_output

tokenized_training_dataset = split_dataset["train"].map(custom_tokenize, batched=True)
tokenized_testing_dataset = split_dataset["test"].map(custom_tokenize, batched=True)

Map:   0%|          | 0/6907 [00:00<?, ? examples/s]

Map:   0%|          | 0/768 [00:00<?, ? examples/s]

In [10]:

# Just for viewing purposes. Input_ids are the tokens, and attention_masks are whether they represent actual words or not.
# The max_length is set to 512 so every entry has been padded to be this long, which may be unnecessary
print(tokenized_training_dataset[6]["claim"])
print(tokenized_training_dataset[6]["evidence"])
print(tokenized_training_dataset[6]["labels"])
print(tokenized_training_dataset[6]["input_ids"])
print(tokenized_training_dataset[6]["attention_mask"])

print(tokenized_testing_dataset[2]["claim"])
print(tokenized_testing_dataset[2]["evidence"])
print(tokenized_training_dataset[2]["labels"])
print(tokenized_testing_dataset[2]["input_ids"])
print(tokenized_testing_dataset[2]["attention_mask"])

alarmists here are taking overwhelmingly good news about global warming improving plant health and making it seem like this good news is actually bad news because healthier plants mean more pollen.
[clarification needed] Predictions measuring the effects of global warming on Australia assert that global warming will negatively impact the continent's environment, economy, and communities.
1
[0, 45699, 35, 8054, 1952, 259, 32, 602, 16089, 205, 340, 59, 720, 8232, 3927, 2195, 474, 8, 442, 24, 2045, 101, 42, 205, 340, 16, 888, 1099, 340, 142, 12732, 3451, 1266, 55, 35971, 4, 27956, 35, 646, 3998, 271, 5000, 956, 742, 24268, 19047, 14978, 5, 3038, 9, 720, 8232, 15, 1221, 18088, 14, 720, 8232, 40, 15708, 913, 5, 9183, 18, 1737, 6, 866, 6, 8, 1822, 4, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [11]:
# Remove ignore_mismatched_sizes when needed - this replaces the head of the pretrained model (because if using
# climateBERT/environmental-claims, it has already been fine tuned and has 2 labels
model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert/distilroberta-base",
    num_labels=3,
    # ignore_mismatched_sizes=True
).to("cuda")

model.safetensors:   0%|          | 0.00/331M [00:00<?, ?B/s]

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at distilbert/distilroberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
print(next(model.parameters()).device)

cuda:0


In [13]:
def calculate_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)
    
    accuracy = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average="weighted")

    return {
        "accuracy": accuracy,
        "f1_score": f1,
    }

In [14]:
# This is where we set the hyperparameters
training_args = TrainingArguments(
    output_dir="./results/distilroberta-base/climate_fever/test01",
    per_device_train_batch_size=8,
    num_train_epochs=3,
    evaluation_strategy="steps",
    eval_steps=200,
    logging_strategy="steps",
    logging_steps=50,
    save_strategy="no",
    save_steps=500,
    fp16=True,                          # Use 16-bit floating point instead of 32 - makes computation faster
    warmup_ratio=0.05,                    # Allows the model to adapt a little
    # gradient_accumulation_steps=2       # Might help with OOM errors, if we have them
    learning_rate=3e-5,
    push_to_hub=False,
)

In [15]:
trainer  = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_training_dataset,
    tokenizer=tokenizer,
    eval_dataset=tokenized_testing_dataset,
    compute_metrics=calculate_metrics,
)

In [16]:
trainer.train()

  0%|          | 0/2592 [00:00<?, ?it/s]

{'loss': 1.0084, 'grad_norm': inf, 'learning_rate': 1.1307692307692307e-05, 'epoch': 0.06}
{'loss': 0.8688, 'grad_norm': 5.010277271270752, 'learning_rate': 2.2846153846153845e-05, 'epoch': 0.12}
{'loss': 0.9078, 'grad_norm': 3.1921093463897705, 'learning_rate': 2.976848090982941e-05, 'epoch': 0.17}
{'loss': 0.873, 'grad_norm': 5.853714466094971, 'learning_rate': 2.9159220146222584e-05, 'epoch': 0.23}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.8654060363769531, 'eval_accuracy': 0.6510416666666666, 'eval_f1_score': 0.513439800210305, 'eval_runtime': 7.2921, 'eval_samples_per_second': 105.319, 'eval_steps_per_second': 13.165, 'epoch': 0.23}
{'loss': 0.9116, 'grad_norm': 5.128660678863525, 'learning_rate': 2.854995938261576e-05, 'epoch': 0.29}
{'loss': 0.8347, 'grad_norm': 3.7346599102020264, 'learning_rate': 2.7940698619008936e-05, 'epoch': 0.35}
{'loss': 0.7984, 'grad_norm': 5.775193214416504, 'learning_rate': 2.7331437855402112e-05, 'epoch': 0.41}
{'loss': 0.796, 'grad_norm': 4.681161403656006, 'learning_rate': 2.672217709179529e-05, 'epoch': 0.46}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.8051298260688782, 'eval_accuracy': 0.6731770833333334, 'eval_f1_score': 0.5885339458577498, 'eval_runtime': 7.183, 'eval_samples_per_second': 106.919, 'eval_steps_per_second': 13.365, 'epoch': 0.46}
{'loss': 0.932, 'grad_norm': 4.253560543060303, 'learning_rate': 2.6112916328188468e-05, 'epoch': 0.52}
{'loss': 0.8391, 'grad_norm': 3.0514910221099854, 'learning_rate': 2.5503655564581644e-05, 'epoch': 0.58}
{'loss': 0.7885, 'grad_norm': 4.587184429168701, 'learning_rate': 2.489439480097482e-05, 'epoch': 0.64}
{'loss': 0.8192, 'grad_norm': 4.415180206298828, 'learning_rate': 2.4285134037367995e-05, 'epoch': 0.69}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.7410104870796204, 'eval_accuracy': 0.6848958333333334, 'eval_f1_score': 0.6063317587209301, 'eval_runtime': 7.2039, 'eval_samples_per_second': 106.609, 'eval_steps_per_second': 13.326, 'epoch': 0.69}
{'loss': 0.8288, 'grad_norm': 7.331613063812256, 'learning_rate': 2.367587327376117e-05, 'epoch': 0.75}
{'loss': 0.7938, 'grad_norm': 5.2924299240112305, 'learning_rate': 2.3066612510154347e-05, 'epoch': 0.81}
{'loss': 0.9149, 'grad_norm': 8.386813163757324, 'learning_rate': 2.2469536961819658e-05, 'epoch': 0.87}
{'loss': 0.7417, 'grad_norm': 9.535518646240234, 'learning_rate': 2.1860276198212834e-05, 'epoch': 0.93}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.6894256472587585, 'eval_accuracy': 0.7083333333333334, 'eval_f1_score': 0.6854872282267186, 'eval_runtime': 7.2068, 'eval_samples_per_second': 106.566, 'eval_steps_per_second': 13.321, 'epoch': 0.93}
{'loss': 0.7497, 'grad_norm': 9.613948822021484, 'learning_rate': 2.125101543460601e-05, 'epoch': 0.98}
{'loss': 0.683, 'grad_norm': 10.01656436920166, 'learning_rate': 2.064175467099919e-05, 'epoch': 1.04}
{'loss': 0.6724, 'grad_norm': 9.98344612121582, 'learning_rate': 2.0032493907392365e-05, 'epoch': 1.1}
{'loss': 0.6564, 'grad_norm': 16.286937713623047, 'learning_rate': 1.942323314378554e-05, 'epoch': 1.16}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.7154312133789062, 'eval_accuracy': 0.6888020833333334, 'eval_f1_score': 0.6896046005664443, 'eval_runtime': 7.2044, 'eval_samples_per_second': 106.602, 'eval_steps_per_second': 13.325, 'epoch': 1.16}
{'loss': 0.6683, 'grad_norm': 7.677107334136963, 'learning_rate': 1.8813972380178717e-05, 'epoch': 1.22}
{'loss': 0.6629, 'grad_norm': 13.988689422607422, 'learning_rate': 1.8204711616571893e-05, 'epoch': 1.27}
{'loss': 0.6948, 'grad_norm': 12.44792366027832, 'learning_rate': 1.759545085296507e-05, 'epoch': 1.33}
{'loss': 0.692, 'grad_norm': 26.490665435791016, 'learning_rate': 1.6986190089358245e-05, 'epoch': 1.39}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.6374222636222839, 'eval_accuracy': 0.72265625, 'eval_f1_score': 0.7068562644419499, 'eval_runtime': 7.2069, 'eval_samples_per_second': 106.564, 'eval_steps_per_second': 13.321, 'epoch': 1.39}
{'loss': 0.6516, 'grad_norm': 9.206770896911621, 'learning_rate': 1.637692932575142e-05, 'epoch': 1.45}
{'loss': 0.5715, 'grad_norm': 15.929777145385742, 'learning_rate': 1.5767668562144597e-05, 'epoch': 1.5}
{'loss': 0.6169, 'grad_norm': 9.45015811920166, 'learning_rate': 1.5158407798537775e-05, 'epoch': 1.56}
{'loss': 0.6874, 'grad_norm': 14.898390769958496, 'learning_rate': 1.454914703493095e-05, 'epoch': 1.62}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.669572651386261, 'eval_accuracy': 0.7213541666666666, 'eval_f1_score': 0.7202807515681394, 'eval_runtime': 7.2211, 'eval_samples_per_second': 106.355, 'eval_steps_per_second': 13.294, 'epoch': 1.62}
{'loss': 0.6384, 'grad_norm': 7.988715171813965, 'learning_rate': 1.3939886271324126e-05, 'epoch': 1.68}
{'loss': 0.5995, 'grad_norm': 11.535401344299316, 'learning_rate': 1.3330625507717304e-05, 'epoch': 1.74}
{'loss': 0.5428, 'grad_norm': 23.513404846191406, 'learning_rate': 1.272136474411048e-05, 'epoch': 1.79}
{'loss': 0.6723, 'grad_norm': 10.822189331054688, 'learning_rate': 1.2112103980503656e-05, 'epoch': 1.85}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.6918346285820007, 'eval_accuracy': 0.7057291666666666, 'eval_f1_score': 0.71149112155428, 'eval_runtime': 7.1842, 'eval_samples_per_second': 106.901, 'eval_steps_per_second': 13.363, 'epoch': 1.85}
{'loss': 0.5825, 'grad_norm': 15.705116271972656, 'learning_rate': 1.1502843216896832e-05, 'epoch': 1.91}
{'loss': 0.5912, 'grad_norm': 20.418243408203125, 'learning_rate': 1.0893582453290008e-05, 'epoch': 1.97}
{'loss': 0.5318, 'grad_norm': 11.915464401245117, 'learning_rate': 1.0284321689683186e-05, 'epoch': 2.03}
{'loss': 0.497, 'grad_norm': 10.664161682128906, 'learning_rate': 9.675060926076362e-06, 'epoch': 2.08}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.7675027847290039, 'eval_accuracy': 0.6731770833333334, 'eval_f1_score': 0.6831170551224152, 'eval_runtime': 7.4853, 'eval_samples_per_second': 102.601, 'eval_steps_per_second': 12.825, 'epoch': 2.08}
{'loss': 0.5868, 'grad_norm': 15.373984336853027, 'learning_rate': 9.077985377741674e-06, 'epoch': 2.14}
{'loss': 0.4218, 'grad_norm': 16.884315490722656, 'learning_rate': 8.46872461413485e-06, 'epoch': 2.2}
{'loss': 0.4751, 'grad_norm': 35.20327377319336, 'learning_rate': 7.859463850528026e-06, 'epoch': 2.26}
{'loss': 0.4904, 'grad_norm': 42.923118591308594, 'learning_rate': 7.2502030869212026e-06, 'epoch': 2.31}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.687366783618927, 'eval_accuracy': 0.7252604166666666, 'eval_f1_score': 0.7301023230662705, 'eval_runtime': 7.212, 'eval_samples_per_second': 106.489, 'eval_steps_per_second': 13.311, 'epoch': 2.31}
{'loss': 0.4496, 'grad_norm': 26.484174728393555, 'learning_rate': 6.640942323314379e-06, 'epoch': 2.37}
{'loss': 0.5606, 'grad_norm': 29.686994552612305, 'learning_rate': 6.031681559707555e-06, 'epoch': 2.43}
{'loss': 0.5242, 'grad_norm': 9.505236625671387, 'learning_rate': 5.422420796100731e-06, 'epoch': 2.49}
{'loss': 0.4063, 'grad_norm': 34.38750457763672, 'learning_rate': 4.813160032493908e-06, 'epoch': 2.55}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.67450350522995, 'eval_accuracy': 0.7486979166666666, 'eval_f1_score': 0.747175898235068, 'eval_runtime': 7.2865, 'eval_samples_per_second': 105.4, 'eval_steps_per_second': 13.175, 'epoch': 2.55}
{'loss': 0.3978, 'grad_norm': 20.847537994384766, 'learning_rate': 4.203899268887083e-06, 'epoch': 2.6}
{'loss': 0.4521, 'grad_norm': 33.10408020019531, 'learning_rate': 3.59463850528026e-06, 'epoch': 2.66}
{'loss': 0.4626, 'grad_norm': 20.268428802490234, 'learning_rate': 2.9853777416734364e-06, 'epoch': 2.72}
{'loss': 0.5221, 'grad_norm': 35.51530838012695, 'learning_rate': 2.3761169780666128e-06, 'epoch': 2.78}


  0%|          | 0/96 [00:00<?, ?it/s]

{'eval_loss': 0.6949567794799805, 'eval_accuracy': 0.73046875, 'eval_f1_score': 0.7337111283516844, 'eval_runtime': 7.24, 'eval_samples_per_second': 106.077, 'eval_steps_per_second': 13.26, 'epoch': 2.78}
{'loss': 0.4799, 'grad_norm': 33.258663177490234, 'learning_rate': 1.7668562144597887e-06, 'epoch': 2.84}
{'loss': 0.4615, 'grad_norm': 27.102725982666016, 'learning_rate': 1.157595450852965e-06, 'epoch': 2.89}
{'loss': 0.4501, 'grad_norm': 4.491292953491211, 'learning_rate': 5.483346872461414e-07, 'epoch': 2.95}
{'train_runtime': 757.0678, 'train_samples_per_second': 27.37, 'train_steps_per_second': 3.424, 'train_loss': 0.6517716219395767, 'epoch': 3.0}


TrainOutput(global_step=2592, training_loss=0.6517716219395767, metrics={'train_runtime': 757.0678, 'train_samples_per_second': 27.37, 'train_steps_per_second': 3.424, 'total_flos': 2744905918178304.0, 'train_loss': 0.6517716219395767, 'epoch': 3.0})

In [17]:
trainer.save_model("./results/distilroberta-base/climate_fever/test01")
# Please remember to delete model.safetensors BEFORE adding to git. Causes issues...
# Also it is probably not worth running this block until the model is worth keeping

In [None]:
# Metrics are not included in the save model so we need to save them separately
metrics = trainer.evaluate()
with open("./results/distilroberta-base/climate_fever/test01/eval_metrics.json", "w") as output_file:
    json.dump(metrics, output_file)

  0%|          | 0/96 [00:00<?, ?it/s]

: 