In [2]:
#from datasets import load_from_disk # import if using a huggingface dataset saved locally
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import torch
import gc
from datasets import load_dataset, DatasetDict, Dataset
#from accelerate.utils import release_memory
import numpy as np
from sklearn.metrics import balanced_accuracy_score, precision_recall_fscore_support, accuracy_score, classification_report
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
os.environ["WANDB_PROJECT"] = "Pol-NLI-Large"

In [3]:
modname = "MoritzLaurer/deberta-v3-large-zeroshot-v2.0"
modname = 'training_large/latest_checkpoint'
training_directory ='training_large'
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

Device: cuda


In [3]:
ds = load_dataset("mlburnham/Pol_NLI")
#ds = ds.shuffle(seed = 1)

In [4]:
df = ds['train'].to_pandas()
dftest = ds['test'].to_pandas()
dfval = ds['validation'].to_pandas()

In [5]:
df = pd.concat([df, dftest, dfval])

In [8]:
len(df['augmented_hypothesis'].unique())

2834

In [6]:
def truncate(text):
    words = text.split()
    if len(words) > 450:
        return " ".join(words[:450])
    return text

In [7]:
df['premise'] = df['premise'].apply(truncate)
dftest['premise'] = dftest['premise'].apply(truncate)
dfval['premise'] = dfval['premise'].apply(truncate)

In [8]:
ds = DatasetDict({'train': Dataset.from_pandas(df, preserve_index=False), 'validation':Dataset.from_pandas(dfval, preserve_index=False), 'test':Dataset.from_pandas(dftest, preserve_index=False)})

In [9]:
tokenizer = AutoTokenizer.from_pretrained(modname)

In [10]:
def tokenize_function(docs):
    #return tokenizer(docs['premise'], docs['augmented_hypothesis'], padding = 'max_length', truncation = True)
    return tokenizer(docs['premise'], docs['augmented_hypothesis'], padding = True, truncation = True)

In [11]:
dstok = ds.map(tokenize_function, batched = True)

Map: 100%|██████████| 171289/171289 [00:33<00:00, 5102.23 examples/s]
Map: 100%|██████████| 15036/15036 [00:02<00:00, 5328.73 examples/s]
Map: 100%|██████████| 15366/15366 [00:02<00:00, 5606.16 examples/s]


In [12]:
dstok = dstok.rename_columns({'entailment':'label'})

In [13]:
id2label = {0: "entailment", 1: "not_entailment"}

In [14]:
model = AutoModelForSequenceClassification.from_pretrained(modname, num_labels = 2, ignore_mismatched_sizes=True)

In [15]:
training_args = TrainingArguments(output_dir=training_directory,
    logging_dir=f'{training_directory}/logs',
    #deepspeed="ds_config_zero3.json",  # if using deepspeed
    lr_scheduler_type= "linear",
    group_by_length=False,  # can increase speed with dynamic padding, by grouping similar length texts https://huggingface.co/transformers/main_classes/trainer.html
    learning_rate=9e-6 if "large" in modname else 2e-5,
    per_device_train_batch_size=4 if "large" in modname else 8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4 if "large" in modname else 2,  # (!adapt/halve batch size accordingly). accumulates gradients over X steps, only then backward/update. decreases memory usage, but also slightly speed
    #eval_accumulation_steps=2,
    num_train_epochs=20,
    #max_steps=400,
    #warmup_steps=0,  # 1000,
    warmup_ratio=0.06,  #0.1, 0.06
    weight_decay=0.01,  #0.1,
    fp16=True,   # ! only makes sense at batch-size > 8. loads two copies of model weights, which creates overhead. https://huggingface.co/transformers/performance.html?#fp16
    fp16_full_eval=True,
    eval_strategy="epoch",
    seed=1,
    #load_best_model_at_end=True,
    #metric_for_best_model="accuracy",
    #eval_steps=50,  # evaluate after n steps if evaluation_strategy!='steps'. defaults to logging_steps
    save_strategy="epoch",  # options: "no"/"steps"/"epoch"
    #save_steps=100,  # Number of updates steps before two checkpoint saves.
    dataloader_num_workers = 12,
    #save_total_limit=1,  # If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in output_dir
    #logging_strategy="epoch",
    #report_to="all",  # "all"
    #run_name=run_name,
    #push_to_hub=True,  # does not seem to work if save_strategy="no"
    #hub_model_id=hub_model_id,
    #hub_token=config.HF_ACCESS_TOKEN,
    #hub_strategy="end",
    #hub_private_repo=True,
)

