In [1]:
from datasets import Dataset

In [2]:
import argparse
import json
import math
import os
import random
from time import time
import mlflow
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, f1_score
from collections import defaultdict

# import pytrec_eval
import torch
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
from torch.utils.data import DataLoader, RandomSampler
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig, AutoTokenizer
from transformers import AdamW, get_linear_schedule_with_warmup
from accelerate import Accelerator


torch.backends.cuda.matmul.allow_tf32 = True

from watchog.dataset import (
    # collate_fn,
    TURLColTypeTablewiseDataset,
    TURLRelExtTablewiseDataset,
    SatoCVTablewiseDataset,
    ColPoplTablewiseDataset
)

from watchog.dataset import TableDataset, SupCLTableDataset, SemtableCVTablewiseDataset, GittablesTablewiseDataset, GittablesColwiseDatasetDecoder
from watchog.model import BertMultiPairPooler, BertForMultiOutputClassification, BertForMultiOutputClassificationColPopl
from watchog.model import SupCLforTable, UnsupCLforTable, lm_mp
from watchog.utils import load_checkpoint, f1_score_multilabel, collate_fn, get_col_pred, ColPoplEvaluator
from watchog.utils import task_num_class_dict
from accelerate import DistributedDataParallelKwargs
import wandb
import os

In [3]:

    parser = argparse.ArgumentParser()
    parser.add_argument("--wandb", type=bool, default=False)
    parser.add_argument("--model", type=str, default="Watchog")
    parser.add_argument("--unlabeled_train_only", type=bool, default=False)
    parser.add_argument("--context_encoding_type", type=str, default="v1")
    parser.add_argument("--pool_version", type=str, default="v0.2")
    parser.add_argument("--random_sample", type=bool, default=False)
    parser.add_argument("--comment", type=str, default="debug", help="to distinguish the runs")
    parser.add_argument(
        "--shortcut_name",
        default="bert-base-uncased",
        type=str,
        help="Huggingface model shortcut name ",
    )
    parser.add_argument(
        "--max_length",
        default=64,
        type=int,
        help=
        "The maximum total input sequence length after tokenization. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.",
    )
    parser.add_argument(
        "--max_num_col",
        default=8,
        type=int,
    )   

    parser.add_argument(
        "--batch_size",
        default=8,
        type=int,
        help="Batch size",
    )
    parser.add_argument(
        "--epoch",
        default=15,
        type=int,
        help="Number of epochs for training",
    )
    parser.add_argument(
        "--random_seed",
        default=4649,
        type=int,
        help="Random seed",
    )
    
    parser.add_argument(
        "--train_n_seed_cols",
        default=-1,
        type=int,
        help="number of seeding columns in training",
    )

    parser.add_argument(
        "--num_classes",
        default=78,
        type=int,
        help="Number of classes",
    )
    parser.add_argument("--multi_gpu",
                        action="store_true",
                        default=False,
                        help="Use multiple GPU")
    parser.add_argument("--fp16",
                        action="store_true",
                        default=False,
                        help="Use FP16")
    parser.add_argument("--warmup",
                        type=float,
                        default=0.,
                        help="Warmup ratio")
    parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate")
    parser.add_argument("--task",
                        type=str,
                        default='gt-semtab22-dbpedia-all0',
                        choices=[
                            "sato0", "sato1", "sato2", "sato3", "sato4",
                            "msato0", "msato1", "msato2", "msato3", "msato4",
                            "gt-dbpedia0", "gt-dbpedia1", "gt-dbpedia2", "gt-dbpedia3", "gt-dbpedia4",
                            "gt-dbpedia-all0", "gt-dbpedia-all1", "gt-dbpedia-all2", "gt-dbpedia-all3", "gt-dbpedia-all4",
                            "gt-schema-all0", "gt-schema-all1", "gt-schema-all2", "gt-schema-all3", "gt-schema-all4",
                            "gt-semtab22-dbpedia", "gt-semtab22-dbpedia0", "gt-semtab22-dbpedia1", "gt-semtab22-dbpedia2", "gt-semtab22-dbpedia3", "gt-semtab22-dbpedia4",
                            "gt-semtab22-dbpedia-all", "gt-semtab22-dbpedia-all0", "gt-semtab22-dbpedia-all1", "gt-semtab22-dbpedia-all2", "gt-semtab22-dbpedia-all3", "gt-semtab22-dbpedia-all4",
                            "gt-semtab22-schema-class-all", "gt-semtab22-schema-property-all",
                            "turl", "turl-re", "col-popl-1", "col-popl-2", "col-popl-3", "row-popl",
                            "col-popl-turl-0", "col-popl-turl-1", "col-popl-turl-2",
                            "col-popl-turl-mdonly-0", "col-popl-turl-mdonly-1", "col-popl-turl-mdonly-2"
                        ],
                        help="Task names}")
    parser.add_argument("--colpair",
                        action="store_true",
                        help="Use column pair embedding")
    parser.add_argument("--metadata",
                        action="store_true",
                        help="Use column header metadata")
    parser.add_argument("--from_scratch",
                        action="store_true",
                        help="Training from scratch")
    parser.add_argument("--cl_tag",
                        type=str,
                        default="wikitables/simclr/bert_100000_10_32_256_5e-05_sample_row4,sample_row4_tfidf_entity_column_0.05_0_last.pt",
                        help="path to the pre-trained file")
    parser.add_argument("--dropout_prob",
                        type=float,
                        default=0.5)
    parser.add_argument("--eval_test",
                        action="store_true",
                        help="evaluate on testset and do not save the model file")
    parser.add_argument("--small_tag",
                        type=str,
                        default="semi1",
                        help="e.g., by_table_t5_v1")
    parser.add_argument("--data_path",
                        type=str,
                        default="/data/zhihao/TU/")
    parser.add_argument("--pretrained_ckpt_path",
                        type=str,
                        default="/data/zhihao/TU/Watchog/model/")    


    args = parser.parse_args([])

