<a href="https://colab.research.google.com/github/MoritzLaurer/zeroshot-classifier/blob/main/4_train_eval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Training and Evaluation
This script takes the prepared data from the other scripts and trains and evaluates a model.

### Install and setup
Note that this notebook was written and tested in Google Colab to increase reproducibility & accessibility. The final run, however, was executed on a university HPC system on an A100 GPU for convenience. With a Google Colab Pro subscription, you should also be able to run the notebook. It will take several hours.  

In [2]:
USING_COLAB = False
DO_TRAIN = True
UPLOAD_TO_HUB = True

# set global seed for reproducibility and against seed hacking
SEED_GLOBAL = 42

# for tests in Colab and project title
DATE = 20231109

if USING_COLAB:
    # uncomment this if you are using colab
    """!pip install transformers[sentencepiece]~=4.33.0 -qq
    !pip install datasets~=2.14.0 -qq
    !pip install accelerate~=0.23.0 -qq
    !pip install wandb~=0.15.0 -qq
    !pip install mdutils~=1.6.0 -qq
    !pip install scikit-learn~=1.2.0 -qq"""

In [3]:
## load packages
import pandas as pd
import numpy as np
import os
from datasets import load_dataset
import re
import time
import random
import tqdm

import torch
from torch.utils.data import DataLoader

import transformers
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import TrainingArguments, Trainer
from datasets import ClassLabel
from datasets import load_dataset, load_metric, Dataset, DatasetDict, concatenate_datasets, list_metrics

from sklearn.metrics import balanced_accuracy_score, precision_recall_fscore_support, accuracy_score, classification_report

import gc
from accelerate.utils import release_memory

import wandb
import json
from datetime import datetime

from mdutils import MdUtils

np.random.seed(SEED_GLOBAL)