In [16]:
def compute_metrics_standard(eval_pred, label_text_alphabetical=list(model.config.id2label.values())):
    labels = eval_pred.label_ids
    pred_logits = eval_pred.predictions
    preds_max = np.argmax(pred_logits, axis=1)  # argmax on each row (axis=1) in the tensor

    # metrics
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(labels, preds_max, average='macro')  # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html
    precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(labels, preds_max, average='micro')  # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html
    acc_balanced = balanced_accuracy_score(labels, preds_max)
    acc_not_balanced = accuracy_score(labels, preds_max)

    metrics = {'f1_macro': f1_macro,
            'f1_micro': f1_micro,
            'accuracy_balanced': acc_balanced,
            'accuracy': acc_not_balanced,
            'precision_macro': precision_macro,
            'recall_macro': recall_macro,
            'precision_micro': precision_micro,
            'recall_micro': recall_micro,
            #'label_gold_raw': labels,
            #'label_predicted_raw': preds_max
            }
    print("Aggregate metrics: ", {key: metrics[key] for key in metrics if key not in ["label_gold_raw", "label_predicted_raw"]} )  # print metrics but without label lists
    print("Detailed metrics: ", classification_report(
        labels, preds_max, labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),
        target_names=label_text_alphabetical, sample_weight=None,
        digits=2, output_dict=True, zero_division='warn'),
    "\n")

    return metrics

In [17]:
trainer = Trainer(
    model=model,
    #model_init=model_init,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=dstok['train'],
    eval_dataset=dstok['validation'],
    compute_metrics=lambda x: compute_metrics_standard(x, label_text_alphabetical=list(model.config.id2label.values()))  #compute_metrics,
    #data_collator=data_collator,  # for weighted sampling per dataset; for dynamic padding probably not necessary because done by default  https://huggingface.co/course/chapter3/3?fw=pt
)

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


In [None]:
#trainer.train()
trainer.train(resume_from_checkpoint = 'training_large/latest_checkpoint')

  torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
wandb: Currently logged in as: mlburnham. Use `wandb login --relogin` to force relogin


  checkpoint_rng_state = torch.load(rng_file)


Epoch,Training Loss,Validation Loss,F1 Macro,F1 Micro,Accuracy Balanced,Accuracy,Precision Macro,Recall Macro,Precision Micro,Recall Micro
9,0.0154,0.336436,0.958684,0.95983,0.957256,0.95983,0.960308,0.957256,0.95983,0.95983
10,0.0103,0.435559,0.950781,0.952248,0.948499,0.952248,0.953582,0.948499,0.952248,0.952248
11,0.0062,0.409473,0.951548,0.952913,0.949951,0.952913,0.953397,0.949951,0.952913,0.952913


Aggregate metrics:  {'f1_macro': 0.9586839037559856, 'f1_micro': 0.959829741952647, 'accuracy_balanced': 0.9572557427285473, 'accuracy': 0.959829741952647, 'precision_macro': 0.9603081722406619, 'recall_macro': 0.9572557427285473, 'precision_micro': 0.959829741952647, 'recall_micro': 0.959829741952647}


  labels, preds_max, labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),


Detailed metrics:  {'entailment': {'precision': 0.9630227676408849, 'recall': 0.9408424041646948, 'f1-score': 0.951803383338653, 'support': 6339.0}, 'not_entailment': {'precision': 0.9575935768404388, 'recall': 0.9736690812923997, 'f1-score': 0.9655644241733181, 'support': 8697.0}, 'accuracy': 0.959829741952647, 'macro avg': {'precision': 0.9603081722406619, 'recall': 0.9572557427285473, 'f1-score': 0.9586839037559856, 'support': 15036.0}, 'weighted avg': {'precision': 0.9598824595541943, 'recall': 0.959829741952647, 'f1-score': 0.9597629318980494, 'support': 15036.0}} 

