In [26]:
#from datasets import load_from_disk # import if using a huggingface dataset saved locally
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import torch
import gc
from datasets import load_dataset, DatasetDict, Dataset
#from accelerate.utils import release_memory
import numpy as np
from sklearn.metrics import balanced_accuracy_score, precision_recall_fscore_support, accuracy_score, classification_report, matthews_corrcoef
import pandas as pd

In [2]:
import os
os.environ["WANDB_PROJECT"] = "Pol-NLI-Base"

In [3]:
#modname = "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33"
modname = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
training_directory ='training_base'
#training_directory ='training_large'
#training_directory = 'training_base'
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

Device: cuda


In [4]:
ds = load_dataset("mlburnham/Pol_NLI")
#ds = ds.shuffle(seed = 1)

In [5]:
df = ds['train'].to_pandas()
dftest = ds['test'].to_pandas()
dfval = ds['validation'].to_pandas()

In [6]:
def truncate(text):
    words = text.split()
    if len(words) > 450:
        return " ".join(words[:450])
    return text

In [7]:
df['premise'] = df['premise'].apply(truncate)
dftest['premise'] = dftest['premise'].apply(truncate)
dfval['premise'] = dfval['premise'].apply(truncate)

In [8]:
ds = DatasetDict({'train': Dataset.from_pandas(df, preserve_index=False), 'validation':Dataset.from_pandas(dfval, preserve_index=False), 'test':Dataset.from_pandas(dftest, preserve_index=False)})

In [9]:
tokenizer = AutoTokenizer.from_pretrained(modname)

In [None]:
def tokenize_function(docs):
    return tokenizer(docs['premise'], docs['augmented_hypothesis'], padding = True, truncation = True)

dstok = ds.map(tokenize_function, batched = True)

dstok = dstok.rename_columns({'entailment':'label'})

id2label = {0: "entailment", 1: "not_entailment"}

model = AutoModelForSequenceClassification.from_pretrained(modname, num_labels = 2, ignore_mismatched_sizes=True)

In [25]:
training_args = TrainingArguments(output_dir=training_directory,
    logging_dir=f'{training_directory}/logs',
    #deepspeed="ds_config_zero3.json",  # if using deepspeed
    lr_scheduler_type= "linear",
    group_by_length=True,  # can increase speed with dynamic padding, by grouping similar length texts https://huggingface.co/transformers/main_classes/trainer.html
    learning_rate=9e-6 if "large" in modname else 2e-5,
    per_device_train_batch_size=4 if "large" in modname else 4,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4 if "large" in modname else 1,  # (!adapt/halve batch size accordingly). accumulates gradients over X steps, only then backward/update. decreases memory usage, but also slightly speed
    #eval_accumulation_steps=2,
    num_train_epochs=20,
    #max_steps=400,
    #warmup_steps=0,  # 1000,
    warmup_ratio=0.06,  #0.1, 0.06
    weight_decay=0.01,  #0.1,
    fp16=True,   # ! only makes sense at batch-size > 8. loads two copies of model weights, which creates overhead. https://huggingface.co/transformers/performance.html?#fp16
    fp16_full_eval=True,
    eval_strategy="epoch",
    seed=1,
    #load_best_model_at_end=True,
    #metric_for_best_model="accuracy",
    #eval_steps=50,  # evaluate after n steps if evaluation_strategy!='steps'. defaults to logging_steps
    save_strategy="epoch",  # options: "no"/"steps"/"epoch"
    #save_steps=100,  # Number of updates steps before two checkpoint saves.
    dataloader_num_workers = 12,
    #save_total_limit=1,  # If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in output_dir
    #logging_strategy="epoch",
    #report_to="all",  # "all"
    #run_name=run_name,
    #push_to_hub=True,  # does not seem to work if save_strategy="no"
    #hub_model_id=hub_model_id,
    #hub_token=config.HF_ACCESS_TOKEN,
    #hub_strategy="end",
    #hub_private_repo=True,
)

