diff --git a/peft_pretraining/args_utils.py b/peft_pretraining/args_utils.py index b31844a..24a2576 100644 --- a/peft_pretraining/args_utils.py +++ b/peft_pretraining/args_utils.py @@ -38,7 +38,6 @@ def check_args_torchrun_main(args): # just for more clear hparam logging to wandb args.relora = None args.lora_r = None - args.force_keep_original = False if args.total_batch_size is None: args.gradient_accumulation = args.gradient_accumulation or 1 diff --git a/peft_pretraining/relora.py b/peft_pretraining/relora.py index 23cd951..b53b0ff 100644 --- a/peft_pretraining/relora.py +++ b/peft_pretraining/relora.py @@ -21,9 +21,6 @@ class ReLoRaConfig: lora_alpha: int lora_dropout: float target_modules: List[str] - keep_original_weights: bool - lora_only: bool = False - trainable_scaling: bool = False quantize: str = None use_double_quant: bool = False @@ -42,8 +39,6 @@ def merge_and_reinit_functional(module): nn.init.kaiming_uniform_(module.lora_A.weight, a=math.sqrt(5)) nn.init.zeros_(module.lora_B.weight) - if module.trainable_scaling: - nn.init.zeros_(module.scaling) class ReLoRaModel(torch.nn.Module): @@ -55,9 +50,7 @@ def __init__( r=128, lora_alpha=32, lora_dropout=0.1, - keep_original_weights=True, - lora_only=False, - trainable_scaling=False, + quantize=None, use_double_quant=False, ): @@ -70,16 +63,12 @@ def __init__( self.lora_alpha = lora_alpha self.lora_dropout = lora_dropout self.target_modules = target_modules - self.keep_original_weights = keep_original_weights - self.lora_only = lora_only - self.trainable_scaling = trainable_scaling self._config = ReLoRaConfig( r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, target_modules=target_modules, - keep_original_weights=keep_original_weights, quantize=quantize, use_double_quant=use_double_quant, ) @@ -98,10 +87,10 @@ def __init__( if not any(target_key in module_name for target_key in target_modules_list): continue - weight_data = module.weight.data if keep_original_weights else None + weight_data = module.weight.data bias_data = None if module.bias is not None: - bias_data = module.bias.data if keep_original_weights else None + bias_data = module.bias.data new_module = ReLoRaLinear( module.in_features, @@ -110,22 +99,15 @@ def __init__( r=self.r, lora_alpha=self.lora_alpha, lora_dropout=self.lora_dropout, - lora_only=self.lora_only, - trainable_scaling=self.trainable_scaling, quantize=quantize, weight_data=weight_data, bias_data=bias_data, bnb_4bit_use_double_quant=use_double_quant, ) - if self.keep_original_weights: - # make lora'ed network to be exacty the same as the original network at initialization - nn.init.zeros_(new_module.lora_A.weight) - assert new_module.lora_A.bias is None - assert new_module.lora_B.bias is None - if self.lora_only: - assert not self.keep_original_weights - module.weight = None + # NOTE: these are LoRA biases, not network biases, network can stil have bias vectors + assert new_module.lora_A.bias is None + assert new_module.lora_B.bias is None del module @@ -159,14 +141,6 @@ def from_pretrained(cls, path): config = AutoConfig.from_pretrained(path) base_model = AutoModelForCausalLM.from_config(config) - if "keep_original" in relora_config: - print("WARNING: keep_original is deprecated. Use lora_only instead.") - print(f"keep_original: {relora_config['keep_original']}") - relora_config["lora_only"] = not relora_config.pop("keep_original") - relora_config["keep_original_weights"] = not relora_config["lora_only"] - - if "trainable_scaling" not in relora_config: - relora_config["trainable_scaling"] = False model = cls(base_model, **relora_config) @@ -187,10 +161,8 @@ def __init__( *, lora_alpha: int = 1, lora_dropout: float = 0.1, - lora_only: bool = False, weight_data=None, bias_data=None, - trainable_scaling: bool = False, bias=True, device=None, dtype=None, @@ -206,72 +178,58 @@ def __init__( if r <= 0: raise ValueError("r must be positive. If you want r == 0, use the original model.") - if lora_only: - self.weight = None - self.bias = None + # if full model weight + lora weight + if bias_data is None: + bias_data = torch.zeros(out_features, device=device, dtype=dtype, requires_grad=True) if bias else None + self.bias = nn.Parameter(bias_data) if bias else None + + if weight_data is None: + # note that our trainable weight are W_a and W_b + weight_data = torch.zeros(out_features, in_features, device=device, dtype=dtype, requires_grad=False) + + if quantize is None: + self.weight = nn.Parameter(weight_data, requires_grad=False) + elif quantize == "4bit": + self.weight = bnb.nn.Params4bit( + weight_data, + requires_grad=False, + compress_statistics=bnb_4bit_use_double_quant, + quant_type=bnb_4bit_quant_type, + ) + elif quantize == "8bit": + logger.warning("Int8 currently does not support merge_and_reinit! It will fail") + self.weight = bnb.nn.Int8Params( + weight_data, + requires_grad=False, + ) else: - # if full model weight + lora weight - if bias_data is None: - bias_data = torch.zeros(out_features, device=device, dtype=dtype, requires_grad=True) if bias else None - self.bias = nn.Parameter(bias_data) if bias else None - - if weight_data is None: - # note that our trainable weight are W_a and W_b - weight_data = torch.zeros(out_features, in_features, device=device, dtype=dtype, requires_grad=False) - - if quantize is None: - self.weight = nn.Parameter(weight_data, requires_grad=False) - elif quantize == "4bit": - self.weight = bnb.nn.Params4bit( - weight_data, - requires_grad=False, - compress_statistics=bnb_4bit_use_double_quant, - quant_type=bnb_4bit_quant_type, - ) - elif quantize == "8bit": - logger.warning("Int8 currently does not support merge_and_reinit! It will fail") - self.weight = bnb.nn.Int8Params( - weight_data, - requires_grad=False, - ) - else: - raise ValueError(f"Unknown quantize type: {quantize}") + raise ValueError(f"Unknown quantize type: {quantize}") self.in_features = in_features self.out_features = out_features self.r = r self.lora_alpha = lora_alpha self.lora_dropout = nn.Dropout(p=lora_dropout) - self.lora_only = lora_only - self.trainable_scaling = trainable_scaling self.quantize = quantize - if r > 0: - self.lora_A = nn.Linear(in_features, r, bias=False) - nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) - self.lora_B = nn.Linear(r, out_features, bias=False) - nn.init.zeros_(self.lora_B.weight) - if trainable_scaling: - self.scaling = nn.Parameter(torch.tensor([1.]), requires_grad=True) - else: - self.scaling = self.lora_alpha / self.r - - # Freezing the pre-trained weight matrix - if not self.lora_only: - self.weight.requires_grad = False - - def _post_lora_scale(self): - if self.trainable_scaling: - return self.scaling.tanh() + if r == 0: + # r == 0 is used for debugging if ReLORA linear works exactly as no lora + return + self.lora_A = nn.Linear(in_features, r, bias=False) + nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) + self.lora_B = nn.Linear(r, out_features, bias=False) + nn.init.zeros_(self.lora_B.weight) + self.scaling = self.lora_alpha / self.r + + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + + def _post_lora_scale(self): return self.scaling @torch.no_grad() def merge_and_reinit(self): - if self.lora_only: - print("WARNING: Skipping merge and reinit, because only lora parameters are used") - return - if not self.quantize: self.weight.data += self.lora_B.weight @ self.lora_A.weight * self._post_lora_scale() elif self.quantize == "4bit": @@ -303,14 +261,8 @@ def merge_and_reinit(self): nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) nn.init.zeros_(self.lora_B.weight) - if self.trainable_scaling: - nn.init.zeros_(self.scaling) def forward(self, x: torch.Tensor): - if self.lora_only: - # just lora - return self.lora_B(self.lora_A(self.lora_dropout(x))) * self._post_lora_scale() - if self.quantize == "4bit": result = bnb.matmul_4bit(x, self.weight.t(), bias=self.bias, quant_state=self.weight.quant_state) elif self.quantize == "8bit": @@ -318,6 +270,9 @@ def forward(self, x: torch.Tensor): else: result = F.linear(x, self.weight, bias=self.bias) - if self.r > 0: - result += self.lora_B(self.lora_A(self.lora_dropout(x))) * self._post_lora_scale() + if self.r == 0: + # r == 0 is used for debugging if ReLORA linear works exactly as no lora + return result + + result += self.lora_B(self.lora_A(self.lora_dropout(x))) * self._post_lora_scale() return result diff --git a/torchrun_main.py b/torchrun_main.py index a4263d2..c2a4232 100644 --- a/torchrun_main.py +++ b/torchrun_main.py @@ -85,9 +85,6 @@ def parse_args(args=None): help="Use random pruning to reduce optimizer matrix internal dimensionality.") parser.add_argument("--optimizer_magnitude_pruning", default=0.0, type=float, help="Use magnitude pruning to reduce optimizer matrix internal dimensionality.") - parser.add_argument("--force_keep_original", default=False, type=lambda x: x.lower() == "true", - help=("Keep original model parameters even if relora is None. " - "Useful for making sure that full-LoRa model is equivalent to model+LoRa.")) parser.add_argument("--optimizer", default="Adam", help="Could be adam (for AdamW) or adam_zero for ZeroRedundancyOptimizer(AdamW)") parser.add_argument("--lr", type=float, default=1e-4) @@ -529,13 +526,6 @@ def main(args): params_before = sum(p.numel() for p in model.parameters()) if args.use_peft: - need_linear_weight = ( - args.relora is not None - or args.force_keep_original - or args.warmed_up_model is not None - ) - logger.info(f"Wrapping model with LoRA ({need_linear_weight=})") - # target modules should define all linear layers from transformer block # "attn" and "mlp" are used in LLaMA # "attention" and "mlp" are used in Pythia @@ -545,9 +535,6 @@ def main(args): lora_alpha=args.lora_alpha, lora_dropout=0.1, target_modules=["attn", "attention", "mlp"], - trainable_scaling=args.train_scaling, - keep_original_weights=True, - lora_only=not need_linear_weight, quantize=args.quantize, use_double_quant=args.use_double_quant, )