Skip to content

Commit

Permalink
Add gradient checkpointing, comment out mink, add autoselecting paddi…
Browse files Browse the repository at this point in the history
…ng token.

Signed-off-by: Szymon Duchniewicz <szymon.duchniewicz.20@ucl.ac.uk>
  • Loading branch information
Willmish committed Jun 9, 2024
1 parent afbb8c5 commit 109016d
Showing 1 changed file with 39 additions and 13 deletions.
52 changes: 39 additions & 13 deletions llm_unlearn_ucl/unlearn_harm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
from accelerate import Accelerator
from datasets import load_dataset
from parse_args import parse_args
from peft import AdaLoraConfig, TaskType, get_peft_model
#from peft import AdaLoraConfig, TaskType, get_peft_model
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, Adafactor
from transformers.tokenization_utils_base import BatchEncoding
from utils import (
compute_kl,
Expand Down Expand Up @@ -133,6 +133,7 @@ def run_training_batch(
question_prefix_str: str = "",
answer_prefix_str: str = "",
):
"""
# Calculate min-k% prob score on bad_batch using the unmodified pre-trained model
mink_probs_base = compute_mink_prob(
model=pretrained_model,
Expand Down Expand Up @@ -166,6 +167,7 @@ def run_training_batch(
device=device,
compute_for_answer_only=False,
)
"""
############ GA on answer only. ############
bad_loss = get_answer_loss("ga", bad_batch, model, device=device)

Expand Down Expand Up @@ -201,22 +203,23 @@ def run_training_batch(
"bad_loss": -bad_loss,
"normal_loss": normal_loss,
"final_loss": loss,
"ratio (bad) mink unlearning/reference": np.mean(mink_probs_after_step)
/ np.mean(mink_probs_base),
"ratio (normal) mink unlearning/reference": np.mean(
mink_probs_after_step_normal
)
/ np.mean(mink_probs_base_normal),
#"ratio (bad) mink unlearning/reference": np.mean(mink_probs_after_step)
#/ np.mean(mink_probs_base),
#"ratio (normal) mink unlearning/reference": np.mean(
# mink_probs_after_step_normal
#)
#/ np.mean(mink_probs_base_normal),
}
)

stats = (
f"epoch: {epoch}, batch: {idx}, "
f"samples seen: {samples_count}, "
f"bad_loss: {-bad_loss:.2f}, "
f"bad_loss: {-bad_loss}, "
f"current_div_loss: {normal_loss:.2f}, "
f"ratio (bad) mink unlearning/reference: {np.mean(mink_probs_after_step)/np.mean(mink_probs_base):.3f}, "
f"ratio (normal) mink unlearning/reference: {np.mean(mink_probs_after_step_normal)/np.mean(mink_probs_base_normal):.3f}"
#f"ratio (bad) mink unlearning/reference: {np.mean(mink_probs_after_step)/np.mean(mink_probs_base):.3f}, "
#f"ratio (normal) mink unlearning/reference: {np.mean(mink_probs_after_step_normal)/np.mean(mink_probs_base_normal):.3f}"
)
logging.info(stats)
print(stats)
Expand Down Expand Up @@ -266,6 +269,23 @@ def main(args) -> None:

tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=args.cache_dir)

model.gradient_checkpointing_enable()
print("Enabling Gradient checkpointing: ", model.model.gradient_checkpointing)
if tokenizer.pad_token_id is None:
# Fix for ValueError: Pipeline(and causalLM?) with tokenizer without pad_token cannot do batching. You can try to set it with `pipe.tokenizer.pad_token_id = model.config.eos_token_id`.
if model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
model.generation_config.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token is not None:
tokenizer.pad_token =tokenizer.eos_token
model.generation_config.pad_token = tokenizer.eos_token
else:
assert(False, "Error: Pad token for the model's tokenizer not defined and could not be automatically set! check model configs in HF")
# TODO: should this habe a sep check? or same as tokenizer
#if model.generation_config.pad_token_id is None:
# model.generation_config.pad_token_id = model.config.eos_token_id
print(tokenizer.pad_token_id)

# Load data to unlearn.
if args.unlearning_dataset == "PKU-Alignment/PKU-SafeRLHF":
# filter entries with harmful responses and draw random samples from the remaining dataset.
Expand All @@ -286,7 +306,7 @@ def main(args) -> None:
# NOTE: full dataset like bytedance.
train_bad_dataset = full_bad_dataset

Path(args.samples_save_dir).mkdir(exist_ok=True)
Path(args.samples_save_dir).mkdir(parents=True, exist_ok=True)
bad_sample_path = f"{args.samples_save_dir}/bad_{args.samples_count if args.sequential > 0 else 'full'}_samples.json"
with open(bad_sample_path, "w") as fin:
print(f"Writing bad samples to {bad_sample_path}")
Expand Down Expand Up @@ -371,7 +391,7 @@ def main(args) -> None:
# NOTE: full dataset like bytedance.
train_bad_dataset = full_bad_dataset

Path(args.samples_save_dir).mkdir(exist_ok=True)
Path(args.samples_save_dir).mkdir(parents=True, exist_ok=True)
bad_sample_path = f"{args.samples_save_dir}/symbolic_{args.samples_count if args.sequential > 0 else 'full'}_samples.json"
with open(bad_sample_path, "w") as fin:
print(f"Writing symbolic samples to {bad_sample_path}")
Expand Down Expand Up @@ -484,7 +504,9 @@ def main(args) -> None:
)
wandb.log_artifact(data_sample_artifacts)

optimizer = AdamW(model.parameters(), lr=args.lr)
#optimizer = AdamW(model.parameters(), lr=args.lr)
optimizer = Adafactor(model.parameters(), lr=args.lr, relative_step=False)


# Prepare.
# num_training_steps = args.max_unlearn_steps
Expand Down Expand Up @@ -533,6 +555,10 @@ def main(args) -> None:
args.model_name, cache_dir=args.cache_dir
)
pretrained_model.to(device)

if pretrained_model.generation_config.pad_token_id is None:
# If model does not havea pad token, use EOS token
pretrained_model.generation_config.pad_token_id = pretrained_model.config.eos_token_id
print("Model loaded.")

print("#################### START UNLEARNING ####################")
Expand Down

0 comments on commit 109016d

Please sign in to comment.