In [4]:
if USING_COLAB:
    # info on the GPU you are using
    #!nvidia-smi
    # info on available ram
    from psutil import virtual_memory
    ram_gb = virtual_memory().total / 1e9
    print('\n\nYour runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))



Your runtime has 54.8 gigabytes of available RAM



In [5]:
if USING_COLAB:
    ## connect to google drive
    from google.colab import drive
    drive.mount('/content/drive', force_remount=False)

    #set wd
    print(os.getcwd())
    os.chdir("/content/drive/My Drive/PhD/zero-shot-models")

print(os.getcwd())

# local config.py file with tokens
import config

Mounted at /content/drive
/content
/content/drive/My Drive/PhD/zero-shot-models


### Load data

In [6]:
# load from hub
dataset_train = load_dataset("MoritzLaurer/dataset_train_nli", token=config.HF_ACCESS_TOKEN)["train"]
dataset_test_concat_nli = load_dataset("MoritzLaurer/dataset_test_concat_nli", token=config.HF_ACCESS_TOKEN)["train"]
dataset_test_disaggregated = load_dataset("MoritzLaurer/dataset_test_disaggregated_nli", token=config.HF_ACCESS_TOKEN)


Downloading readme:   0%|          | 0.00/691 [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/187M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/987881 [00:00<?, ? examples/s]

Downloading readme:   0%|          | 0.00/649 [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/8.73M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/59140 [00:00<?, ? examples/s]

Downloading readme:   0%|          | 0.00/4.97k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/36 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/1.17M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.20M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.82M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/239k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/247k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/303k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/505k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/365k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/109k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/98.1k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.00M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.62M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.54M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/93.8k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/660k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/962k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/167k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/317k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.21M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.31M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/20.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/387k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/334k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/317k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/68.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/575k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/70.8k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/71.9k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.05M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.88M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.44M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.35M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.75M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.87M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/318k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/36 [00:00<?, ?it/s]

Generating mnli_m split:   0%|          | 0/9815 [00:00<?, ? examples/s]

Generating mnli_mm split:   0%|          | 0/9832 [00:00<?, ? examples/s]

Generating fevernli split:   0%|          | 0/19652 [00:00<?, ? examples/s]

Generating anli_r1 split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating anli_r2 split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating anli_r3 split:   0%|          | 0/1200 [00:00<?, ? examples/s]

Generating wanli split:   0%|          | 0/5000 [00:00<?, ? examples/s]

Generating lingnli split:   0%|          | 0/4893 [00:00<?, ? examples/s]

Generating wellformedquery split:   0%|          | 0/5934 [00:00<?, ? examples/s]

Generating rottentomatoes split:   0%|          | 0/2132 [00:00<?, ? examples/s]

Generating amazonpolarity split:   0%|          | 0/20000 [00:00<?, ? examples/s]

Generating imdb split:   0%|          | 0/20000 [00:00<?, ? examples/s]

Generating yelpreviews split:   0%|          | 0/20000 [00:00<?, ? examples/s]

Generating hatexplain split:   0%|          | 0/2922 [00:00<?, ? examples/s]

Generating massive split:   0%|          | 0/175466 [00:00<?, ? examples/s]

Generating banking77 split:   0%|          | 0/221760 [00:00<?, ? examples/s]

Generating emotiondair split:   0%|          | 0/12000 [00:00<?, ? examples/s]

Generating emocontext split:   0%|          | 0/22036 [00:00<?, ? examples/s]

Generating empathetic split:   0%|          | 0/81344 [00:00<?, ? examples/s]

Generating agnews split:   0%|          | 0/30400 [00:00<?, ? examples/s]

Generating yahootopics split:   0%|          | 0/500000 [00:00<?, ? examples/s]

Generating biasframes_sex split:   0%|          | 0/8808 [00:00<?, ? examples/s]

Generating biasframes_offensive split:   0%|          | 0/7676 [00:00<?, ? examples/s]

Generating biasframes_intent split:   0%|          | 0/7296 [00:00<?, ? examples/s]

Generating financialphrasebank split:   0%|          | 0/2070 [00:00<?, ? examples/s]

Generating appreviews split:   0%|          | 0/8000 [00:00<?, ? examples/s]

Generating hateoffensive split:   0%|          | 0/2586 [00:00<?, ? examples/s]

Generating trueteacher split:   0%|          | 0/17910 [00:00<?, ? examples/s]

Generating spam split:   0%|          | 0/2070 [00:00<?, ? examples/s]

Generating wikitoxic_toxicaggregated split:   0%|          | 0/20000 [00:00<?, ? examples/s]

Generating wikitoxic_obscene split:   0%|          | 0/17382 [00:00<?, ? examples/s]

Generating wikitoxic_identityhate split:   0%|          | 0/11424 [00:00<?, ? examples/s]

Generating wikitoxic_threat split:   0%|          | 0/10422 [00:00<?, ? examples/s]

Generating wikitoxic_insult split:   0%|          | 0/16854 [00:00<?, ? examples/s]

Generating manifesto split:   0%|          | 0/246008 [00:00<?, ? examples/s]

Generating capsotu split:   0%|          | 0/20790 [00:00<?, ? examples/s]

### Tokenize, train eval

In [7]:
### Load model and tokenizer

if DO_TRAIN:
    model_name = "microsoft/deberta-v3-base"  #"microsoft/deberta-v3-large"  #"microsoft/deberta-v3-base"
else:
    model_name = "sileod/deberta-v3-base-tasksource-nli"  #"facebook/bart-large-mnli"  #"sileod/deberta-v3-base-tasksource-nli"  #"MoritzLaurer/DeBERTa-v3-base-mnli-fever-docnli-ling-2c"

max_length = 512

## load model and tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

# label2id mapping
if DO_TRAIN:
    label2id = {"entailment": 0, "not_entailment": 1}  #{"entailment": 0, "neutral": 1, "contradiction": 2}
    id2label = {0: "entailment", 1: "not_entailment"}  #{0: "entailment", 1: "neutral", 2: "contradiction"}

    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, model_max_length=max_length)  # model_max_length=512
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name, label2id=label2id, id2label=id2label
    ).to(device)

    label_text_unique = list(label2id.keys())
    print(label_text_unique)

else:
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, model_max_length=max_length)  # model_max_length=512
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name
    ).to(device)

    label_text_unique = list(model.config.id2label.values())
    print(label_text_unique)



Device: cpu


(…)-base/resolve/main/tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