In [4]:
    task = args.task
    if args.small_tag != "":
        args.eval_test = True
    
    args.num_classes = task_num_class_dict[task]
    if args.colpair:
        assert "turl-re" == task, "colpair can be only used for Relation Extraction"
    if args.metadata:
        assert "turl-re" == task or "turl" == task, "metadata can be only used for TURL datasets"
    if "col-popl":
        # metrics = {
        #     "accuracy": CategoricalAccuracy(tie_break=True),
        # }
        if args.train_n_seed_cols != -1:
            if "col-popl" in task:
                assert args.train_n_seed_cols == int(task[-1]),  "# of seed columns must match"

    print("args={}".format(json.dumps(vars(args))))

    max_length = args.max_length
    batch_size = args.batch_size
    num_train_epochs = args.epoch

    shortcut_name = args.shortcut_name

    if args.colpair and args.metadata:
        taskname = "{}-colpair-metadata".format(task)
    elif args.colpair:
        taskname = "{}-colpair".format(task)
    elif args.metadata:
        taskname = "{}-metadata".format(task)
    elif args.train_n_seed_cols == -1 and 'popl' in task:
        taskname = "{}-mix".format(task)
    else:
        taskname = "".join(task)

args={"wandb": false, "model": "Watchog", "unlabeled_train_only": false, "context_encoding_type": "v1", "pool_version": "v0.2", "random_sample": false, "comment": "debug", "shortcut_name": "bert-base-uncased", "max_length": 64, "max_num_col": 8, "batch_size": 8, "epoch": 15, "random_seed": 4649, "train_n_seed_cols": -1, "num_classes": 101, "multi_gpu": false, "fp16": false, "warmup": 0.0, "lr": 5e-05, "task": "gt-semtab22-dbpedia-all0", "colpair": false, "metadata": false, "from_scratch": false, "cl_tag": "wikitables/simclr/bert_100000_10_32_256_5e-05_sample_row4,sample_row4_tfidf_entity_column_0.05_0_last.pt", "dropout_prob": 0.5, "eval_test": true, "small_tag": "semi1", "data_path": "/data/zhihao/TU/", "pretrained_ckpt_path": "/data/zhihao/TU/Watchog/model/"}


In [5]:
task

'gt-semtab22-dbpedia-all0'