In [27]:
def compute_metrics_standard(eval_pred, label_text_alphabetical=list(model.config.id2label.values())):
    labels = eval_pred.label_ids
    pred_logits = eval_pred.predictions
    preds_max = np.argmax(pred_logits, axis=1)  # argmax on each row (axis=1) in the tensor

    # metrics
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(labels, preds_max, average='macro')  # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html
    precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(labels, preds_max, average='micro')  # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html
    acc_balanced = balanced_accuracy_score(labels, preds_max)
    acc_not_balanced = accuracy_score(labels, preds_max)
    mcc = matthews_corrcoef(labels, preds_max)

    metrics = {'MCC': mcc,
            'f1_macro': f1_macro,
            'f1_micro': f1_micro,
            'accuracy_balanced': acc_balanced,
            'accuracy': acc_not_balanced,
            'precision_macro': precision_macro,
            'recall_macro': recall_macro,
            'precision_micro': precision_micro,
            'recall_micro': recall_micro,
            #'label_gold_raw': labels,
            #'label_predicted_raw': preds_max
            }
    print("Aggregate metrics: ", {key: metrics[key] for key in metrics if key not in ["label_gold_raw", "label_predicted_raw"]} )  # print metrics but without label lists
    print("Detailed metrics: ", classification_report(
        labels, preds_max, labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),
        target_names=label_text_alphabetical, sample_weight=None,
        digits=2, output_dict=True, zero_division='warn'),
    "\n")

    return metrics

In [28]:
trainer = Trainer(
    model=model,
    #model_init=model_init,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=dstok['train'],
    eval_dataset=dstok['validation'],
    compute_metrics=lambda x: compute_metrics_standard(x, label_text_alphabetical=list(model.config.id2label.values()))  #compute_metrics,
    #data_collator=data_collator,  # for weighted sampling per dataset; for dynamic padding probably not necessary because done by default  https://huggingface.co/course/chapter3/3?fw=pt
)

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


In [13]:
if device == "cuda":
    # free memory
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    release_memory(model)
    #del (model, trainer)

In [29]:
trainer.train()
#trainer.train(resume_from_checkpoint = 'training_base/checkpoint-157664')

wandb: Currently logged in as: mlburnham. Use `wandb login --relogin` to force relogin


Epoch,Training Loss,Validation Loss,Mcc,F1 Macro,F1 Micro,Accuracy Balanced,Accuracy,Precision Macro,Recall Macro,Precision Micro,Recall Micro
1,0.259,0.289279,0.858225,0.928839,0.930234,0.93118,0.930234,0.927055,0.93118,0.930234,0.930234
2,0.2145,0.303717,0.878449,0.939175,0.940809,0.938183,0.940809,0.940269,0.938183,0.940809,0.940809
3,0.1718,0.396141,0.863504,0.931616,0.933559,0.929992,0.933559,0.93352,0.929992,0.933559,0.933559
4,0.1204,0.386189,0.879039,0.939086,0.941008,0.936174,0.941008,0.94289,0.936174,0.941008,0.941008
5,0.077,0.352744,0.889709,0.944841,0.946262,0.944309,0.946262,0.945401,0.944309,0.946262,0.946262
6,0.0534,0.475702,0.868469,0.933148,0.935555,0.928637,0.935555,0.939906,0.928637,0.935555,0.935555
7,0.064,0.528699,0.868221,0.93349,0.935688,0.930056,0.935688,0.938203,0.930056,0.935688,0.935688
8,0.0373,0.529357,0.882424,0.941016,0.942737,0.939037,0.942737,0.943397,0.939037,0.942737,0.942737
9,0.0465,0.507311,0.888016,0.943957,0.945464,0.942934,0.945464,0.945084,0.942934,0.945464,0.945464


Aggregate metrics:  {'MCC': 0.8582253579934088, 'f1_macro': 0.9288394102617632, 'f1_micro': 0.9302341048151104, 'accuracy_balanced': 0.9311803369483882, 'accuracy': 0.9302341048151104, 'precision_macro': 0.9270549361757414, 'recall_macro': 0.9311803369483882, 'precision_micro': 0.9302341048151104, 'recall_micro': 0.9302341048151104}


  labels, preds_max, labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),


