Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not merge, used only to indicate the diff between bugfix and main codebase #16

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion peft_pretraining/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def check_args_torchrun_main(args):
# just for more clear hparam logging to wandb
args.relora = None
args.lora_r = None
args.force_keep_original = False

if args.total_batch_size is None:
args.gradient_accumulation = args.gradient_accumulation or 1
Expand Down
143 changes: 49 additions & 94 deletions peft_pretraining/relora.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ class ReLoRaConfig:
lora_alpha: int
lora_dropout: float
target_modules: List[str]
keep_original_weights: bool
lora_only: bool = False
trainable_scaling: bool = False
quantize: str = None
use_double_quant: bool = False

Expand All @@ -42,8 +39,6 @@ def merge_and_reinit_functional(module):
nn.init.kaiming_uniform_(module.lora_A.weight, a=math.sqrt(5))

nn.init.zeros_(module.lora_B.weight)
if module.trainable_scaling:
nn.init.zeros_(module.scaling)


class ReLoRaModel(torch.nn.Module):
Expand All @@ -55,9 +50,7 @@ def __init__(
r=128,
lora_alpha=32,
lora_dropout=0.1,
keep_original_weights=True,
lora_only=False,
trainable_scaling=False,

quantize=None,
use_double_quant=False,
):
Expand All @@ -70,16 +63,12 @@ def __init__(
self.lora_alpha = lora_alpha
self.lora_dropout = lora_dropout
self.target_modules = target_modules
self.keep_original_weights = keep_original_weights
self.lora_only = lora_only
self.trainable_scaling = trainable_scaling

self._config = ReLoRaConfig(
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=target_modules,
keep_original_weights=keep_original_weights,
quantize=quantize,
use_double_quant=use_double_quant,
)
Expand All @@ -98,10 +87,10 @@ def __init__(
if not any(target_key in module_name for target_key in target_modules_list):
continue

weight_data = module.weight.data if keep_original_weights else None
weight_data = module.weight.data
bias_data = None
if module.bias is not None:
bias_data = module.bias.data if keep_original_weights else None
bias_data = module.bias.data

new_module = ReLoRaLinear(
module.in_features,
Expand All @@ -110,22 +99,15 @@ def __init__(
r=self.r,
lora_alpha=self.lora_alpha,
lora_dropout=self.lora_dropout,
lora_only=self.lora_only,
trainable_scaling=self.trainable_scaling,
quantize=quantize,
weight_data=weight_data,
bias_data=bias_data,
bnb_4bit_use_double_quant=use_double_quant,
)
if self.keep_original_weights:
# make lora'ed network to be exacty the same as the original network at initialization
nn.init.zeros_(new_module.lora_A.weight)
assert new_module.lora_A.bias is None
assert new_module.lora_B.bias is None

if self.lora_only:
assert not self.keep_original_weights
module.weight = None
# NOTE: these are LoRA biases, not network biases, network can stil have bias vectors
assert new_module.lora_A.bias is None
assert new_module.lora_B.bias is None

del module

Expand Down Expand Up @@ -159,14 +141,6 @@ def from_pretrained(cls, path):
config = AutoConfig.from_pretrained(path)

base_model = AutoModelForCausalLM.from_config(config)
if "keep_original" in relora_config:
print("WARNING: keep_original is deprecated. Use lora_only instead.")
print(f"keep_original: {relora_config['keep_original']}")
relora_config["lora_only"] = not relora_config.pop("keep_original")
relora_config["keep_original_weights"] = not relora_config["lora_only"]

if "trainable_scaling" not in relora_config:
relora_config["trainable_scaling"] = False

model = cls(base_model, **relora_config)

Expand All @@ -187,10 +161,8 @@ def __init__(
*,
lora_alpha: int = 1,
lora_dropout: float = 0.1,
lora_only: bool = False,
weight_data=None,
bias_data=None,
trainable_scaling: bool = False,
bias=True,
device=None,
dtype=None,
Expand All @@ -206,72 +178,58 @@ def __init__(
if r <= 0:
raise ValueError("r must be positive. If you want r == 0, use the original model.")

if lora_only:
self.weight = None
self.bias = None
# if full model weight + lora weight
if bias_data is None:
bias_data = torch.zeros(out_features, device=device, dtype=dtype, requires_grad=True) if bias else None
self.bias = nn.Parameter(bias_data) if bias else None

if weight_data is None:
# note that our trainable weight are W_a and W_b
weight_data = torch.zeros(out_features, in_features, device=device, dtype=dtype, requires_grad=False)

if quantize is None:
self.weight = nn.Parameter(weight_data, requires_grad=False)
elif quantize == "4bit":
self.weight = bnb.nn.Params4bit(
weight_data,
requires_grad=False,
compress_statistics=bnb_4bit_use_double_quant,
quant_type=bnb_4bit_quant_type,
)
elif quantize == "8bit":
logger.warning("Int8 currently does not support merge_and_reinit! It will fail")
self.weight = bnb.nn.Int8Params(
weight_data,
requires_grad=False,
)
else:
# if full model weight + lora weight
if bias_data is None:
bias_data = torch.zeros(out_features, device=device, dtype=dtype, requires_grad=True) if bias else None
self.bias = nn.Parameter(bias_data) if bias else None

if weight_data is None:
# note that our trainable weight are W_a and W_b
weight_data = torch.zeros(out_features, in_features, device=device, dtype=dtype, requires_grad=False)

if quantize is None:
self.weight = nn.Parameter(weight_data, requires_grad=False)
elif quantize == "4bit":
self.weight = bnb.nn.Params4bit(
weight_data,
requires_grad=False,
compress_statistics=bnb_4bit_use_double_quant,
quant_type=bnb_4bit_quant_type,
)
elif quantize == "8bit":
logger.warning("Int8 currently does not support merge_and_reinit! It will fail")
self.weight = bnb.nn.Int8Params(
weight_data,
requires_grad=False,
)
else:
raise ValueError(f"Unknown quantize type: {quantize}")
raise ValueError(f"Unknown quantize type: {quantize}")

self.in_features = in_features
self.out_features = out_features
self.r = r
self.lora_alpha = lora_alpha
self.lora_dropout = nn.Dropout(p=lora_dropout)
self.lora_only = lora_only
self.trainable_scaling = trainable_scaling
self.quantize = quantize

if r > 0:
self.lora_A = nn.Linear(in_features, r, bias=False)
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
self.lora_B = nn.Linear(r, out_features, bias=False)
nn.init.zeros_(self.lora_B.weight)
if trainable_scaling:
self.scaling = nn.Parameter(torch.tensor([1.]), requires_grad=True)
else:
self.scaling = self.lora_alpha / self.r

# Freezing the pre-trained weight matrix
if not self.lora_only:
self.weight.requires_grad = False

def _post_lora_scale(self):
if self.trainable_scaling:
return self.scaling.tanh()
if r == 0:
# r == 0 is used for debugging if ReLORA linear works exactly as no lora
return

self.lora_A = nn.Linear(in_features, r, bias=False)
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
self.lora_B = nn.Linear(r, out_features, bias=False)
nn.init.zeros_(self.lora_B.weight)
self.scaling = self.lora_alpha / self.r

# Freezing the pre-trained weight matrix
self.weight.requires_grad = False

def _post_lora_scale(self):
return self.scaling

@torch.no_grad()
def merge_and_reinit(self):
if self.lora_only:
print("WARNING: Skipping merge and reinit, because only lora parameters are used")
return

if not self.quantize:
self.weight.data += self.lora_B.weight @ self.lora_A.weight * self._post_lora_scale()
elif self.quantize == "4bit":
Expand Down Expand Up @@ -303,21 +261,18 @@ def merge_and_reinit(self):
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))

nn.init.zeros_(self.lora_B.weight)
if self.trainable_scaling:
nn.init.zeros_(self.scaling)

def forward(self, x: torch.Tensor):
if self.lora_only:
# just lora
return self.lora_B(self.lora_A(self.lora_dropout(x))) * self._post_lora_scale()

if self.quantize == "4bit":
result = bnb.matmul_4bit(x, self.weight.t(), bias=self.bias, quant_state=self.weight.quant_state)
elif self.quantize == "8bit":
result = bnb.matmul(x, self.weight.t(), bias=self.bias, quant_state=self.weight.quant_state)
else:
result = F.linear(x, self.weight, bias=self.bias)

if self.r > 0:
result += self.lora_B(self.lora_A(self.lora_dropout(x))) * self._post_lora_scale()
if self.r == 0:
# r == 0 is used for debugging if ReLORA linear works exactly as no lora
return result

result += self.lora_B(self.lora_A(self.lora_dropout(x))) * self._post_lora_scale()
return result
13 changes: 0 additions & 13 deletions torchrun_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,6 @@ def parse_args(args=None):
help="Use random pruning to reduce optimizer matrix internal dimensionality.")
parser.add_argument("--optimizer_magnitude_pruning", default=0.0, type=float,
help="Use magnitude pruning to reduce optimizer matrix internal dimensionality.")
parser.add_argument("--force_keep_original", default=False, type=lambda x: x.lower() == "true",
help=("Keep original model parameters even if relora is None. "
"Useful for making sure that full-LoRa model is equivalent to model+LoRa."))

parser.add_argument("--optimizer", default="Adam", help="Could be adam (for AdamW) or adam_zero for ZeroRedundancyOptimizer(AdamW)")
parser.add_argument("--lr", type=float, default=1e-4)
Expand Down Expand Up @@ -529,13 +526,6 @@ def main(args):
params_before = sum(p.numel() for p in model.parameters())

if args.use_peft:
need_linear_weight = (
args.relora is not None
or args.force_keep_original
or args.warmed_up_model is not None
)
logger.info(f"Wrapping model with LoRA ({need_linear_weight=})")

# target modules should define all linear layers from transformer block
# "attn" and "mlp" are used in LLaMA
# "attention" and "mlp" are used in Pythia
Expand All @@ -545,9 +535,6 @@ def main(args):
lora_alpha=args.lora_alpha,
lora_dropout=0.1,
target_modules=["attn", "attention", "mlp"],
trainable_scaling=args.train_scaling,
keep_original_weights=True,
lora_only=not need_linear_weight,
quantize=args.quantize,
use_double_quant=args.use_double_quant,
)
Expand Down