In [1]:
# Kill all processess on GPU
!fuser -v /dev/nvidia* -k

# Libraries

In [2]:
%%capture
# Install required libraries (optimized for Colab/Kaggle notebooks)
import os
if 'COLAB_' not in ''.join(os.environ.keys()):
    %pip install unsloth
else:
    # Do this only in Colab notebooks and Kaggle notebooks!
    %pip install --no-deps bitsandbytes accelerate xformers==0.0.29 peft trl triton
    %pip install --no-deps cut_cross_entropy unsloth_zoo
    %pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
    %pip install --no-deps unsloth

In [3]:
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from unsloth import FastLanguageModel
from transformers import TrainingArguments 
from trl import SFTTrainer
from peft import LoraConfig
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from pprint import pprint

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


    PyTorch 2.5.1+cu121 with CUDA 1201 (you have 2.6.0+cu124)
    Python  3.11.11 (you have 3.11.11)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


🦥 Unsloth Zoo will now patch everything to make training faster!


# Config

In [4]:
# Project config
seed = 69
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Model config
max_seq_length = 1024
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

# LoRA config
hf_lora_id = 'alxxtexxr/L3.1-8B-wikipedia-en-LoRA-v20250305134947'
lora_dir = hf_lora_id.split('/')[-1]

# Model

## Base Model

In [5]:
# Download the trained LoRA adapter to the local directory
snapshot_download(
    repo_id=hf_lora_id, 
    local_dir=lora_dir, 
    # ignore_patterns='checkpoint-*/*',
)