In [6]:
from transformers import AutoTokenizer, DataCollatorWithPadding
MAX_LEN = 512 
checkpoint = "meta-llama/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(checkpoint, add_prefix_space=True, cache_dir="/data/zhihao/hf")
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
tokenizer.cls_token = tokenizer.bos_token
tokenizer.cls_token_id = tokenizer.bos_token_id

In [7]:
run_name = f"Llama_max_cols@{args.max_num_col}_{args.small_tag}_DS@{args.task}_scratch@{args.from_scratch}_maxlen@{args.max_length}_bs@{args.batch_size}"

In [8]:
src = 'dbpedia'
dataset_cls = GittablesColwiseDatasetDecoder
cv = 0
max_length = 64
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_dataset = dataset_cls(cv=cv,
                            split="train",
                            src=src,
                            tokenizer=tokenizer,
                            max_length=max_length,
                            gt_only='all' not in task,
                            device=device,
                            base_dirpath=os.path.join(args.data_path, "GitTables/semtab_gittables/2022"),
                            small_tag=args.small_tag,
                            max_num_col=args.max_num_col,
                            random_sample=args.random_sample,
                            context_encoding_type=args.context_encoding_type)
valid_dataset = dataset_cls(cv=cv,
                            split="valid", src=src,
                            tokenizer=tokenizer,
                            max_length=max_length,
                            gt_only='all' not in task or args.unlabeled_train_only,
                            device=device,
                            base_dirpath=os.path.join(args.data_path, "GitTables/semtab_gittables/2022"),
                            small_tag=args.small_tag,
                            max_num_col=args.max_num_col,
                            context_encoding_type=args.context_encoding_type
                            )
test_dataset = dataset_cls(cv=cv,
                            split="test", src=src,
                            tokenizer=tokenizer,
                            max_length=max_length,
                            gt_only='all' not in task or args.unlabeled_train_only,
                            device=device,
                            base_dirpath=os.path.join(args.data_path, "GitTables/semtab_gittables/2022"),
                            small_tag=args.small_tag,
                            max_num_col=args.max_num_col,
                            context_encoding_type=args.context_encoding_type)

train 1
train 2
train 3
train 4
train 3463
valid 1
valid 2
valid 3
valid 4
valid 885
test
test 1085


In [36]:
tokenizer.eos_token_id

2

In [9]:
from datasets import Dataset
# dataset_train = Dataset.from_pandas(train_dataset.table_df[["data_tensor", "label_tensor"]])
dataset_train = Dataset.from_dict({'input_ids': train_dataset.table_df["data_tensor"].tolist(), "label": train_dataset.table_df["label_tensor"].tolist()})
dataset_train.set_format("torch")
dataset_valid = Dataset.from_dict({'input_ids': valid_dataset.table_df["data_tensor"].tolist(), "label": valid_dataset.table_df["label_tensor"].tolist()})
dataset_valid.set_format("torch")
dataset_test = Dataset.from_dict({'input_ids': test_dataset.table_df["data_tensor"].tolist(), "label": test_dataset.table_df["label_tensor"].tolist()})
dataset_test.set_format("torch")

In [10]:
tokenizer.decode(dataset_train[0]["input_ids"])

'<s> False;False;False;False;False;False;False;False;False;False;False;False;False;False;False;False;False<s> 2021-03-22 14:20:57;2021-03-22 14:20:57;2021-03-22 14:20:57;2<s> 2016-08-04 17:26:14;2016-07-14 23:52:42;2016-11-01 15:28:12;2<s> 54ce99fa85c92b1d87678436e956a2e8;5b6cf869265c13af8566f192b4ab3d2a;<s> 2104505001950;2104505001950;2104505001950;2104505001950;21045<s> 163915846063865;155711202630775;205334637062213;15571120263076'

In [None]:
dataset_train[0]["input_ids"]

In [28]:
model.config.pad_token_id

2

In [17]:
sequence_lengths = torch.eq(dataset_train[0]["input_ids"], model.config.pad_token_id).int().argmax(-1) - 1

In [18]:
sequence_lengths

tensor(-1)

In [30]:
torch.eq(dataset_train[0]["input_ids"], model.config.pad_token_id).int().argmax()