(…)deberta-v3-base/resolve/main/config.json:   0%|          | 0.00/579 [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


pytorch_model.bin:   0%|          | 0.00/371M [00:00<?, ?B/s]

Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'pooler.dense.weight', 'classifier.weight', 'pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


['entailment', 'not_entailment']


#### Tokenize

In [8]:
# Dynamic padding HF course: https://huggingface.co/course/chapter3/2?fw=pt

# without padding="max_length" & max_length=512, it should do dynamic padding.
def tokenize_func(examples):
    return tokenizer(examples["text"], examples["hypothesis"], truncation=True)  # max_length=512,  padding=True

# training on:
encoded_dataset_train = dataset_train.map(tokenize_func, batched=True)
print(len(encoded_dataset_train))
# testing on:
encoded_dataset_test = dataset_test_concat_nli.map(tokenize_func, batched=True)
print(len(encoded_dataset_test))
# testing on individual datasets:
encoded_dataset_test_disaggregated = dataset_test_disaggregated.map(tokenize_func, batched=True)

# remove columns the library does not expect
encoded_dataset_train = encoded_dataset_train.remove_columns(["hypothesis", "text"])
encoded_dataset_test = encoded_dataset_test.remove_columns(["hypothesis", "text"])


Map:   0%|          | 0/987881 [00:00<?, ? examples/s]

987881


Map:   0%|          | 0/59140 [00:00<?, ? examples/s]

59140


Map:   0%|          | 0/9815 [00:00<?, ? examples/s]

Map:   0%|          | 0/9832 [00:00<?, ? examples/s]

Map:   0%|          | 0/19652 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1200 [00:00<?, ? examples/s]

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Map:   0%|          | 0/4893 [00:00<?, ? examples/s]

Map:   0%|          | 0/5934 [00:00<?, ? examples/s]

Map:   0%|          | 0/2132 [00:00<?, ? examples/s]

Map:   0%|          | 0/20000 [00:00<?, ? examples/s]

Map:   0%|          | 0/20000 [00:00<?, ? examples/s]

Map:   0%|          | 0/20000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2922 [00:00<?, ? examples/s]

Map:   0%|          | 0/175466 [00:00<?, ? examples/s]

Map:   0%|          | 0/221760 [00:00<?, ? examples/s]

Map:   0%|          | 0/12000 [00:00<?, ? examples/s]

Map:   0%|          | 0/22036 [00:00<?, ? examples/s]

Map:   0%|          | 0/81344 [00:00<?, ? examples/s]

Map:   0%|          | 0/30400 [00:00<?, ? examples/s]

Map:   0%|          | 0/500000 [00:00<?, ? examples/s]

Map:   0%|          | 0/8808 [00:00<?, ? examples/s]

Map:   0%|          | 0/7676 [00:00<?, ? examples/s]

Map:   0%|          | 0/7296 [00:00<?, ? examples/s]

Map:   0%|          | 0/2070 [00:00<?, ? examples/s]

Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2586 [00:00<?, ? examples/s]

Map:   0%|          | 0/17910 [00:00<?, ? examples/s]

Map:   0%|          | 0/2070 [00:00<?, ? examples/s]

Map:   0%|          | 0/20000 [00:00<?, ? examples/s]

Map:   0%|          | 0/17382 [00:00<?, ? examples/s]

Map:   0%|          | 0/11424 [00:00<?, ? examples/s]

Map:   0%|          | 0/10422 [00:00<?, ? examples/s]

Map:   0%|          | 0/16854 [00:00<?, ? examples/s]

Map:   0%|          | 0/246008 [00:00<?, ? examples/s]

Map:   0%|          | 0/20790 [00:00<?, ? examples/s]

In [9]:
encoded_dataset_test_disaggregated["mnli_m"].to_pandas()["label_text"].value_counts()

not_entailment    6336
entailment        3479
Name: label_text, dtype: int64

#### Training

In [10]:
# release memory: https://huggingface.co/blog/optimize-llm

def flush():
  gc.collect()
  torch.cuda.empty_cache()
  torch.cuda.reset_peak_memory_stats()


In [11]:

# function for computing metrics for normally formatted classification tasks
# here, this is used for the standard NLI datasets like MNLI, ANLI etc
def compute_metrics_standard(eval_pred, label_text_alphabetical=None):
    labels = eval_pred.label_ids
    pred_logits = eval_pred.predictions
    preds_max = np.argmax(pred_logits, axis=1)  # argmax on each row (axis=1) in the tensor

    # metrics
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(labels, preds_max, average='macro')  # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html
    precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(labels, preds_max, average='micro')  # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html
    acc_balanced = balanced_accuracy_score(labels, preds_max)
    acc_not_balanced = accuracy_score(labels, preds_max)

    metrics = {'f1_macro': f1_macro,
            'f1_micro': f1_micro,
            'accuracy_balanced': acc_balanced,
            'accuracy': acc_not_balanced,
            'precision_macro': precision_macro,
            'recall_macro': recall_macro,
            'precision_micro': precision_micro,
            'recall_micro': recall_micro,
            #'label_gold_raw': labels,
            #'label_predicted_raw': preds_max
            }
    print("Aggregate metrics: ", {key: metrics[key] for key in metrics if key not in ["label_gold_raw", "label_predicted_raw"]} )  # print metrics but without label lists
    print("Detailed metrics: ", classification_report(
        labels, preds_max, labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),
        target_names=label_text_alphabetical, sample_weight=None,
        digits=2, output_dict=True, zero_division='warn'),
    "\n")

    return metrics