.gitattributes:   0%|          | 0.00/1.57k [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/812 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/83.9M [00:00<?, ?B/s]

README.md:   0%|          | 0.00/5.12k [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/812 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/83.9M [00:00<?, ?B/s]

optimizer.pt:   0%|          | 0.00/43.1M [00:00<?, ?B/s]

rng_state.pth:   0%|          | 0.00/14.2k [00:00<?, ?B/s]

scheduler.pt:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

trainer_state.json:   0%|          | 0.00/16.7k [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.75k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/5.12k [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/812 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/83.9M [00:00<?, ?B/s]

optimizer.pt:   0%|          | 0.00/43.1M [00:00<?, ?B/s]

rng_state.pth:   0%|          | 0.00/14.2k [00:00<?, ?B/s]

scheduler.pt:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

trainer_state.json:   0%|          | 0.00/32.7k [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.75k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/5.12k [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/812 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/83.9M [00:00<?, ?B/s]

optimizer.pt:   0%|          | 0.00/44.3M [00:00<?, ?B/s]

rng_state.pth:   0%|          | 0.00/14.2k [00:00<?, ?B/s]

scheduler.pt:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

trainer_state.json:   0%|          | 0.00/48.9k [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.75k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/5.12k [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/812 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/83.9M [00:00<?, ?B/s]

optimizer.pt:   0%|          | 0.00/44.3M [00:00<?, ?B/s]

rng_state.pth:   0%|          | 0.00/14.2k [00:00<?, ?B/s]

scheduler.pt:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

trainer_state.json:   0%|          | 0.00/65.0k [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.75k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/5.12k [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/812 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/83.9M [00:00<?, ?B/s]

optimizer.pt:   0%|          | 0.00/44.3M [00:00<?, ?B/s]

rng_state.pth:   0%|          | 0.00/14.2k [00:00<?, ?B/s]

scheduler.pt:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

trainer_state.json:   0%|          | 0.00/81.4k [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.75k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/5.12k [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/812 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/83.9M [00:00<?, ?B/s]

optimizer.pt:   0%|          | 0.00/44.3M [00:00<?, ?B/s]

rng_state.pth:   0%|          | 0.00/14.2k [00:00<?, ?B/s]

scheduler.pt:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

trainer_state.json:   0%|          | 0.00/97.5k [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.75k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/5.12k [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/812 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/83.9M [00:00<?, ?B/s]

optimizer.pt:   0%|          | 0.00/44.3M [00:00<?, ?B/s]

rng_state.pth:   0%|          | 0.00/14.2k [00:00<?, ?B/s]

scheduler.pt:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

trainer_state.json:   0%|          | 0.00/114k [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.75k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/5.12k [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/812 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/83.9M [00:00<?, ?B/s]

optimizer.pt:   0%|          | 0.00/44.3M [00:00<?, ?B/s]

rng_state.pth:   0%|          | 0.00/14.2k [00:00<?, ?B/s]

scheduler.pt:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

trainer_state.json:   0%|          | 0.00/122k [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.75k [00:00<?, ?B/s]

(…).tfevents.1741182714.5172c9540b89.8159.0:   0%|          | 0.00/50.1k [00:00<?, ?B/s]

(…).tfevents.1741192136.40cd2355fd32.1061.0:   0%|          | 0.00/29.5k [00:00<?, ?B/s]

(…).tfevents.1741219006.8da8ca61af90.1084.0:   0%|          | 0.00/29.5k [00:00<?, ?B/s]

(…).tfevents.1741232833.506a328f2d3f.1088.0:   0%|          | 0.00/74.1k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.75k [00:00<?, ?B/s]

'/content/L3.1-8B-wikipedia-en-LoRA-v20250305134947'

In [6]:
# Load the base model and tokenizer
base_model, tokenizer = FastLanguageModel.from_pretrained(
    model_name='unsloth/Meta-Llama-3.1-8B',
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)

# base_model = base_model.to(device)
base_model.eval()

==((====))==  Unsloth 2025.3.18: Fast Llama patching. Transformers: 4.49.0.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/5.96G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/235 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096, padding_idx=128004)
    (layers): ModuleList(
      (0): LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((409

## LoRA Model

In [7]:
lora_config = LoraConfig.from_pretrained(lora_dir)
pprint(lora_config.__dict__)

{'_custom_modules': None,
 'alpha_pattern': {},
 'auto_mapping': None,
 'base_model_name_or_path': 'unsloth/meta-llama-3.1-8b-unsloth-bnb-4bit',
 'bias': 'none',
 'eva_config': None,
 'exclude_modules': None,
 'fan_in_fan_out': False,
 'inference_mode': True,
 'init_lora_weights': True,
 'layer_replication': None,
 'layers_pattern': None,
 'layers_to_transform': None,
 'loftq_config': {},
 'lora_alpha': 16,
 'lora_bias': False,
 'lora_dropout': 0,
 'megatron_config': None,
 'megatron_core': 'megatron.core',
 'modules_to_save': None,
 'peft_type': <PeftType.LORA: 'LORA'>,
 'r': 8,
 'rank_pattern': {},
 'revision': None,
 'runtime_config': LoraRuntimeConfig(ephemeral_gpu_offload=False),
 'target_modules': {'down_proj',
                    'gate_proj',
                    'k_proj',
                    'o_proj',
                    'q_proj',
                    'up_proj',
                    'v_proj'},
 'task_type': 'CAUSAL_LM',
 'use_dora': False,
 'use_rslora': False}


### References
- https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/bnb.py
- https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py

In [8]:
class LoraLayer(nn.Module):
    def __init__(self, base_layer, rank, alpha, dropout, lora_bias, use_rslora):
        super().__init__()
        self.base_layer = base_layer
        self.device = base_layer.weight.device
        self.alpha = alpha
        self.lora_bias = lora_bias
        self.scaling = alpha / math.sqrt(rank) if use_rslora else alpha / rank
        self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()

        # Extract input and output features from the base layer
        in_features = getattr(base_layer, 'in_features', None)
        out_features = getattr(base_layer, 'out_features', None)

        if in_features is None or out_features is None:
            raise ValueError(f"Cannot determine in_features or out_features from {base_layer}")
        
        # LoRA decomposition: A (down-projection) and B (up-projection)
        self.lora_A = nn.Linear(in_features, rank, bias=lora_bias).to(self.device)  # Projects down
        self.lora_B = nn.Linear(rank, out_features, bias=lora_bias).to(self.device) # Projects up

        # Initialize LoRA matrices: A ~ N(0, 1/rank), B initialized to 0
        std = 1 / torch.sqrt(torch.tensor(rank).float())
        nn.init.normal_(self.lora_A.weight, mean=0.0, std=std)
        nn.init.zeros_(self.lora_B.weight)
        
    def forward(self, x):
        # Forward through base layer
        base_out = self.base_layer(x)

        # LoRA transformation
        requires_conversion = not torch.is_autocast_enabled()
        if requires_conversion:
            x = x.to(self.lora_A.weight.dtype)
        lora_out = self.lora_B(self.lora_A(self.dropout(x))) * self.scaling
        if requires_conversion:
            lora_out = lora_out.to(base_out.dtype)

        return base_out + lora_out

    def load_lora_weights(self, state_dict, prefix):
        self.lora_A.weight.data = state_dict[f'{prefix}.lora_A.weight'].to(self.device)
        self.lora_B.weight.data = state_dict[f'{prefix}.lora_B.weight'].to(self.device)
        if self.lora_bias:
            self.lora_A.bias.data = state_dict[f'{prefix}.lora_A.bias'].to(self.device)
            self.lora_B.bias.data = state_dict[f'{prefix}.lora_B.bias'].to(self.device)
    
class LoraModel(nn.Module):
    def __init__(self, base_model: nn.Module, lora_config: LoraConfig):
        super().__init__()
        self.base_model = base_model
        self.lora_layers = nn.ModuleDict()

        # Wrap target layers with NeroLayer
        self._wrap_target_layers(lora_config)
        
        # Freeze all parameters except Nero-specific weights
        self.freeze_except_nero()
    
    def _wrap_target_layers(self, lora_config):
        for module_name, module in self.base_model.named_modules():
            if isinstance(module, LoraLayer):
                # Convert module name format and store reference
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.lora_layers[module_name] = module
                continue

            if any(module_name.endswith(target_module) for target_module in lora_config.target_modules) and isinstance(module, nn.Linear):    
                parent_module, child_name = self._get_parent_module(module_name)
                lora_layer = LoraLayer(
                    module, 
                    lora_config.r, 
                    lora_config.lora_alpha, 
                    lora_config.lora_dropout, 
                    lora_config.lora_bias, 
                    lora_config.use_rslora,
                )
                setattr(parent_module, child_name, lora_layer)

                # Store LoRA layers for weight loading
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.lora_layers[module_name] = lora_layer
    
    def _get_parent_module(self, module_name):
        parts = module_name.split('.')
        parent_module = self.base_model
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)
        return parent_module, parts[-1]

    def freeze_except_nero(self):
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        for lora_layer in self.lora_layers.values():
            for param_name, param in lora_layer.named_parameters():
                if 'lora_A' in param_name or 'lora_B' in param_name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
    
    def unfreeze_all(self):
        for param in self.base_model.parameters():
            param.requires_grad = True
        
        for lora_layer in self.lora_layers.values():
            for param in lora_layer.parameters():
                param.requires_grad = True
    
    def load_lora_weights(self, lora_path):
        state_dict = load_file(lora_path)
        prefix = list(state_dict.keys())[0].rsplit('model.', 1)[0] + 'model.'
        for lora_layer_name, lora_layer in self.lora_layers.items():
            lora_layer_name = lora_layer_name.replace('__DOT__', '.')
            lora_layer_name = prefix + lora_layer_name
            if f'{lora_layer_name}.lora_A.weight' in state_dict and f'{lora_layer_name}.lora_B.weight' in state_dict:
                lora_layer.load_lora_weights(state_dict, lora_layer_name)
            else:
                # TODO: Print a warning message
                pass
        print("LoRA weights loaded successfully!")
    
    def forward(self, input_ids, attention_mask=None):
        # Run forward pass
        out = self.base_model(input_ids, attention_mask=attention_mask)
        return out
    
    def __getattr__(self, name):
        try:
            return super().__getattr__(name) # Try getting attribute from self
        except AttributeError:
            return getattr(self.base_model, name) # Fallback to base_model

lora_model = LoraModel(base_model, lora_config)
print(lora_model)

LoraModel(
  (base_model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 4096, padding_idx=128004)
      (layers): ModuleList(
        (0): LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): LoraLayer(
              (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
              (dropout): Identity()
              (lora_A): Linear(in_features=4096, out_features=8, bias=False)
              (lora_B): Linear(in_features=8, out_features=4096, bias=False)
            )
            (k_proj): LoraLayer(
              (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=False)
              (dropout): Identity()
              (lora_A): Linear(in_features=4096, out_features=8, bias=False)
              (lora_B): Linear(in_features=8, out_features=1024, bias=False)
            )
            (v_proj): LoraLayer(
              (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=F

In [32]:
@torch.no_grad()
def check_lora_parameters(model):
    print("Check LoRA parameters:")
    for n, p in model.named_parameters():
        if 'lora' in n:
            print(f"- {'Name':<8}:", n)
            print(f"- {'Mean':<8}:", p.mean().item())
            print(f"- {'Min':<8}:", p.min().item())
            print(f"- {'Max':<8}:", p.max().item())
            break

check_lora_parameters(lora_model)

Check LoRA parameters:
- Name    : base_model.model.layers.0.self_attn.q_proj.lora_A.weight
- Mean    : 0.0013487511314451694
- Min     : -1.440632700920105
- Max     : 1.5253840684890747


In [33]:
lora_path = os.path.join(lora_dir, 'adapter_model.safetensors')
lora_model.load_lora_weights(lora_path)
lora_model.eval()

LoRA weights loaded successfully!


LoraModel(
  (base_model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 4096, padding_idx=128004)
      (layers): ModuleList(
        (0): LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): LoraLayer(
              (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
              (dropout): Identity()
              (lora_A): Linear(in_features=4096, out_features=8, bias=False)
              (lora_B): Linear(in_features=8, out_features=4096, bias=False)
            )
            (k_proj): LoraLayer(
              (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=False)
              (dropout): Identity()
              (lora_A): Linear(in_features=4096, out_features=8, bias=False)
              (lora_B): Linear(in_features=8, out_features=1024, bias=False)
            )
            (v_proj): LoraLayer(
              (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=F

In [34]:
check_lora_parameters(lora_model)

Check LoRA parameters:
- Name    : base_model.model.layers.0.self_attn.q_proj.lora_A.weight
- Mean    : 4.93806364829652e-05
- Min     : -0.03398667275905609
- Max     : 0.033232565969228745


In [36]:
@torch.no_grad()
def generate_text(model, prompt, max_new_tokens=50, skip_special_tokens=True):
    inputs = tokenizer(prompt, return_tensors='pt')
    outputs = model.generate(input_ids=inputs['input_ids'].to(device), max_new_tokens=max_new_tokens)
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=skip_special_tokens)[0])

generate_text(lora_model, prompt="Preheat the oven to 350 degrees and place the cookie dough", skip_special_tokens=False)

<|begin_of_text|>Preheat the oven to 350 degrees and place the cookie dough on the cookie sheet. Bake for 15-20 minutes or until the cookies are golden brown and crispy.
Can you bake cookies at 350?
Yes, you can bake cookies at 350 degrees Fahrenheit. The baking time will vary depending on the


## Nero Layer

In [39]:
class NeroLayer(nn.Module):
    def __init__(self, base_layer, rank, alpha, lora_bias, nero_bias=False):
        super().__init__()
        self.base_layer = base_layer
        self.alpha = alpha
        self.lora_bias = lora_bias
        self.device = base_layer.weight.device

        # Extract input and output features from the base layer
        in_features = getattr(base_layer, 'in_features', None)
        out_features = getattr(base_layer, 'out_features', None)

        if in_features is None or out_features is None:
            raise ValueError(f"Cannot determine in_features or out_features from {base_layer}")
        
        # LoRA decomposition: A (down-projection) and B (up-projection)
        self.lora_A = nn.Linear(in_features, rank, bias=lora_bias).to(dtype=torch.float16, device=self.device)  # Projects down
        self.lora_B = nn.Linear(rank, out_features, bias=lora_bias).to(dtype=torch.float16, device=self.device) # Projects up

        # Initialize LoRA matrices: A ~ N(0, 1/rank), B initialized to 0
        std = 1 / torch.sqrt(torch.tensor(rank).float())
        nn.init.normal_(self.lora_A.weight, mean=0.0, std=std)
        nn.init.zeros_(self.lora_B.weight)

        # Nero decomposition: additional transformation applied to LoRA output
        self.nero_A = nn.Linear(out_features, rank, bias=nero_bias).to(dtype=torch.float16, device=self.device)
        self.nero_B = nn.Linear(rank, out_features, bias=nero_bias).to(dtype=torch.float16, device=self.device)

        # Initialize Nero matrices similarly
        nn.init.normal_(self.nero_A.weight, mean=0.0, std=std)
        nn.init.zeros_(self.nero_B.weight)
        
    def forward(self, x):
        # Forward through base layer
        base_out = self.base_layer(x)

        # LoRA transformation
        # print(f'{x.dtype=}')
        lora_A_out = self.lora_A(x)
        # print(f'{lora_A_out.dtype=}')
        lora_B_out = self.lora_B(lora_A_out)
        # print(f'{lora_B_out.dtype=}')
        lora_out = self.alpha * lora_B_out
        # print(f'{lora_out.dtype=}')

        # Nero transformation (applied on top of LoRA output)
        nero_A_out = self.nero_A(lora_out)
        nero_B_out = self.nero_B(nero_A_out)
        nero_out = F.relu(self.alpha * nero_B_out)
        
        return base_out + nero_out, nero_out

    def load_lora_weights(self, state_dict, prefix):
        self.lora_A.weight.data = state_dict[f'{prefix}.lora_A.weight'].to(dtype=torch.float16, device=self.device)
        self.lora_B.weight.data = state_dict[f'{prefix}.lora_B.weight'].to(dtype=torch.float16, device=self.device)
        if self.lora_bias:
            self.lora_A.bias.data = state_dict[f'{prefix}.lora_A.bias'].to(dtype=torch.float16, device=self.device)
            self.lora_B.bias.data = state_dict[f'{prefix}.lora_B.bias'].to(dtype=torch.float16, device=self.device)
    
class NeroModel(nn.Module):
    def __init__(self, base_model: nn.Module, lora_config: LoraConfig, nero_bias: bool=False):
        super().__init__()
        self.base_model = base_model
        self.nero_bias = nero_bias
        self.nero_layers = nn.ModuleDict()

        # Wrap target layers with NeroLayer
        self._wrap_target_layers(lora_config)
        
        # Freeze all parameters except Nero-specific weights
        self.freeze_except_nero()
    
    def _wrap_target_layers(self, lora_config):
        for module_name, module in self.base_model.named_modules():
            if isinstance(module, NeroLayer):
                # Convert module name format and store reference
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.nero_layers[module_name] = module
                continue

            if any(module_name.endswith(target_module) for target_module in lora_config.target_modules) and isinstance(module, nn.Linear):    
                parent_module, child_name = self._get_parent_module(module_name)
                nero_layer = NeroLayer(module, lora_config.r, lora_config.lora_alpha, lora_config.lora_bias, nero_bias=self.nero_bias)
                setattr(parent_module, child_name, nero_layer)

                # Store LoRA layers for weight loading
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.nero_layers[module_name] = nero_layer
    
    def _get_parent_module(self, module_name):
        parts = module_name.split('.')
        parent_module = self.base_model
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)
        return parent_module, parts[-1]

    def freeze_except_nero(self):
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        for nero_layer in self.nero_layers.values():
            for param_name, param in nero_layer.named_parameters():
                if 'nero_A' in param_name or 'nero_B' in param_name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
    
    def unfreeze_all(self):
        for param in self.base_model.parameters():
            param.requires_grad = True
        
        for nero_layer in self.nero_layers.values():
            for param in nero_layer.parameters():
                param.requires_grad = True
    
    def load_lora_weights(self, lora_path):
        state_dict = load_file(lora_path)
        prefix = list(state_dict.keys())[0].rsplit('model.', 1)[0] + 'model.'
        for nero_layer_name, nero_layer in self.nero_layers.items():
            nero_layer_name = nero_layer_name.replace('__DOT__', '.')
            nero_layer_name = prefix + nero_layer_name
            if f'{nero_layer_name}.lora_A.weight' in state_dict and f'{nero_layer_name}.lora_B.weight' in state_dict:
                nero_layer.load_lora_weights(state_dict, nero_layer_name)
            else:
                # TODO: Print a warning message
                pass
        print("LoRA weights loaded successfully!")
    
    def forward(self, input_ids, attention_mask=None):
        nero_outs = {}
        
        def _hook_fn(module, _in, _out):
            layer_name = module._name # Assume each layer has a unique name
            if isinstance(_out, tuple) and len(_out) == 2:
                layer_out, nero_out = _out
                nero_outs[layer_name] = nero_out # Store nero_out separately
                return layer_out # Return only layer_out to avoid breaking model flow

        # Register hooks to extract nero_out during forward pass
        hooks = []
        for layer_name, layer in self.nero_layers.items():
            layer._name = layer_name # Assign unique name to each layer
            hook = layer.register_forward_hook(_hook_fn)
            hooks.append(hook)
        
        # Run forward pass
        out = self.base_model(input_ids, attention_mask=attention_mask)

        # Remove hooks after forward pass
        for hook in hooks:
            hook.remove()

        return out, nero_outs # Return both main output and collected nero_outs

nero_model = NeroModel(base_model, lora_config, nero_bias=True)
print(nero_model)

NeroModel(
  (base_model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 4096, padding_idx=128004)
      (layers): ModuleList(
        (0): LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): LoraLayer(
              (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
              (lora_A): Linear(in_features=4096, out_features=8, bias=False)
              (lora_B): Linear(in_features=8, out_features=4096, bias=False)
            )
            (k_proj): LoraLayer(
              (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=False)
              (lora_A): Linear(in_features=4096, out_features=8, bias=False)
              (lora_B): Linear(in_features=8, out_features=1024, bias=False)
            )
            (v_proj): LoraLayer(
              (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=False)
              (lora_A): Linear(in_features=4096, out_features=8, b