Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Consume much more gpt memory running eval_rm #3614

Merged
merged 5 commits into from Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 16 additions & 11 deletions model/model_eval/eval_rm.py
Expand Up @@ -7,6 +7,7 @@
from model_training.custom_datasets.ranking_collator import RankingDataCollator
from model_training.metrics import RewardMetrics
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers.trainer_utils import EvalPrediction
from utils import write_to_json
Expand All @@ -29,15 +30,16 @@ def get_ranking_dataset(dataset, split):
def batch_inference(inputs, model):
batch, cu_lens = inputs
batch = {k: v.to(model.device) for k, v in batch.items()}
logits = (
model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
)
.logits.detach()
.cpu()
.numpy()
)

with torch.no_grad():
logits = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits.detach().cpu()

if logits.dtype == torch.bfloat16:
# As of Numpy 1.21.4, NumPy does not support bfloat16 (see
# https://github.com/numpy/numpy/blob/a47ecdea856986cd60eabbd53265c2ca5916ad5d/doc/source/user/basics.types.rst ).
# Until Numpy adds bfloat16, we must convert float32.
logits = logits.to(torch.float32)
logits = logits.numpy()

labels = []
for i, (s, e) in enumerate(zip(cu_lens[:-1], cu_lens[1:])):
Expand All @@ -54,6 +56,7 @@ def batch_inference(inputs, model):
parser.add_argument("--metrics", type=str, help="metrics to evaluate", default="accuracy")
parser.add_argument("--batch_size", type=int, help="Batch Size", default=8)
parser.add_argument("--device", type=str, help="device", default="cuda")
parser.add_argument("--dtype", type=str, help="data type", default=None)
args = parser.parse_args().__dict__

if args.get("device") != "cpu":
Expand All @@ -64,7 +67,9 @@ def batch_inference(inputs, model):
model_name = args.get("model")

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name, torch_dtype="auto" if not args.dtype else args.dtype
)
model.eval()
model.to(device)
max_length = args.get("max_length") or model.config.max_position_embeddings
Expand All @@ -77,7 +82,7 @@ def batch_inference(inputs, model):
metrics = args.get("metrics").split(",")
compute_metrics = RewardMetrics(metrics)
score_dict = defaultdict(float)
for i, data in enumerate(dataset):
for i, data in enumerate(tqdm(dataset)):
eval_pred = batch_inference(data, model)
results = compute_metrics(eval_pred)
for metric in metrics:
Expand Down
6 changes: 2 additions & 4 deletions model/model_training/custom_datasets/__init__.py
Expand Up @@ -128,11 +128,9 @@ def get_one_dataset(
elif dataset_name == "gpt4all":
dataset = Gpt4All(mode=mode, cache_dir=data_path)
elif dataset_name == "prosocial_dialogue":
train = ProsocialDialogue(cache_dir=data_path, split="train")
eval = ProsocialDialogue(cache_dir=data_path, split="validation")
dataset = ProsocialDialogue(cache_dir=data_path, split="train")
elif dataset_name == "explain_prosocial":
train = ProsocialDialogueExplaination(cache_dir=data_path, split="train")
eval = ProsocialDialogueExplaination(cache_dir=data_path, split="validation")
dataset = ProsocialDialogueExplaination(cache_dir=data_path, split="train")
elif dataset_name == "soda":
dataset = SODA(data_path, **kwargs)
elif dataset_name == "soda_dialogue":
Expand Down
5 changes: 2 additions & 3 deletions model/model_training/custom_datasets/qa_datasets.py
Expand Up @@ -519,10 +519,9 @@ def __init__(self, cache_dir: str | Path, mode: str = "sft", input_max_length: i
self.mode = mode

dataset = load_dataset(
"gozfarb/ShareGPT_Vicuna_unfiltered",
"Aeala/ShareGPT_Vicuna_unfiltered",
cache_dir=cache_dir,
data_files=["ShareGPT_2023.05.02v0_unfiltered_cleaned_split.json"],
revision="7b8551404f3de5704d634e7516b9ff77be3e2700",
data_files=["ShareGPT_V4.3_unfiltered_cleaned_split.json"],
)["train"]

self.pairs = []
Expand Down