# function to compute metrics for classification tasks that have been reformatted into the NLI format
# here, this is used for the non-NLI classification tasks (which were converted to NLI format)
def compute_metrics_nli_binary(eval_pred, label_text_alphabetical=None):
    predictions, labels = eval_pred

    # hacky special handling for BART encoder-decoder model
    if "bart" in model_name:
        predictions = predictions[0]

    # split in chunks with predictions for each hypothesis for one unique premise
    def chunks(lst, n):  # Yield successive n-sized chunks from lst. https://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks
        for i in range(0, len(lst), n):
            yield lst[i:i + n]

    # for each chunk/premise, select the most likely hypothesis, either via raw logits, or softmax
    prediction_chunks_lst = list(chunks(predictions, len(set(label_text_alphabetical)) ))  # len(LABEL_TEXT_ALPHABETICAL)
    hypo_position_highest_prob = []
    for i, chunk in enumerate(prediction_chunks_lst):
        # only accesses the first column of the array, i.e. the entailment prediction logit of all hypos and takes the highest one
        if "bart" not in model_name:
            hypo_position_highest_prob.append(np.argmax(chunk[:, 0]))
        else:  # bart has label sequence ['contradiction', 'neutral', 'entailment']
            hypo_position_highest_prob.append(np.argmax(chunk[:, 2]))

    label_chunks_lst = list(chunks(labels, len(set(label_text_alphabetical)) ))
    label_position_gold = []
    for chunk in label_chunks_lst:
        label_position_gold.append(np.argmin(chunk))  # argmin to detect the position of the 0 among the 1s

    # for inspection
    print("Highest probability prediction per premise: ", hypo_position_highest_prob)
    print("Correct label per premise: ", label_position_gold)

    ## metrics
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(label_position_gold, hypo_position_highest_prob, average='macro')  # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html
    precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(label_position_gold, hypo_position_highest_prob, average='micro')  # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html
    acc_balanced = balanced_accuracy_score(label_position_gold, hypo_position_highest_prob)
    acc_not_balanced = accuracy_score(label_position_gold, hypo_position_highest_prob)
    metrics = {'f1_macro': f1_macro,
               'f1_micro': f1_micro,
               'accuracy_balanced': acc_balanced,
               'accuracy': acc_not_balanced,
               'precision_macro': precision_macro,
               'recall_macro': recall_macro,
               'precision_micro': precision_micro,
               'recall_micro': recall_micro,
               #'label_gold_raw': label_position_gold,
               #'label_predicted_raw': hypo_position_highest_prob
               }
    print("Aggregate metrics: ", {
        key: metrics[key] for key in metrics
        if key not in ["label_gold_raw", "label_predicted_raw"]
    })
    print("Detailed metrics: ", classification_report(
        label_position_gold,
        hypo_position_highest_prob,
        labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),
        target_names=label_text_alphabetical,
        sample_weight=None, digits=2, output_dict=True,
        zero_division='warn'),
    "\n")

    return metrics


In [None]:
## test logging with wandb
wandb.login(key=config.WANDB_ACCESS_TOKEN)

now = datetime.now().strftime("%Y-%m-%d-%H-%M")
run_name = f"{model_name.split('/')[-1]}-zeroshot-{now}"
print(now)
print(run_name)

