In [None]:

class LoRALinear(nn.Linear):
    def __init__(self,
                 # nn.Linear parameters
                 in_features: int,
                 out_features: int,
                 bias: bool = True,
                 device=None,
                 dtype=None,
                 # LoRA parameters
                 lora_rank: int = 0,
                 lora_alpha: float = 0.0,
                 lora_dropout: float = 0.0,
                ) -> None:
        
        # Initialize the inherited class, nn.Linear
        super(LoRALinear, self).__init__(in_features, out_features, bias, device=device, dtype=dtype)

        self.has_weights_merged = False
        if lora_rank > 0:
            self.lora_dropout = nn.Dropout(lora_dropout)
            self.lora_scaling = lora_alpha / lora_rank

            # Define the LoRA matrices A and B
            self.lora_A = nn.Linear(in_features, lora_rank, bias=bias)
            self.lora_B = nn.Linear(lora_rank, out_features, bias=bias)
            
            # Make sure LoRA matrices don't get updated during training
            self.lora_A.weight.requires_grad = False
            self.lora_B.weight.requires_grad = False

            self.reset_parameters()

    def is_lora(self) -> bool:
        return hasattr(self, 'lora_A')

    def reset_parameters(self) -> None:
        nn.Linear.reset_parameters(self)
        if self.is_lora():
            # Initialize lora_A with kaiming_uniform_ using negative slope as math.sqrt(5)
            nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
            # Initialize lora_B weights to zero
            nn.init.zeros_(self.lora_B.weight)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        if self.is_lora():
            if not self.has_weights_merged:
                # Apply the LoRA adaptation to the input
                lora_out = self.lora_B(self.lora_A(input))
                lora_out = self.lora_dropout(lora_out) * self.lora_scaling
                # Add LoRA to the original nn.Linear output
                return super().forward(input) + lora_out
        return super().forward(input)

    def train(self, mode: bool = True) -> "LoRALinear":
        super().train(mode)
        if self.is_lora() and self.has_weights_merged:
            # Demerge LoRA weights if already merged
            self.has_weights_merged = False
        return self

    def eval(self) -> "LoRALinear":
        super().eval()
        if self.is_lora() and not self.has_weights_merged:
            # Merge LoRA weights when switching to evaluation mode
            self.has_weights_merged = True
        return self

    def extra_repr(self) -> str:
        out = nn.Linear.extra_repr(self)
        if self.is_lora():
            out += f', lora_rank={self.lora_A.weight.shape[0]}, lora_scaling={self.lora_scaling}, lora_dropout={self.lora_dropout.p}'
        return out

def mark_only_lora_as_trainable(model: nn.Module) -> nn.Module:
    for param in model.parameters():
        param.requires_grad = False
    # Loop through all LoRA layers and mark them as trainable
    for name, module in model.named_modules():
        if isinstance(module, LoRALinear) and module.is_lora():
            module.lora_A.weight.requires_grad = True
            module.lora_B.weight.requires_grad = True
    return model
