## 详细介绍请查看官方文档
https://docs.wandb.ai/guides/integrations/huggingface

## 在使用huggingface transformers进行大模型微调中使用wandb进行实验记录

In [None]:
# 需要设置的内容主要是相关环境变量与参数
import os
from transformers import TrainingArguments, Trainer
os.environ["WANDB_PROJECT"] = "<my-amazing-project>"  # name your W&B project
os.environ["WANDB_LOG_MODEL"] = "checkpoint"  # log all model checkpoints


args = TrainingArguments(..., report_to="wandb")  # turn on W&B logging
trainer = Trainer(..., args=args)

## 以下是具体的使用方法

1. 登录wandb

In [None]:
import wandb
# import os
# os.environ["WANDB_API_KEY"] = "xxx"
wandb.login()

# 可以设置wandb的key环境变量登录

2. 给wandb设置项目名称

In [None]:
import os
os.environ["WANDB_PROJECT"] = "xxx"
# 没有指定则默认为huggingface

3. 模型训练

In [None]:
# 通过report_to参数设置是否上传到wandb
# 通过run_name参数设置上传的名称
# 通过logging_steps参数设置上传的频率

from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    # other args and kwargs here
    report_to="wandb",  # enable logging to W&B
    run_name="bert-base-high-lr",  # name of the W&B run (optional)
    logging_steps=1,  # how often to log to W&B
)

trainer = Trainer(
    # other args and kwargs here
    args=args,  # your training args
)

trainer.train()  # start training and logging to W&B

4. 上传模型参数

In [None]:
# 共两个参数可以选择：checkpoint和end
# checkpoint：上传TrainingArguments中每个args.save_steps
# end：训练结束后上传
# https://docs.wandb.ai/guides/model_registry

import os

os.environ["WANDB_LOG_MODEL"] = "checkpoint"

5. 模型训练后可视化评估输出

In [None]:
# https://docs.wandb.ai/guides/integrations/huggingface#custom-logging-log-and-view-evaluation-samples-during-training
run.summary["article_topics_extract"]

6. 结束wandb

In [None]:
# 如果是使用jupyter或者Google Colab
wandb.finish()

7. 可视化结果

In [None]:
# https://docs.wandb.ai/guides/track/app

## 高级功能

### 保存最好的模型

In [None]:
# 在TrainingArguments中设置load_best_model_at_end=True

### 加载保存的模型

In [None]:
# Create a new run
with wandb.init(project="amazon_sentiment_analysis") as run:
    # Pass the name and version of Artifact
    my_model_name = "model-bert-base-high-lr:latest"
    my_model_artifact = run.use_artifact(my_model_name)

    # Download model weights to a folder and return the path
    model_dir = my_model_artifact.download()

    # Load your Hugging Face model from that folder
    #  using the same model class
    model = AutoModelForSequenceClassification.from_pretrained(
        model_dir, num_labels=num_labels
    )

    # Do additional training, or run inference

### 从一个checkpoint继续训练

In [None]:
last_run_id = "xxxxxxxx"  # fetch the run_id from your wandb workspace

# resume the wandb run from the run_id
with wandb.init(
    project=os.environ["WANDB_PROJECT"],
    id=last_run_id,
    resume="must",
) as run:
    # Connect an Artifact to the run
    my_checkpoint_name = f"checkpoint-{last_run_id}:latest"
    my_checkpoint_artifact = run.use_artifact(my_model_name)

    # Download checkpoint to a folder and return the path
    checkpoint_dir = my_checkpoint_artifact.download()

    # reinitialize your model and trainer
    model = AutoModelForSequenceClassification.from_pretrained(
        "<model_name>", num_labels=num_labels
    )
    # your awesome training arguments here.
    training_args = TrainingArguments()

    trainer = Trainer(model=model, args=training_args)

    # make sure use the checkpoint dir to resume training from the checkpoint
    trainer.train(resume_from_checkpoint=checkpoint_dir)

### 自定义初始化

In [None]:
wandb.init(
    project="amazon_sentiment_analysis",
    name="bert-base-high-lr",
    tags=["baseline", "high-lr"],
    group="bert",
)

### 在训练中记录和观察评估样本

In [None]:
# Instantiate the Trainer as normal
trainer = Trainer()

# Instantiate the new logging callback, passing it the Trainer object
evals_callback = WandbEvalsCallback(trainer, tokenizer, ...)

# Add the callback to the Trainer
trainer.add_callback(evals_callback)

# Begin Trainer training as normal
trainer.train()

### 在训练中查看评估样本

In [None]:
from transformers.integrations import WandbCallback
import pandas as pd


def decode_predictions(tokenizer, predictions):
    labels = tokenizer.batch_decode(predictions.label_ids)
    logits = predictions.predictions.argmax(axis=-1)
    prediction_text = tokenizer.batch_decode(logits)
    return {"labels": labels, "predictions": prediction_text}


class WandbPredictionProgressCallback(WandbCallback):
    """Custom WandbCallback to log model predictions during training.

    This callback logs model predictions and labels to a wandb.Table at each 
    logging step during training. It allows to visualize the 
    model predictions as the training progresses.

    Attributes:
        trainer (Trainer): The Hugging Face Trainer instance.
        tokenizer (AutoTokenizer): The tokenizer associated with the model.
        sample_dataset (Dataset): A subset of the validation dataset 
          for generating predictions.
        num_samples (int, optional): Number of samples to select from 
          the validation dataset for generating predictions. Defaults to 100.
        freq (int, optional): Frequency of logging. Defaults to 2.
    """

    def __init__(self, trainer, tokenizer, val_dataset,
                 num_samples=100, freq=2):
        """Initializes the WandbPredictionProgressCallback instance.

        Args:
            trainer (Trainer): The Hugging Face Trainer instance.
            tokenizer (AutoTokenizer): The tokenizer associated 
              with the model.
            val_dataset (Dataset): The validation dataset.
            num_samples (int, optional): Number of samples to select from 
              the validation dataset for generating predictions.
              Defaults to 100.
            freq (int, optional): Frequency of logging. Defaults to 2.
        """
        super().__init__()
        self.trainer = trainer
        self.tokenizer = tokenizer
        self.sample_dataset = val_dataset.select(range(num_samples))
        self.freq = freq

    def on_evaluate(self, args, state, control, **kwargs):
        super().on_evaluate(args, state, control, **kwargs)
        # control the frequency of logging by logging the predictions
        # every `freq` epochs
        if state.epoch % self.freq == 0:
            # generate predictions
            predictions = self.trainer.predict(self.sample_dataset)
            # decode predictions and labels
            predictions = decode_predictions(self.tokenizer, predictions)
            # add predictions to a wandb.Table
            predictions_df = pd.DataFrame(predictions)
            predictions_df["epoch"] = state.epoch
            records_table = self._wandb.Table(dataframe=predictions_df)
            # log the table to wandb
            self._wandb.log({"sample_predictions": records_table})


# First, instantiate the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"],
    eval_dataset=lm_datasets["validation"],
)

# Instantiate the WandbPredictionProgressCallback
progress_callback = WandbPredictionProgressCallback(
    trainer=trainer,
    tokenizer=tokenizer,
    val_dataset=lm_dataset["validation"],
    num_samples=10,
    freq=2,
)

# Add the callback to the trainer
trainer.add_callback(progress_callback)