tensor(34)

In [31]:
-1%512

511

In [26]:
sequence_lengths

tensor(33)

In [33]:
sum(dataset_train[0]["input_ids"] == 2)

tensor(1)

In [34]:
torch.eq(dataset_train[0]["input_ids"], model.config.pad_token_id).int().argmax(-1)

tensor(35)

In [11]:
len(dataset_train)

3463

In [12]:
# train_dataset.table_df["data_tensor"].tolist()

In [14]:
# dataset_train["label"]

In [12]:
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
@dataclass
class DoduoCollatorWithPadding:
    pad_token_id: int

    def __call__(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]: 
        data = torch.nn.utils.rnn.pad_sequence(
            [sample["input_ids"] for sample in samples], padding_value=self.pad_token_id)
        label = torch.cat([sample["label"] for sample in samples])
        batch = {"input_ids": data.T, "labels": label}
        if "idx" in samples[0]:
            batch["idx"] = [sample["idx"] for sample in samples]
        if "cls_indexes" in samples[0]:
            cls_indexes = torch.nn.utils.rnn.pad_sequence(
                [sample["cls_indexes"] for sample in samples], padding_value=0)
            batch["cls_indexes"] = cls_indexes
        return batch
        
data_collator = DoduoCollatorWithPadding(pad_token_id=tokenizer.pad_token_id)
# data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [13]:
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    AutoTokenizer,
    TrainingArguments,
    GenerationConfig
)

# model = AutoModelForCausalLM.from_pretrained(
#           checkpoint, quantization_config=bnb_config, device_map={"": 0}, cache_dir="/data/zhihao/hf"
# )

In [14]:
from transformers.models.llama.modeling_llama import LlamaPreTrainedModel, LlamaForSequenceClassification

In [15]:
import copy
import importlib
import json
import os
import warnings
from collections import OrderedDict

from transformers.configuration_utils import PretrainedConfig
from transformers.dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from transformers.utils import (
    CONFIG_NAME,
    cached_file,
    copy_func,
    extract_commit_hash,
    find_adapter_config_file,
    is_peft_available,
    logging,
    requires_backends,
)
from transformers.models.auto.configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
from watchog.llm_model import LlamaForColTypeClassification
def load_pretrained_llm(pretrained_model_name_or_path, *model_args, **kwargs):
    config = kwargs.pop("config", None)
    trust_remote_code = kwargs.pop("trust_remote_code", None)
    kwargs["_from_auto"] = True
    hub_kwargs_names = [
        "cache_dir",
        "force_download",
        "local_files_only",
        "proxies",
        "resume_download",
        "revision",
        "subfolder",
        "use_auth_token",
        "token",
    ]
    hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
    code_revision = kwargs.pop("code_revision", None)
    commit_hash = kwargs.pop("_commit_hash", None)
    adapter_kwargs = kwargs.pop("adapter_kwargs", None)

    token = hub_kwargs.pop("token", None)
    use_auth_token = hub_kwargs.pop("use_auth_token", None)
    if use_auth_token is not None:
        warnings.warn(
            "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
            FutureWarning,
        )
        if token is not None:
            raise ValueError(
                "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
            )
        token = use_auth_token

    if token is not None:
        hub_kwargs["token"] = token

    if commit_hash is None:
        if not isinstance(config, PretrainedConfig):
            # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
            resolved_config_file = cached_file(
                pretrained_model_name_or_path,
                CONFIG_NAME,
                _raise_exceptions_for_gated_repo=False,
                _raise_exceptions_for_missing_entries=False,
                _raise_exceptions_for_connection_errors=False,
                **hub_kwargs,
            )
            commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
        else:
            commit_hash = getattr(config, "_commit_hash", None)

    if is_peft_available():
        if adapter_kwargs is None:
            adapter_kwargs = {}
            if token is not None:
                adapter_kwargs["token"] = token

        maybe_adapter_path = find_adapter_config_file(
            pretrained_model_name_or_path, _commit_hash=commit_hash, **adapter_kwargs
        )

        if maybe_adapter_path is not None:
            with open(maybe_adapter_path, "r", encoding="utf-8") as f:
                adapter_config = json.load(f)

                adapter_kwargs["_adapter_model_path"] = pretrained_model_name_or_path
                pretrained_model_name_or_path = adapter_config["base_model_name_or_path"]

    if not isinstance(config, PretrainedConfig):
        kwargs_orig = copy.deepcopy(kwargs)
        # ensure not to pollute the config object with torch_dtype="auto" - since it's
        # meaningless in the context of the config object - torch.dtype values are acceptable
        if kwargs.get("torch_dtype", None) == "auto":
            _ = kwargs.pop("torch_dtype")
        # to not overwrite the quantization_config if config has a quantization_config
        if kwargs.get("quantization_config", None) is not None:
            _ = kwargs.pop("quantization_config")

        config, kwargs = AutoConfig.from_pretrained(
            pretrained_model_name_or_path,
            return_unused_kwargs=True,
            trust_remote_code=trust_remote_code,
            code_revision=code_revision,
            _commit_hash=commit_hash,
            **hub_kwargs,
            **kwargs,
        )

        # if torch_dtype=auto was passed here, ensure to pass it on
        if kwargs_orig.get("torch_dtype", None) == "auto":
            kwargs["torch_dtype"] = "auto"
        if kwargs_orig.get("quantization_config", None) is not None:
            kwargs["quantization_config"] = kwargs_orig["quantization_config"]


    # Set the adapter kwargs
    kwargs["adapter_kwargs"] = adapter_kwargs



    model_class = LlamaForColTypeClassification
    return model_class.from_pretrained(
        pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    )