Detailed metrics:  {'entailment': {'precision': 0.9012439320388349, 'recall': 0.9372140716201294, 'f1-score': 0.9188771170056453, 'support': 6339.0}, 'not_entailment': {'precision': 0.952865940312648, 'recall': 0.9251466022766471, 'f1-score': 0.9388017035178811, 'support': 8697.0}, 'accuracy': 0.9302341048151104, 'macro avg': {'precision': 0.9270549361757414, 'recall': 0.9311803369483882, 'f1-score': 0.9288394102617632, 'support': 15036.0}, 'weighted avg': {'precision': 0.9311027113656075, 'recall': 0.9302341048151104, 'f1-score': 0.9304017331866054, 'support': 15036.0}} 

Aggregate metrics:  {'MCC': 0.8784490297347349, 'f1_macro': 0.9391753102131721, 'f1_micro': 0.9408087257249268, 'accuracy_balanced': 0.9381828495239177, 'accuracy': 0.9408087257249268, 'precision_macro': 0.9402686564982596, 'recall_macro': 0.9381828495239177, 'precision_micro': 0.9408087257249268, 'recall_micro': 0.9408087257249268}


  labels, preds_max, labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),


Detailed metrics:  {'entailment': {'precision': 0.9371089363067544, 'recall': 0.9214387127307146, 'f1-score': 0.9292077632834871, 'support': 6339.0}, 'not_entailment': {'precision': 0.9434283766897649, 'recall': 0.9549269863171208, 'f1-score': 0.9491428571428572, 'support': 8697.0}, 'accuracy': 0.9408087257249268, 'macro avg': {'precision': 0.9402686564982596, 'recall': 0.9381828495239177, 'f1-score': 0.9391753102131721, 'support': 15036.0}, 'weighted avg': {'precision': 0.9407641752673185, 'recall': 0.9408087257249268, 'f1-score': 0.9407384570381387, 'support': 15036.0}} 

Aggregate metrics:  {'MCC': 0.8635043434451539, 'f1_macro': 0.9316156739217735, 'f1_micro': 0.933559457302474, 'accuracy_balanced': 0.9299916073349379, 'accuracy': 0.933559457302474, 'precision_macro': 0.9335199445926845, 'recall_macro': 0.9299916073349379, 'precision_micro': 0.933559457302474, 'recall_micro': 0.933559457302474}


  labels, preds_max, labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),


Detailed metrics:  {'entailment': {'precision': 0.9333008763388511, 'recall': 0.9072408897302414, 'f1-score': 0.9200863930885529, 'support': 6339.0}, 'not_entailment': {'precision': 0.933739012846518, 'recall': 0.9527423249396344, 'f1-score': 0.943144954754994, 'support': 8697.0}, 'accuracy': 0.933559457302474, 'macro avg': {'precision': 0.9335199445926845, 'recall': 0.9299916073349379, 'f1-score': 0.9316156739217735, 'support': 15036.0}, 'weighted avg': {'precision': 0.9335542996700015, 'recall': 0.933559457302474, 'f1-score': 0.9334237375161293, 'support': 15036.0}} 

Aggregate metrics:  {'MCC': 0.8790385115953533, 'f1_macro': 0.9390860406868684, 'f1_micro': 0.9410082468741686, 'accuracy_balanced': 0.9361739808228446, 'accuracy': 0.9410082468741686, 'precision_macro': 0.9428901876551959, 'recall_macro': 0.9361739808228446, 'precision_micro': 0.9410082468741686, 'recall_micro': 0.9410082468741686}


  labels, preds_max, labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),


