In [1]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import torch
import gc
from datasets import load_dataset, DatasetDict, Dataset
import numpy as np
from sklearn.metrics import balanced_accuracy_score, precision_recall_fscore_support, accuracy_score, classification_report, matthews_corrcoef
import pandas as pd
import wandb
import os

In [None]:
import os
os.environ["WANDB_PROJECT"] = "DEBATE-DeBERTa-Large"

modname = "MoritzLaurer/deberta-v3-large-zeroshot-v2.0"
training_directory ='training_large'
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

In [None]:
ds = load_dataset("mlburnham/Pol_NLI")

df = ds['train'].to_pandas()
dftest = ds['test'].to_pandas()
dfval = ds['validation'].to_pandas()

def truncate(text):
    words = text.split()
    if len(words) > 450:
        return " ".join(words[:450])
    return text

df['premise'] = df['premise'].apply(truncate)
dftest['premise'] = dftest['premise'].apply(truncate)
dfval['premise'] = dfval['premise'].apply(truncate)

ds = DatasetDict({'train': Dataset.from_pandas(df, preserve_index=False), 'validation':Dataset.from_pandas(dfval, preserve_index=False), 'test':Dataset.from_pandas(dftest, preserve_index=False)})

In [12]:
tokenizer = AutoTokenizer.from_pretrained(modname)
id2label = {0: "entailment", 1: "not_entailment"}

def tokenize_function(docs):
    return tokenizer(docs['premise'], docs['augmented_hypothesis'], padding = True, truncation = True)

def model_init():
  return AutoModelForSequenceClassification.from_pretrained(modname, 
                                                           num_labels=2,
                                                           ignore_mismatched_sizes=True,
                                                           id2label = id2label)

dstok = ds.map(tokenize_function, batched = True)
dstok = dstok.rename_columns({'entailment':'label'})

training_args = TrainingArguments(output_dir=training_directory,
    logging_dir=f'{training_directory}/logs',
    lr_scheduler_type= "linear",
    group_by_length=True,
    learning_rate=9e-6 if "large" in modname else 2e-5,
    per_device_train_batch_size=4 if "large" in modname else 16,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4 if "large" in modname else 1,  
    num_train_epochs=10,
    warmup_ratio=0.06,
    weight_decay=0.01,
    fp16=True,
    eval_strategy="epoch",
    seed=1,
    save_strategy="epoch",
    dataloader_num_workers = 12,
)

def compute_metrics_standard(eval_pred, label_text_alphabetical=list(id2label.values())):
    labels = eval_pred.label_ids
    pred_logits = eval_pred.predictions
    preds_max = np.argmax(pred_logits, axis=1)

    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(labels, preds_max, average='macro')
    precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(labels, preds_max, average='micro')
    acc_balanced = balanced_accuracy_score(labels, preds_max)
    acc_not_balanced = accuracy_score(labels, preds_max)
    mcc = matthews_corrcoef(labels, preds_max)

    metrics = {'MCC': mcc,
            'f1_macro': f1_macro,
            'f1_micro': f1_micro,
            'accuracy_balanced': acc_balanced,
            'accuracy': acc_not_balanced,
            'precision_macro': precision_macro,
            'recall_macro': recall_macro,
            'precision_micro': precision_micro,
            'recall_micro': recall_micro,
            }
    print("Aggregate metrics: ", {key: metrics[key] for key in metrics if key not in ["label_gold_raw", "label_predicted_raw"]} )
    print("Detailed metrics: ", classification_report(
        labels, preds_max, labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),
        target_names=label_text_alphabetical, sample_weight=None,
        digits=2, output_dict=True, zero_division='warn'),
    "\n")

    return metrics

trainer = Trainer(
    model_init=model_init,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=dstok['train'],
    eval_dataset=dstok['validation'],
    compute_metrics=lambda x: compute_metrics_standard(x, label_text_alphabetical=list(id2label.values()))
)

Map:   0%|          | 0/201691 [00:00<?, ? examples/s]

Map:   0%|          | 0/15036 [00:00<?, ? examples/s]

Map:   0%|          | 0/15366 [00:00<?, ? examples/s]

In [None]:
trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmlburnham[0m. Use [1m`wandb login --relogin`[0m to force relogin


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Epoch,Training Loss,Validation Loss