In [23]:
# from transformers import AutoModelForSequenceClassification, BitsAndBytesConfig
# from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
# import torch
# compute_dtype = getattr(torch, "float16")
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# # bnb_config = BitsAndBytesConfig(
# #         load_in_4bit=True,
# #         bnb_4bit_quant_type="nf4",
# #         bnb_4bit_compute_dtype=compute_dtype,
# #         bnb_4bit_use_double_quant=True,
# # )
# bnb_config = BitsAndBytesConfig(
#     load_in_8bit=True,  # Enable 8-bit quantization
#     bnb_8bit_compute_dtype=compute_dtype,  # Use FP16 for computation
#     bnb_8bit_use_double_quant=True
# )
# model = load_pretrained_llm(
#   pretrained_model_name_or_path=checkpoint,
#   num_labels=args.num_classes,
#   device_map={'': device},
#   quantization_config=bnb_config,
#   cache_dir="/data/zhihao/hf"
# )
# model.resize_token_embeddings(len(tokenizer))
# #Configure the pad token in the model
# model.config.pad_token_id = tokenizer.pad_token_id
# model.config.use_cache = False # Gradient checkpointing is used by default but not compatible with caching

In [16]:
from transformers import AutoModelForSequenceClassification, BitsAndBytesConfig
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
import torch
compute_dtype = getattr(torch, "float16")

# bnb_config = BitsAndBytesConfig(
#         load_in_4bit=True,
#         bnb_4bit_quant_type="nf4",
#         bnb_4bit_compute_dtype=compute_dtype,
#         bnb_4bit_use_double_quant=True,
# )
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,  # Enable 8-bit quantization
    bnb_8bit_compute_dtype=compute_dtype,  # Use FP16 for computation
    bnb_8bit_use_double_quant=True
)

model =  AutoModelForSequenceClassification.from_pretrained(
  pretrained_model_name_or_path=checkpoint,
  num_labels=args.num_classes,
  device_map={'': device},
  quantization_config=bnb_config,
  cache_dir="/data/zhihao/hf"
)
model.resize_token_embeddings(len(tokenizer))
#Configure the pad token in the model
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False # Gradient checkpointing is used by default but not compatible with caching

model = prepare_model_for_kbit_training(model)

Unused kwargs: ['bnb_8bit_compute_dtype', 'bnb_8bit_use_double_quant']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Llama-2-7b-hf and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
args.num_classes

101

In [19]:
from peft import get_peft_model, LoraConfig, TaskType
peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS, r=16, lora_alpha=16, lora_dropout=0.05, bias="none", 
    target_modules=[
        "q_proj",
        "v_proj",  
    ],
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainable params: 8,802,304 || all params: 6,616,559,616 || trainable%: 0.1330