Aggregate metrics:  {'f1_macro': 0.9507807712956857, 'f1_micro': 0.9522479382814578, 'accuracy_balanced': 0.9484990309228052, 'accuracy': 0.9522479382814578, 'precision_macro': 0.9535823523568522, 'recall_macro': 0.9484990309228052, 'precision_micro': 0.9522479382814578, 'recall_micro': 0.9522479382814578}


  labels, preds_max, labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),


Detailed metrics:  {'entailment': {'precision': 0.9606621865267989, 'recall': 0.9245937845085975, 'f1-score': 0.9422829581993569, 'support': 6339.0}, 'not_entailment': {'precision': 0.9465025181869055, 'recall': 0.9724042773370127, 'f1-score': 0.9592785843920145, 'support': 8697.0}, 'accuracy': 0.9522479382814578, 'macro avg': {'precision': 0.9535823523568522, 'recall': 0.9484990309228052, 'f1-score': 0.9507807712956857, 'support': 15036.0}, 'weighted avg': {'precision': 0.9524720671099292, 'recall': 0.9522479382814578, 'f1-score': 0.9521134291356128, 'support': 15036.0}} 

Aggregate metrics:  {'f1_macro': 0.9515477078223584, 'f1_micro': 0.9529130087789306, 'accuracy_balanced': 0.9499507557398172, 'accuracy': 0.9529130087789306, 'precision_macro': 0.9533965186971707, 'recall_macro': 0.9499507557398172, 'precision_micro': 0.9529130087789306, 'recall_micro': 0.9529130087789306}


  labels, preds_max, labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),


Detailed metrics:  {'entailment': {'precision': 0.956099141422323, 'recall': 0.9310616816532576, 'f1-score': 0.9434143222506394, 'support': 6339.0}, 'not_entailment': {'precision': 0.9506938959720185, 'recall': 0.968839829826377, 'f1-score': 0.9596810933940775, 'support': 8697.0}, 'accuracy': 0.9529130087789306, 'macro avg': {'precision': 0.9533965186971707, 'recall': 0.9499507557398172, 'f1-score': 0.9515477078223584, 'support': 15036.0}, 'weighted avg': {'precision': 0.9529726836089885, 'recall': 0.9529130087789306, 'f1-score': 0.9528232148174445, 'support': 15036.0}} 



In [21]:
trainer.evaluate(eval_dataset = dstok['test'])

Aggregate metrics:  {'f1_macro': 0.9483667829600351, 'f1_micro': 0.950475074840557, 'accuracy_balanced': 0.9449757064506679, 'accuracy': 0.950475074840557, 'precision_macro': 0.9526335862387606, 'recall_macro': 0.9449757064506679, 'precision_micro': 0.950475074840557, 'recall_micro': 0.950475074840557}


  labels, preds_max, labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),


Detailed metrics:  {'entailment': {'precision': 0.9623430962343096, 'recall': 0.9147311485841553, 'f1-score': 0.9379332843976838, 'support': 6286.0}, 'not_entailment': {'precision': 0.9429240762432116, 'recall': 0.9752202643171806, 'f1-score': 0.9588002815223864, 'support': 9080.0}, 'accuracy': 0.950475074840557, 'macro avg': {'precision': 0.9526335862387606, 'recall': 0.9449757064506679, 'f1-score': 0.9483667829600351, 'support': 15366.0}, 'weighted avg': {'precision': 0.9508681058972558, 'recall': 0.950475074840557, 'f1-score': 0.9502639061530072, 'support': 15366.0}} 



{'eval_loss': 0.23196956515312195,
 'eval_f1_macro': 0.9483667829600351,
 'eval_f1_micro': 0.950475074840557,
 'eval_accuracy_balanced': 0.9449757064506679,
 'eval_accuracy': 0.950475074840557,
 'eval_precision_macro': 0.9526335862387606,
 'eval_recall_macro': 0.9449757064506679,
 'eval_precision_micro': 0.950475074840557,
 'eval_recall_micro': 0.950475074840557}

In [20]:
trainer.save_model("training_large/test")