# https://huggingface.co/docs/transformers/v4.34.0/en/main_classes/callback#transformers.integrations.WandbCallback
os.environ["WANDB_PROJECT"] = f"nli-zeroshot-{DATE}"  # log to your project
#%env WANDB_PROJECT=amazon_sentiment_analysis
os.environ["WANDB_LOG_MODEL"] = "false"  # Can be "end", "checkpoint" or "false". If set to "end", the model will be uploaded at the end of training. If set to "checkpoint", the checkpoint will be uploaded every args.save_steps . If set to "false", the model will not be uploaded. Use along with load_best_model_at_end() to upload best model.
os.environ["WANDB_WATCH"] = "parameters"   # Can be "gradients", "all", "parameters", or "false". Set to "all" to log gradients and parameters.


In [15]:
training_directory = f'./results/{model_name.split("/")[-1]}-zeroshot-{now}'

fp16_bool = True if torch.cuda.is_available() else False
if "mDeBERTa" in model_name: fp16_bool = False  # mDeBERTa does not support FP16 yet

# https://huggingface.co/transformers/main_classes/trainer.html#transformers.TrainingArguments
eval_batch = 64 if "large" in model_name else 64*2
per_device_train_batch_size = 16 if "large" in model_name else 32
gradient_accumulation_steps = 2 if "large" in model_name else 1

if USING_COLAB:
    per_device_train_batch_size = int(per_device_train_batch_size / 4)
    gradient_accumulation_steps = int(gradient_accumulation_steps * 4)
    eval_batch = int(eval_batch / 32) if "large" in model_name else int(eval_batch / 8)

if "bart" in model_name:
    eval_batch = int(eval_batch / 64) if "large" in model_name else int(eval_batch / 32)


train_args = TrainingArguments(
    output_dir=training_directory,
    logging_dir=f'{training_directory}/logs',
    #deepspeed="ds_config_zero3.json",  # if using deepspeed
    lr_scheduler_type= "linear",
    group_by_length=False,  # can increase speed with dynamic padding, by grouping similar length texts https://huggingface.co/transformers/main_classes/trainer.html
    learning_rate=9e-6 if "large" in model_name else 2e-5,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=eval_batch,
    gradient_accumulation_steps=gradient_accumulation_steps,  # (!adapt/halve batch size accordingly). accumulates gradients over X steps, only then backward/update. decreases memory usage, but also slightly speed
    #eval_accumulation_steps=2,
    num_train_epochs=3,
    #max_steps=400,
    #warmup_steps=0,  # 1000,
    warmup_ratio=0.06,  #0.1, 0.06
    weight_decay=0.01,  #0.1,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    fp16=fp16_bool,   # ! only makes sense at batch-size > 8. loads two copies of model weights, which creates overhead. https://huggingface.co/transformers/performance.html?#fp16
    fp16_full_eval=fp16_bool,
    evaluation_strategy="epoch",
    seed=SEED_GLOBAL,
    #eval_steps=300,  # evaluate after n steps if evaluation_strategy!='steps'. defaults to logging_steps
    save_strategy="epoch",  # options: "no"/"steps"/"epoch"
    #save_steps=1_000_000,  # Number of updates steps before two checkpoint saves.
    save_total_limit=3,  # If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in output_dir
    #logging_strategy="epoch",
    report_to="all",  # "all"
    run_name=run_name,
    #push_to_hub=True,
    #push_to_hub_model_id="test97531", #f"{model_name}-finetuned-{task}",
    #hub_token="XXX",  # for pushing to hub
)


In [16]:
trainer = Trainer(
    model=model,
    #model_init=model_init,
    tokenizer=tokenizer,
    args=train_args,
    train_dataset=encoded_dataset_train,  #.shard(index=1, num_shards=200),  # https://huggingface.co/docs/datasets/processing.html#sharding-the-dataset-shard
    eval_dataset=encoded_dataset_test,  #.shard(index=1, num_shards=20),
    compute_metrics=lambda x: compute_metrics_standard(x, label_text_alphabetical=label_text_unique)  #compute_metrics,
    #data_collator=data_collator,  # for weighted sampling per dataset; for dynamic padding probably not necessary because done by default  https://huggingface.co/course/chapter3/3?fw=pt
)

if device == "cuda":
    # free memory
    flush()
    release_memory(model)
    #del (model, trainer)


In [17]:
# train
if DO_TRAIN:
    trainer.train()


#### Evaluation

In [None]:
# could load specific model for evaluation here
#model = AutoModelForSequenceClassification.from_pretrained('./results/nli-few-shot/all-nli-3c/DeBERTa-v3-mnli-fever-anli-v1',   # nli_effect/distilroberta-paraphrase-mnli-fever-anli-v1
#                                                           label2id=label2id, id2label=id2label).to(device)