In [20]:
def print_trainable_parameters(model):
    print("Trainable Parameters:")
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name}: {param.shape}, {param.dtype}")

# print_trainable_parameters(model)

In [21]:
import evaluate
import numpy as np

def compute_metrics(eval_pred):
    # All metrics are already predefined in the HF `evaluate` package
    f1_metric= evaluate.load("f1")


    logits, labels = eval_pred # eval_pred is the tuple of predictions and labels returned by the model
    predictions = np.argmax(logits, axis=-1)
    references = labels
    macro_f1 = f1_metric.compute(average='macro', predictions=predictions, references=references)
    micro_f1 = f1_metric.compute(average='micro', predictions=predictions, references=references)
    class_f1 = f1_metric.compute(average=None, predictions=predictions, references=references)
    torch.save({"predictions": predictions, "references": references}, "./results/llm_eval_pred.pt")
    return {"macro_f1": macro_f1, "micro_f1": micro_f1}


  warn(f"Failed to load image Python extension: {e}")


In [22]:
from transformers import TrainingArguments, Trainer
from transformers import Trainer
from torch import inf

class WeightedCELossTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        # Get model's predictions
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # Compute custom loss
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss


training_args = TrainingArguments(
        output_dir="/data/zhihao/hf/checkpoints",
        evaluation_strategy="epoch",
        save_strategy="epoch",
        do_eval=True,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=1,
        per_device_eval_batch_size=args.batch_size,
        # log_level="debug",
        optim= "adamw_torch", # "paged_adamw_32bit",
        # save_strategy="no",  # No checkpoints will be saved
        # save_steps=None,  # Explicitly set to None
        save_total_limit=1, 
        load_best_model_at_end=True,  
        metric_for_best_model="eval_loss",  # Metric to monitor
        greater_is_better=False,         
        logging_steps=50, #change to 100
        learning_rate=1e-4,
        # eval_steps=5, #change to 200
        fp16=True,
        num_train_epochs=args.epoch, # remove "#"
        # max_steps=10, #remove this
        lr_scheduler_type="constant",
        warmup_ratio= 0.1,
        max_grad_norm=0.3,
        weight_decay=0.001,
        report_to="wandb",
        run_name=run_name,
)


trainer = WeightedCELossTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset_train,
    eval_dataset=dataset_valid, 
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [23]:
trainer.train()

[2024-08-01 12:27:51,168] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av



huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av



Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[34m[1mwandb[0m: Currently logged in as: [33mtommyding[0m. Use [1m`wandb login --relogin`[0m to force relogin
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Epoch,Training Loss,Validation Loss,Macro F1,Micro F1
1,3.2995,2.579391,{'f1': 0.1072324007328501},{'f1': 0.4338983050847457}
2,1.9661,2.163586,{'f1': 0.23367323611734758},{'f1': 0.5186440677966102}
3,1.3797,2.171382,{'f1': 0.26555081217115956},{'f1': 0.5050847457627119}
4,0.7987,2.267467,{'f1': 0.2978441060802984},{'f1': 0.5310734463276836}
5,0.4092,2.495165,{'f1': 0.2685406024401992},{'f1': 0.5322033898305085}
6,0.2102,2.546117,{'f1': 0.296101279147423},{'f1': 0.5299435028248588}
7,0.1317,2.65992,{'f1': 0.31581023057029156},{'f1': 0.5514124293785311}
8,0.0747,2.712093,{'f1': 0.2876772922872507},{'f1': 0.5446327683615819}
9,0.0628,2.782274,{'f1': 0.33378483320612196},{'f1': 0.5514124293785311}
10,0.0361,2.83694,{'f1': 0.33278482779308854},{'f1': 0.5694915254237288}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enab

TrainOutput(global_step=2175, training_loss=0.578339083852439, metrics={'train_runtime': 25278.6354, 'train_samples_per_second': 2.055, 'train_steps_per_second': 0.086, 'total_flos': 1.0206230020664525e+18, 'train_loss': 0.578339083852439, 'epoch': 15.0})