Detailed metrics:  {'entailment': {'precision': 0.9523730501161632, 'recall': 0.9053478466635116, 'f1-score': 0.9282652648604933, 'support': 6339.0}, 'not_entailment': {'precision': 0.9334073251942286, 'recall': 0.9670001149821777, 'f1-score': 0.9499068165132434, 'support': 8697.0}, 'accuracy': 0.9410082468741686, 'macro avg': {'precision': 0.9428901876551959, 'recall': 0.9361739808228446, 'f1-score': 0.9390860406868684, 'support': 15036.0}, 'weighted avg': {'precision': 0.9414030508047729, 'recall': 0.9410082468741686, 'f1-score': 0.9407829939589216, 'support': 15036.0}} 

Aggregate metrics:  {'MCC': 0.8897086397503311, 'f1_macro': 0.9448405317681094, 'f1_micro': 0.9462623038042033, 'accuracy_balanced': 0.9443085753795242, 'accuracy': 0.9462623038042033, 'precision_macro': 0.9454007347091056, 'recall_macro': 0.9443085753795242, 'precision_micro': 0.9462623038042033, 'recall_micro': 0.9462623038042033}


  labels, preds_max, labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),


Detailed metrics:  {'entailment': {'precision': 0.9401559764443737, 'recall': 0.9318504495977283, 'f1-score': 0.9359847884645857, 'support': 6339.0}, 'not_entailment': {'precision': 0.9506454929738375, 'recall': 0.95676670116132, 'f1-score': 0.9536962750716332, 'support': 8697.0}, 'accuracy': 0.9462623038042033, 'macro avg': {'precision': 0.9454007347091056, 'recall': 0.9443085753795242, 'f1-score': 0.9448405317681094, 'support': 15036.0}, 'weighted avg': {'precision': 0.9462232367035347, 'recall': 0.9462623038042033, 'f1-score': 0.9462293215200188, 'support': 15036.0}} 

Aggregate metrics:  {'MCC': 0.8684692531945615, 'f1_macro': 0.9331477873712155, 'f1_micro': 0.9355546687948922, 'accuracy_balanced': 0.9286367983998921, 'accuracy': 0.9355546687948922, 'precision_macro': 0.9399055602318238, 'recall_macro': 0.9286367983998921, 'precision_micro': 0.9355546687948922, 'recall_micro': 0.9355546687948922}


  labels, preds_max, labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),


Detailed metrics:  {'entailment': {'precision': 0.9594455852156057, 'recall': 0.8845243729294842, 'f1-score': 0.9204629401625215, 'support': 6339.0}, 'not_entailment': {'precision': 0.9203655352480418, 'recall': 0.9727492238703, 'f1-score': 0.9458326345799094, 'support': 8697.0}, 'accuracy': 0.9355546687948922, 'macro avg': {'precision': 0.9399055602318238, 'recall': 0.9286367983998921, 'f1-score': 0.9331477873712155, 'support': 15036.0}, 'weighted avg': {'precision': 0.936841222714415, 'recall': 0.9355546687948922, 'f1-score': 0.9351370710715414, 'support': 15036.0}} 

Aggregate metrics:  {'MCC': 0.868220610524483, 'f1_macro': 0.933490332029201, 'f1_micro': 0.9356876828943868, 'accuracy_balanced': 0.9300563086171714, 'accuracy': 0.9356876828943868, 'precision_macro': 0.9382025175745834, 'recall_macro': 0.9300563086171714, 'precision_micro': 0.9356876828943868, 'recall_micro': 0.9356876828943868}


  labels, preds_max, labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),


Detailed metrics:  {'entailment': {'precision': 0.9503688799463448, 'recall': 0.8941473418520272, 'f1-score': 0.9214012842396163, 'support': 6339.0}, 'not_entailment': {'precision': 0.9260361552028219, 'recall': 0.9659652753823157, 'f1-score': 0.9455793798187855, 'support': 8697.0}, 'accuracy': 0.9356876828943868, 'macro avg': {'precision': 0.9382025175745834, 'recall': 0.9300563086171714, 'f1-score': 0.933490332029201, 'support': 15036.0}, 'weighted avg': {'precision': 0.9362945445450134, 'recall': 0.9356876828943868, 'f1-score': 0.9353861803058595, 'support': 15036.0}} 

