## 查看训练数据的数量

In [1]:
import logging
from pathlib import Path
import os
import sys
import json
import numpy as np

import datasets
import torch

import transformers
from transformers import (
    CONFIG_MAPPING,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.trainer_callback import TrainerState
from transformers.trainer import TRAINER_STATE_NAME
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from transformers.models.gpt2.configuration_gpt2 import GPT2Config

from doremi.training_args import ModelArguments, DataTrainingArguments, FullTrainingArguments
import doremi.dataloader as data_utils
from doremi.trainer import DoReMiTrainer
import doremi.models as doremi_models
try:
    from flash_attn.models.gpt_neox import gpt_neox_config_to_gpt2_config
except Exception:
    pass

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
dataset_dir = "/home/wth/My_codes/doremi/data/slim_preprocessed/preprocessed"
domain_config_path = "/home/wth/My_codes/doremi/configs/rp_baseline_50kvocab_nopack.json"
with open(domain_config_path, "r", encoding="utf-8") as f:
    domain_config = json.load(f)

train_domain_weights_dict = domain_config["train_domain_weights"]
eval_domain_weights_dict = domain_config["eval_domain_weights"]
domain_list = list(sorted(train_domain_weights_dict.keys()))
dataset_name = "RedPajamaCommonCrawl"
cache_dir = "/home/wth/My_codes/doremi/cache"
max_train_samples = None


### 检查 DoReMiTrainer class 的功能

In [1]:
import math
import warnings
import json
import re
from pathlib import Path
import wandb
import numpy as np
from collections import defaultdict
from tqdm import tqdm
import torch
import torch.distributed as dist
from torch import nn
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from datasets import IterableDataset
from transformers import Trainer
from transformers.utils import ExplicitEnum, is_torch_tpu_available
from transformers.optimization import get_scheduler
from transformers.utils import logging
from transformers.trainer import is_sagemaker_mp_enabled
from transformers.modeling_utils import unwrap_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.trainer_utils import (
        has_length,
        denumpify_detensorize,
        EvalLoopOutput,
        enable_full_determinism,
        set_seed,
        get_last_checkpoint,
        PREFIX_CHECKPOINT_DIR
)
from transformers.trainer_pt_utils import find_batch_size

from doremi.eval_datasets import get_eval_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
domain_config_path = "/home/wth/My_codes/doremi/configs/slim_baseline_50kvocab_nopack_bucket.json"
with open(domain_config_path, "r") as f:
    domain_config = json.load(f)
print(domain_config)
train_domain_weights_dict = domain_config["train_domain_weights"]
eval_domain_weights_dict = domain_config["eval_domain_weights"]
domain_list = list(sorted(train_domain_weights_dict.keys()))
sampling_weights = torch.tensor([train_domain_weights_dict[domain] for domain in domain_list])
print("train domain weights dict: ")
print(train_domain_weights_dict)
print("eval domain weights dict: ")
print(eval_domain_weights_dict)
print("domain list: ")
print(domain_list)
print("sampling weights: ")
print(sampling_weights)

{'train_domain_weights': {'RedPajamaC4_4': 0.046629283524404826, 'RedPajamaArXiv_7': 0.00210244756631777, 'RedPajamaWikipedia_5': 0.00434940298422881, 'RedPajamaC4_6': 0.046731289066705956, 'RedPajamaC4_8': 0.04627651435728008, 'RedPajamaGithub_6': 0.004210562107207825, 'RedPajamaCommonCrawl_7': 0.036962841647729526, 'RedPajamaArXiv_4': 0.0021038643099608413, 'RedPajamaC4_2': 0.04665053467905089, 'RedPajamaStackExchange_6': 0.0042190625690662524, 'RedPajamaStackExchange_3': 0.004313984393152028, 'RedPajamaBook_1': 0.0017270105009038824, 'RedPajamaGithub_8': 0.004335235547798097, 'RedPajamaArXiv_3': 0.002106697797246984, 'RedPajamaWikipedia_1': 0.004389071806234805, 'RedPajamaArXiv_6': 0.0021265322082499815, 'RedPajamaGithub_9': 0.00421622908178011, 'RedPajamaBook_2': 0.0017270105009038824, 'RedPajamaWikipedia_2': 0.00434940298422881, 'RedPajamaStackExchange_5': 0.004251647672856892, 'RedPajamaStackExchange_2': 0.004221896056352395, 'RedPajamaC4_9': 0.04647627521095313, 'RedPajamaCommon