Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gradient checkpointing, comment out mink, add autoselecting paddi… #101

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 65 additions & 20 deletions llm_unlearn_ucl/unlearn_harm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@
from accelerate import Accelerator
from datasets import load_dataset
from parse_args import parse_args
from peft import AdaLoraConfig, TaskType, get_peft_model

# from peft import AdaLoraConfig, TaskType, get_peft_model
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, Adafactor
from transformers.tokenization_utils_base import BatchEncoding
from utils import (
compute_kl,
Expand Down Expand Up @@ -133,6 +134,7 @@ def run_training_batch(
question_prefix_str: str = "",
answer_prefix_str: str = "",
):
"""
# Calculate min-k% prob score on bad_batch using the unmodified pre-trained model
mink_probs_base = compute_mink_prob(
model=pretrained_model,
Expand Down Expand Up @@ -166,6 +168,7 @@ def run_training_batch(
device=device,
compute_for_answer_only=False,
)
"""
############ GA on answer only. ############
bad_loss = get_answer_loss("ga", bad_batch, model, device=device)

Expand Down Expand Up @@ -201,22 +204,23 @@ def run_training_batch(
"bad_loss": -bad_loss,
"normal_loss": normal_loss,
"final_loss": loss,
"ratio (bad) mink unlearning/reference": np.mean(mink_probs_after_step)
/ np.mean(mink_probs_base),
"ratio (normal) mink unlearning/reference": np.mean(
mink_probs_after_step_normal
)
/ np.mean(mink_probs_base_normal),
# "ratio (bad) mink unlearning/reference": np.mean(mink_probs_after_step)
# / np.mean(mink_probs_base),
# "ratio (normal) mink unlearning/reference": np.mean(
# mink_probs_after_step_normal
# )
# / np.mean(mink_probs_base_normal),
}
)

stats = (
f"epoch: {epoch}, batch: {idx}, "
f"samples seen: {samples_count}, "
f"bad_loss: {-bad_loss:.2f}, "
f"bad_loss: {-bad_loss}, "
f"current_div_loss: {normal_loss:.2f}, "
f"ratio (bad) mink unlearning/reference: {np.mean(mink_probs_after_step)/np.mean(mink_probs_base):.3f}, "
f"ratio (normal) mink unlearning/reference: {np.mean(mink_probs_after_step_normal)/np.mean(mink_probs_base_normal):.3f}"
# f"ratio (bad) mink unlearning/reference: {np.mean(mink_probs_after_step)/np.mean(mink_probs_base):.3f}, "
# f"ratio (normal) mink unlearning/reference: {np.mean(mink_probs_after_step_normal)/np.mean(mink_probs_base_normal):.3f}"
)
logging.info(stats)
print(stats)
Expand All @@ -232,21 +236,32 @@ def main(args) -> None:
assert (
args.samples_count // args.sequential
) % args.batch_size == 0, "samples in each 'sequence' (--samples_count / --sequential) should be a multiple of batch_size."
accelerator = Accelerator() # accelerator precision can be specified if required.
if args.use_quantized:
accelerator = Accelerator(
mixed_precision="bf16"
) # accelerator precision can be specified if required.
else:
accelerator = (
Accelerator()
) # accelerator precision can be specified if required.
device = accelerator.device

print(f"Loading model {args.model_name} for training...")
if args.use_quantized:
print("QUANTIZED")
# Uncomment for quantized
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
cache_dir=args.cache_dir,
load_in_8bit=True,
torch_dtype=torch.float32,
torch_dtype=torch.bfloat16,
# load_in_8bit=True,
# torch_dtype=torch.float32,
)
model.to(device)
else:
model = AutoModelForCausalLM.from_pretrained(
args.model_name, cache_dir=args.cache_dir
args.model_name,
cache_dir=args.cache_dir,
)

print("Model loaded.")
Expand All @@ -266,6 +281,26 @@ def main(args) -> None:

tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=args.cache_dir)

model.gradient_checkpointing_enable()
print("Enabling Gradient checkpointing: ", model.model.gradient_checkpointing)
if tokenizer.pad_token_id is None:
# Fix for ValueError: Pipeline(and causalLM?) with tokenizer without pad_token cannot do batching. You can try to set it with `pipe.tokenizer.pad_token_id = model.config.eos_token_id`.
if model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
model.generation_config.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
model.generation_config.pad_token = tokenizer.eos_token
else:
assert (
False,
"Error: Pad token for the model's tokenizer not defined and could not be automatically set! check model configs in HF",
)
# TODO: should this habe a sep check? or same as tokenizer
# if model.generation_config.pad_token_id is None:
# model.generation_config.pad_token_id = model.config.eos_token_id
print(tokenizer.pad_token_id)

# Load data to unlearn.
if args.unlearning_dataset == "PKU-Alignment/PKU-SafeRLHF":
# filter entries with harmful responses and draw random samples from the remaining dataset.
Expand All @@ -286,7 +321,7 @@ def main(args) -> None:
# NOTE: full dataset like bytedance.
train_bad_dataset = full_bad_dataset

Path(args.samples_save_dir).mkdir(exist_ok=True)
Path(args.samples_save_dir).mkdir(parents=True, exist_ok=True)
bad_sample_path = f"{args.samples_save_dir}/bad_{args.samples_count if args.sequential > 0 else 'full'}_samples.json"
with open(bad_sample_path, "w") as fin:
print(f"Writing bad samples to {bad_sample_path}")
Expand Down Expand Up @@ -371,7 +406,7 @@ def main(args) -> None:
# NOTE: full dataset like bytedance.
train_bad_dataset = full_bad_dataset

Path(args.samples_save_dir).mkdir(exist_ok=True)
Path(args.samples_save_dir).mkdir(parents=True, exist_ok=True)
bad_sample_path = f"{args.samples_save_dir}/symbolic_{args.samples_count if args.sequential > 0 else 'full'}_samples.json"
with open(bad_sample_path, "w") as fin:
print(f"Writing symbolic samples to {bad_sample_path}")
Expand Down Expand Up @@ -484,7 +519,8 @@ def main(args) -> None:
)
wandb.log_artifact(data_sample_artifacts)

optimizer = AdamW(model.parameters(), lr=args.lr)
# optimizer = AdamW(model.parameters(), lr=args.lr)
optimizer = Adafactor(model.parameters(), lr=args.lr, relative_step=False)

# Prepare.
# num_training_steps = args.max_unlearn_steps
Expand Down Expand Up @@ -525,14 +561,23 @@ def main(args) -> None:
pretrained_model = AutoModelForCausalLM.from_pretrained(
args.model_name,
cache_dir=args.cache_dir,
load_in_8bit=True,
torch_dtype=torch.float32,
torch_dtype=torch.bfloat16,
# load_in_8bit=True,
# torch_dtype=torch.float32,
)
pretrained_model.to(device)
else:
pretrained_model = AutoModelForCausalLM.from_pretrained(
args.model_name, cache_dir=args.cache_dir
args.model_name,
cache_dir=args.cache_dir,
)
pretrained_model.to(device)

if pretrained_model.generation_config.pad_token_id is None:
# If model does not havea pad token, use EOS token
pretrained_model.generation_config.pad_token_id = (
pretrained_model.config.eos_token_id
)
print("Model loaded.")

print("#################### START UNLEARNING ####################")
Expand Down
Loading