Aggregate metrics:  {'MCC': 0.8824237243750158, 'f1_macro': 0.9410155574135199, 'f1_micro': 0.9427374301675978, 'accuracy_balanced': 0.939037434289971, 'accuracy': 0.9427374301675978, 'precision_macro': 0.9433970594097364, 'recall_macro': 0.939037434289971, 'precision_micro': 0.9427374301675978, 'recall_micro': 0.9427374301675978}


  labels, preds_max, labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),


Detailed metrics:  {'entailment': {'precision': 0.9469647519582245, 'recall': 0.9154440763527371, 'f1-score': 0.9309376754632229, 'support': 6339.0}, 'not_entailment': {'precision': 0.9398293668612483, 'recall': 0.9626307922272048, 'f1-score': 0.9510934393638171, 'support': 8697.0}, 'accuracy': 0.9427374301675978, 'macro avg': {'precision': 0.9433970594097364, 'recall': 0.939037434289971, 'f1-score': 0.9410155574135199, 'support': 15036.0}, 'weighted avg': {'precision': 0.9428375609374476, 'recall': 0.9427374301675978, 'f1-score': 0.9425960073761963, 'support': 15036.0}} 

Aggregate metrics:  {'MCC': 0.8880162279260374, 'f1_macro': 0.9439567629366055, 'f1_micro': 0.945464219207236, 'accuracy_balanced': 0.9429343397348423, 'accuracy': 0.945464219207236, 'precision_macro': 0.9450844912657135, 'recall_macro': 0.9429343397348423, 'precision_micro': 0.945464219207236, 'recall_micro': 0.945464219207236}


  labels, preds_max, labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),


Detailed metrics:  {'entailment': {'precision': 0.9428663135933237, 'recall': 0.9268023347531157, 'f1-score': 0.9347653142402546, 'support': 6339.0}, 'not_entailment': {'precision': 0.9473026689381033, 'recall': 0.9590663447165689, 'f1-score': 0.9531482116329563, 'support': 8697.0}, 'accuracy': 0.945464219207236, 'macro avg': {'precision': 0.9450844912657135, 'recall': 0.9429343397348423, 'f1-score': 0.9439567629366055, 'support': 15036.0}, 'weighted avg': {'precision': 0.94543235392543, 'recall': 0.945464219207236, 'f1-score': 0.9453981992245807, 'support': 15036.0}} 



KeyboardInterrupt: 

In [27]:
trainer.evaluate(eval_dataset = dstok['test'])

Aggregate metrics:  {'f1_macro': 0.9044187877129376, 'f1_micro': 0.9071485623003195, 'accuracy_balanced': 0.901032727240402, 'accuracy': 0.9071485623003195, 'precision_macro': 0.9096203014276345, 'recall_macro': 0.901032727240402, 'precision_micro': 0.9071485623003195, 'recall_micro': 0.9071485623003195}
Detailed metrics:  {'entailment': {'precision': 0.9220152976388427, 'recall': 0.8569000154535621, 'f1-score': 0.8882659191029235, 'support': 6471.0}, 'not_entailment': {'precision': 0.8972253052164262, 'recall': 0.9451654390272419, 'f1-score': 0.9205716563229517, 'support': 8553.0}, 'accuracy': 0.9071485623003195, 'macro avg': {'precision': 0.9096203014276345, 'recall': 0.901032727240402, 'f1-score': 0.9044187877129376, 'support': 15024.0}, 'weighted avg': {'precision': 0.9079026242370237, 'recall': 0.9071485623003195, 'f1-score': 0.9066572243773445, 'support': 15024.0}} 



  labels, preds_max, labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),


{'eval_loss': 1.1247555017471313,
 'eval_f1_macro': 0.9044187877129376,
 'eval_f1_micro': 0.9071485623003195,
 'eval_accuracy_balanced': 0.901032727240402,
 'eval_accuracy': 0.9071485623003195,
 'eval_precision_macro': 0.9096203014276345,
 'eval_recall_macro': 0.901032727240402,
 'eval_precision_micro': 0.9071485623003195,
 'eval_recall_micro': 0.9071485623003195}