# free memory
if device == "cuda":
    flush()
    release_memory(model)

datasets_not_to_evaluate = ["dummy_dataset"]   # "anthropic", "banking77", "massive", "empathetic"

result_dic = {}
for key_task_name, value_dataset in tqdm.tqdm(encoded_dataset_test_disaggregated.items(), desc="Iterations over testsets"):
    print(f"\n*** Evaluating task: {key_task_name}. Length of dataset: {len(value_dataset)}")
    # skip selected datasets
    if any(dataset_name in key_task_name for dataset_name in datasets_not_to_evaluate):
        continue
    # eval not_nli datasets
    elif key_task_name not in ["mnli_m", "mnli_mm", "fevernli", "anli_r1", "anli_r2", "anli_r3", "wanli", "lingnli"]:  #dataset_test_disaggregated.keys():
        label_text_alphabetical_task = np.sort(np.unique(value_dataset["label_text"])).tolist()
        trainer.compute_metrics = lambda x: compute_metrics_nli_binary(x, label_text_alphabetical=label_text_alphabetical_task)
        result = trainer.evaluate(eval_dataset=encoded_dataset_test_disaggregated[key_task_name])
    # eval nli datasets. only works for binary nli models/datasets
    elif len(label_text_unique) == 2:
        trainer.compute_metrics = lambda x: compute_metrics_standard(x, label_text_alphabetical=label_text_unique)
        result = trainer.evaluate(eval_dataset=encoded_dataset_test_disaggregated[key_task_name])
    else:
        raise ValueError(f"Issue with task: {key_task_name}")

    result_dic.update({key_task_name: result})
    print(f"Result for task {key_task_name}: ", result, "\n")

    if device == "cuda":
        flush()
        release_memory(model)


print("\n\nOverall results: ", result_dic)


In [None]:
# add disaggregated metrics to wandb
wandb.log(result_dic)
print("wandb.run.id: ", wandb.run.id)
print("wandb.run.name: ", wandb.run.name)

wandb.finish()


In [None]:
## testing automatic creation of .md file
# https://mdutils.readthedocs.io/en/latest/mdutils.html#subpackages
mdFile = MdUtils(file_name=f'README-{model_name.split("/")[-1]}-{DATE}', title='Model Card')

row_dataset_names = list(result_dic.keys())
row_metrics = [str(round(value["eval_accuracy"], 3)) for key, value in result_dic.items()]
row_samp_per_sec = [str(round(value["eval_samples_per_second"], 0)) for key, value in result_dic.items()]

table_lst = ["Datasets"] + row_dataset_names + ["Accuracy"] + row_metrics + [f"Inference text/sec (A100, batch={eval_batch})"] + row_samp_per_sec

# create markdown table with results
#mdFile.new_line()
results_table_me = mdFile.new_table(columns=len(list(result_dic.keys()))+1, rows=3, text=table_lst, text_align='center')
print(results_table_me)

# write results_table_me to training directors
path_main = os.getcwd()
os.chdir(training_directory)
mdFile.create_md_file()
os.chdir(path_main)


In [None]:

if UPLOAD_TO_HUB and DO_TRAIN:
    #trainer.push_to_hub()  # does not work for some reason. wheel spins but nothing happens.

    # save best model to disk
    model_path = f"{training_directory}/best-{model_name.split('/')[-1]}-{DATE}"

    trainer.save_model(output_dir=model_path)

    print(os.getcwd())
    model = AutoModelForSequenceClassification.from_pretrained(model_path, torch_dtype=torch.float16)
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, model_max_length=512)

    ## Push to hub
    #!sudo apt-get install git-lfs
    #!huggingface-cli login
    # unnecessary if token provided below

    # https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.push_to_hub
    model.push_to_hub(repo_id=f'MoritzLaurer/{model_name.split("/")[-1]}-zeroshot-v1.1', use_temp_dir=True, private=True, use_auth_token=config.HF_ACCESS_TOKEN)
    tokenizer.push_to_hub(repo_id=f'MoritzLaurer/{model_name.split("/")[-1]}-zeroshot-v1.1', use_temp_dir=True, private=True, use_auth_token=config.HF_ACCESS_TOKEN)
