In [1]:
# Kill all processess on GPU
# !fuser -v /dev/nvidia* -k

# Libraries

In [2]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    %pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    %pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    %pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    %pip install --no-deps unsloth
%pip install trl==0.19.1 # Fix error: ImportError: cannot import name 'ConstantLengthDataset' from 'trl.trainer.utils'

In [2]:
from unsloth import FastLanguageModel
import os
import functools
import gc
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from pprint import pprint

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [3]:
def download_hf_model(repo_id, checkpoint):
    local_dir = repo_id.split('/')[-1]
    ignore_checkpoints = [f'checkpoint-{i}/*' for i in range(0, 2000, 25) if i != checkpoint]

    snapshot_download(
        repo_id=repo_id,
        local_dir=local_dir,
        ignore_patterns=ignore_checkpoints,
    )

    if checkpoint:
        return os.path.join(local_dir, f'checkpoint-{checkpoint}')
    return local_dir

@torch.no_grad()
def check_lora_parameters(model):
    for n, p in model.named_parameters():
        if 'lora' in n:
            print(f"- {'Name':<8}:", n)
            print(f"- {'Mean':<8}:", p.mean().item())
            print(f"- {'Min':<8}:", p.min().item())
            print(f"- {'Max':<8}:", p.max().item())
            break

@torch.no_grad()
def generate_text(model, tokenizer, prompt, max_new_tokens=50, skip_special_tokens=True):
    device = next(model.parameters()).device
    inputs = tokenizer(prompt, return_tensors='pt')
    outputs = model.generate(input_ids=inputs['input_ids'].to(device), max_new_tokens=max_new_tokens)
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=skip_special_tokens)[0])

# Config

In [4]:
# Project configuration
seed = 69
device = 'cuda'

# Model configuration
max_seq_length = 1024
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

model_configs = {
    'L1T1': {
        'lora_id': 'alxxtexxr/L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650',
        'checkpoint': 650,
    },
    'L2T1': {
        'lora_id': 'alxxtexxr/L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629',
        'checkpoint': 650,
    },
}

for key, config in model_configs.items():
    model_configs[key]['lora_dir'] = download_hf_model(config['lora_id'], config['checkpoint'])

print("Model configurations:",)
pprint(model_configs)
print()

lora_config = LoraConfig.from_pretrained(model_configs['L2T1']['lora_dir'])
print("LoRA configuration:")
pprint(lora_config.__dict__)

Model configurations:
{'L1T1': {'checkpoint': 650,
          'lora_dir': 'L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650/checkpoint-650',
          'lora_id': 'alxxtexxr/L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650'},
 'L2T1': {'checkpoint': 650,
          'lora_dir': 'L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629/checkpoint-650',
          'lora_id': 'alxxtexxr/L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629'}}

LoRA configuration:
{'_custom_modules': None,
 'alpha_pattern': {},
 'auto_mapping': None,
 'base_model_name_or_path': 'unsloth/meta-llama-3.1-8b-unsloth-bnb-4bit',
 'bias': 'none',
 'corda_config': None,
 'eva_config': None,
 'exclude_modules': None,
 'fan_in_fan_out': False,
 'inference_mode': True,
 'init_lora_weights': True,
 'layer_replication': None,
 'layers_pattern': None,
 'layers_to_transform': None,
 'loftq_config': {},
 'lora_alpha': 16,
 'lora_bias': False,
 'lora_dropout': 0,
 'megatron_config': None,
 'megatron_core': 'megatron.core',
 'modules_to_save': None,
 'peft_

# Model

## LoRA Model

### References
- https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/bnb.py
- https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py

In [5]:
class LoraLayer(nn.Module):
    def __init__(self, base_layer, rank, alpha, dropout, lora_bias, use_rslora, return_lora_output=False):
        super().__init__()
        self.base_layer = base_layer
        self.device = base_layer.weight.device
        self.alpha = alpha
        self.lora_bias = lora_bias
        self.scaling = alpha / math.sqrt(rank) if use_rslora else alpha / rank
        self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
        self.return_lora_output = return_lora_output

        # Extract input and output features from the base layer
        in_features = getattr(base_layer, 'in_features', None)
        out_features = getattr(base_layer, 'out_features', None)

        if in_features is None or out_features is None:
            raise ValueError(f"Cannot determine in_features or out_features from {base_layer}.")
        
        # LoRA decomposition: A (down-projection) and B (up-projection)
        self.lora_A = nn.Linear(in_features, rank, bias=lora_bias).to(self.device)  # Projects down
        self.lora_B = nn.Linear(rank, out_features, bias=lora_bias).to(self.device) # Projects up

        # Initialize LoRA matrices: A ~ N(0, 1/rank), B initialized to 0
        std = 1 / torch.sqrt(torch.tensor(rank).float())
        nn.init.normal_(self.lora_A.weight, mean=0.0, std=std)
        nn.init.zeros_(self.lora_B.weight)
        
    def forward(self, x):
        # Forward through base layer
        base_out = self.base_layer(x)

        print("================================================================")
        print("base_out.requires_grad:", base_out.requires_grad)
        print("base_out.grad_fn:", base_out.grad_fn)
        print()

        # LoRA transformation
        requires_conversion = not torch.is_autocast_enabled()
        if requires_conversion:
            x = x.to(self.lora_A.weight.dtype)
        lora_out = self.lora_B(self.lora_A(self.dropout(x))) * self.scaling
        if requires_conversion:
            lora_out = lora_out.to(base_out.dtype)
        
        print("lora_out.requires_grad:", lora_out.requires_grad)
        print("lora_out.grad_fn:", lora_out.grad_fn)
        print()

        output = base_out + lora_out

        if self.return_lora_output:
            return output, lora_out
        
        return output

    def load_lora_weights(self, state_dict, prefix):
        self.lora_A.weight.data = state_dict[f'{prefix}.lora_A.weight'].to(self.device)
        self.lora_B.weight.data = state_dict[f'{prefix}.lora_B.weight'].to(self.device)
        if self.lora_bias:
            self.lora_A.bias.data = state_dict[f'{prefix}.lora_A.bias'].to(self.device)
            self.lora_B.bias.data = state_dict[f'{prefix}.lora_B.bias'].to(self.device)
    
class LoraModel(nn.Module):
    def __init__(self, base_model: nn.Module, lora_config: LoraConfig, return_lora_outputs=False):
        super().__init__()
        self.base_model = base_model
        self.lora_layers = nn.ModuleDict()
        self.return_lora_outputs = return_lora_outputs

        # Wrap target layers with NeroLayer
        self._wrap_target_layers(lora_config)
    
    def _wrap_target_layers(self, lora_config):
        for module_name, module in self.base_model.named_modules():
            if isinstance(module, LoraLayer):
                # Convert module name format and store reference
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.lora_layers[module_name] = module
                continue

            if any(module_name.endswith(target_module) for target_module in lora_config.target_modules) and isinstance(module, nn.Linear):    
                parent_module, child_name = self._get_parent_module(module_name)
                lora_layer = LoraLayer(
                    module, 
                    lora_config.r, 
                    lora_config.lora_alpha, 
                    lora_config.lora_dropout, 
                    lora_config.lora_bias, 
                    lora_config.use_rslora,
                    return_lora_output=self.return_lora_outputs,
                )
                setattr(parent_module, child_name, lora_layer)

                # Store LoRA layers for weight loading
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.lora_layers[module_name] = lora_layer
    
    def _get_parent_module(self, module_name):
        parts = module_name.split('.')
        parent_module = self.base_model
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)
        return parent_module, parts[-1]

    def freeze_all(self):
        for param in self.base_model.parameters():
            param.requires_grad = False
    
    def unfreeze_all(self):
        for param in self.base_model.parameters():
            param.requires_grad = True
        
        for lora_layer in self.lora_layers.values():
            for param in lora_layer.parameters():
                param.requires_grad = True
    
    def load_lora_weights(self, lora_path):
        state_dict = load_file(lora_path)
        prefix = list(state_dict.keys())[0].rsplit('model.', 1)[0] + 'model.'
        for lora_layer_name, lora_layer in self.lora_layers.items():
            lora_layer_name = lora_layer_name.replace('__DOT__', '.')
            lora_layer_name = prefix + lora_layer_name
            if f'{lora_layer_name}.lora_A.weight' in state_dict and f'{lora_layer_name}.lora_B.weight' in state_dict:
                lora_layer.load_lora_weights(state_dict, lora_layer_name)
            else:
                # TODO: Print warning message
                pass
        print("LoRA weights loaded successfully!")
    
    def forward(self, input_ids, attention_mask=None):
        if self.return_lora_outputs:
            lora_outs = {}
            
            def _hook_fn(layer_name, module, _in, _out):
                if isinstance(_out, tuple) and len(_out) == 2:
                    layer_out, lora_out = _out
                    lora_outs[layer_name] = lora_out # Store nero_out separately
                    return layer_out # Return only layer_out to avoid breaking model flow

            # Register hooks to extract nero_out during forward pass
            hooks = []
            for layer_name, layer in self.lora_layers.items():
                hook = layer.register_forward_hook(functools.partial(_hook_fn, layer_name))
                hooks.append(hook)
        
            try:
                output = self.base_model(input_ids, attention_mask=attention_mask)
            finally:
                # Remove hooks after forward pass, ensuring it's done even if an error occurs
                for hook in hooks:
                    hook.remove()

            return output, lora_outs
        
        return self.base_model(input_ids, attention_mask=attention_mask)
    
    def __getattr__(self, name):
        try:
            return super().__getattr__(name) # Try getting attribute from self
        except AttributeError:
            return getattr(self.base_model, name) # Fallback to base_model

base_model_1 = AutoModelForCausalLM.from_pretrained(lora_config.base_model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(lora_config.base_model_name_or_path)
lora_model = LoraModel(base_model_1, lora_config, return_lora_outputs=True)

In [6]:
print("Check LoRA parameters (unloaded):")
check_lora_parameters(lora_model)
print()

lora_path = os.path.join(model_configs['L1T1']['lora_dir'], 'adapter_model.safetensors')
lora_model.load_lora_weights(lora_path)
print()

print("Check LoRA parameters (loaded):")
check_lora_parameters(lora_model)

Check LoRA parameters (unloaded):
- Name    : base_model.model.layers.0.self_attn.q_proj.lora_A.weight
- Mean    : 0.0014135234523564577
- Min     : -1.449187994003296
- Max     : 1.4309760332107544

LoRA weights loaded successfully!

Check LoRA parameters (loaded):
- Name    : base_model.model.layers.0.self_attn.q_proj.lora_A.weight
- Mean    : 6.287686119321734e-05
- Min     : -0.04176201671361923
- Max     : 0.04242725297808647


In [8]:
lora_model.freeze_all()

In [9]:
lora_model.train()
device = next(lora_model.parameters()).device
inputs = tokenizer("Preheat the oven to 350 degrees and place the cookie dough", return_tensors='pt')
lora_model_outs = lora_model(input_ids=inputs['input_ids'].to(device))

base_out.requires_grad: False
base_out.grad_fn: None

lora_out.requires_grad: False
lora_out.grad_fn: None

base_out.requires_grad: False
base_out.grad_fn: None

lora_out.requires_grad: False
lora_out.grad_fn: None

base_out.requires_grad: False
base_out.grad_fn: None

lora_out.requires_grad: False
lora_out.grad_fn: None

base_out.requires_grad: False
base_out.grad_fn: None

lora_out.requires_grad: False
lora_out.grad_fn: None

base_out.requires_grad: False
base_out.grad_fn: None

lora_out.requires_grad: False
lora_out.grad_fn: None

base_out.requires_grad: False
base_out.grad_fn: None

lora_out.requires_grad: False
lora_out.grad_fn: None

base_out.requires_grad: False
base_out.grad_fn: None

lora_out.requires_grad: False
lora_out.grad_fn: None

base_out.requires_grad: False
base_out.grad_fn: None

lora_out.requires_grad: False
lora_out.grad_fn: None

base_out.requires_grad: False
base_out.grad_fn: None

lora_out.requires_grad: False
lora_out.grad_fn: None

base_out.requires_grad: Fals

In [11]:
lora_model_outs[1]['layers__DOT__0__DOT__self_attn__DOT__q_proj']

tensor([[[-3.3689e-04,  9.8288e-05,  1.9665e-03,  ..., -2.6093e-03,
          -2.6684e-03, -2.7065e-03],
         [ 8.0585e-04, -6.7253e-03, -6.6681e-03,  ..., -3.8910e-02,
          -3.5675e-02, -3.6743e-02],
         [-1.2369e-03, -7.9651e-03, -1.1253e-02,  ..., -2.9968e-02,
          -2.6855e-02, -2.7893e-02],
         ...,
         [ 1.1005e-03, -6.9275e-03, -4.7417e-03,  ..., -3.6438e-02,
          -3.3875e-02, -3.4790e-02],
         [-4.1902e-05, -1.1978e-02, -1.4435e-02,  ..., -5.2795e-02,
          -4.8126e-02, -4.9835e-02],
         [ 5.3215e-04, -9.1553e-03, -1.2909e-02,  ..., -4.5715e-02,
          -4.1351e-02, -4.2847e-02]]], device='cuda:0', dtype=torch.float16)

## Nero Layer

In [12]:
class NeroLayer(nn.Module):
    def __init__(self, base_layer, 
                 # LoRA parameters
                 rank, alpha, dropout, lora_bias, use_rslora, 
                 # Nero parameters
                 nero_bias=False, 
                 return_nero_output=False,
                 ):
        super().__init__()
        self.base_layer = base_layer
        self.device = base_layer.weight.device
        self.alpha = alpha
        self.lora_bias = lora_bias
        self.scaling = alpha / math.sqrt(rank) if use_rslora else alpha / rank
        self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
        self.return_nero_output = return_nero_output

        # Extract input and output features from the base layer
        in_features = getattr(base_layer, 'in_features', None)
        out_features = getattr(base_layer, 'out_features', None)

        if in_features is None or out_features is None:
            raise ValueError(f"Cannot determine in_features or out_features from {base_layer}.")
        
        # LoRA decomposition: A (down-projection) and B (up-projection)
        self.lora_A = nn.Linear(in_features, rank, bias=lora_bias).to(self.device)  # Projects down
        self.lora_B = nn.Linear(rank, out_features, bias=lora_bias).to(self.device) # Projects up

        # Initialize LoRA matrices: A ~ N(0, 1/rank), B initialized to 0
        std = 1 / torch.sqrt(torch.tensor(rank).float())
        nn.init.normal_(self.lora_A.weight, mean=0.0, std=std)
        nn.init.zeros_(self.lora_B.weight)

        # Nero decomposition: additional transformation applied to LoRA output
        self.nero_A = nn.Linear(out_features, rank, bias=nero_bias).to(self.device)
        self.nero_B = nn.Linear(rank, out_features, bias=nero_bias).to(self.device)

        # Initialize Nero matrices similarly
        nn.init.normal_(self.nero_A.weight, mean=0.0, std=std)
        nn.init.zeros_(self.nero_B.weight)
        
    def forward(self, x):
        # Forward through base layer
        base_out = self.base_layer(x)

        print("================================================================")
        print("base_out.requires_grad:", base_out.requires_grad)
        print("base_out.grad_fn:", base_out.grad_fn)
        print()

        # LoRA transformation
        requires_conversion = not torch.is_autocast_enabled()
        if requires_conversion:
            x = x.to(self.lora_A.weight.dtype)
        lora_out = self.lora_B(self.lora_A(self.dropout(x))) * self.scaling
        # if requires_conversion:
        #     lora_out = lora_out.to(base_out.dtype)

        print("lora_out.requires_grad:", lora_out.requires_grad)
        print("lora_out.grad_fn:", lora_out.grad_fn)
        print()

        nero_out = F.relu(self.nero_B(self.nero_A(self.dropout(lora_out))) * self.scaling)
        if requires_conversion:
            nero_out = nero_out.to(base_out.dtype)
        self.last_nero_out = nero_out

        print("nero_out.requires_grad:", nero_out.requires_grad)
        print("nero_out.grad_fn:", nero_out.grad_fn)
        print()

        output = base_out + nero_out

        if self.return_nero_output:
            return output, nero_out
        
        return output

    def load_lora_weights(self, state_dict, prefix):
        self.lora_A.weight.data = state_dict[f'{prefix}.lora_A.weight'].to(self.device)
        self.lora_B.weight.data = state_dict[f'{prefix}.lora_B.weight'].to(self.device)
        # self.lora_A.weight.requires_grad = False
        # self.lora_B.weight.requires_grad = False
        if self.lora_bias:
            self.lora_A.bias.data = state_dict[f'{prefix}.lora_A.bias'].to(self.device)
            self.lora_B.bias.data = state_dict[f'{prefix}.lora_B.bias'].to(self.device)
            # self.lora_A.bias.requires_grad = False
            # self.lora_B.bias.requires_grad = False
    
class NeroModel(nn.Module):
    def __init__(self, base_model: nn.Module, lora_config: LoraConfig, nero_bias: bool=False, 
                 return_nero_outputs: bool=False,
                 ):
        super().__init__()
        self.base_model = base_model
        self.nero_bias = nero_bias
        self.nero_layers = nn.ModuleDict()
        self.return_nero_outputs = return_nero_outputs

        # Wrap target layers with NeroLayer
        self._wrap_target_layers(lora_config)
        
    def _wrap_target_layers(self, lora_config):
        for module_name, module in self.base_model.named_modules():
            if isinstance(module, NeroLayer):
                # Convert module name format and store reference
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.nero_layers[module_name] = module
                continue

            if any(module_name.endswith(target_module) for target_module in lora_config.target_modules) and isinstance(module, nn.Linear):    
                parent_module, child_name = self._get_parent_module(module_name)
                nero_layer = NeroLayer(
                    module, 
                    lora_config.r, 
                    lora_config.lora_alpha, 
                    lora_config.lora_dropout, 
                    lora_config.lora_bias, 
                    lora_config.use_rslora,
                    nero_bias=self.nero_bias,
                    return_nero_output=self.return_nero_outputs,
                )
                setattr(parent_module, child_name, nero_layer)

                # Store LoRA layers for weight loading
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.nero_layers[module_name] = nero_layer
    
    def _get_parent_module(self, module_name):
        parts = module_name.split('.')
        parent_module = self.base_model
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)
        return parent_module, parts[-1]

    def freeze_all_except_nero(self):
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        for nero_layer in self.nero_layers.values():
            for param_name, param in nero_layer.named_parameters():
                if 'nero_A' in param_name or 'nero_B' in param_name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
    
    def unfreeze_all(self):
        for param in self.base_model.parameters():
            param.requires_grad = True
        
        for nero_layer in self.nero_layers.values():
            for param in nero_layer.parameters():
                param.requires_grad = True
    
    def load_lora_weights(self, lora_path):
        state_dict = load_file(lora_path)
        prefix = list(state_dict.keys())[0].rsplit('model.', 1)[0] + 'model.'
        for nero_layer_name, nero_layer in self.nero_layers.items():
            nero_layer_name = nero_layer_name.replace('__DOT__', '.')
            nero_layer_name = prefix + nero_layer_name
            if f'{nero_layer_name}.lora_A.weight' in state_dict and f'{nero_layer_name}.lora_B.weight' in state_dict:
                nero_layer.load_lora_weights(state_dict, nero_layer_name)
            else:
                # TODO: Print warning message
                pass
        print("LoRA weights loaded successfully!")
    
    def forward(self, input_ids, attention_mask=None):
        if self.return_nero_outputs:
            nero_outs = {}
            
            def _hook_fn(layer_name, module, _in, _out):
                if isinstance(_out, tuple) and len(_out) == 2:
                    layer_out, nero_out = _out
                    nero_outs[layer_name] = nero_out # Store nero_out separately
                    return layer_out # Return only layer_out to avoid breaking model flow

            # Register hooks to extract nero_out during forward pass
            hooks = []
            for layer_name, layer in self.nero_layers.items():
                hook = layer.register_forward_hook(functools.partial(_hook_fn, layer_name))
                hooks.append(hook)
        
            try:
                output = self.base_model(input_ids, attention_mask=attention_mask)
            finally:
                # Remove hooks after forward pass, ensuring it's done even if an error occurs
                for hook in hooks:
                    hook.remove()

            return output, nero_outs
        
        return self.base_model(input_ids, attention_mask=attention_mask)
    
    def __getattr__(self, name):
        try:
            return super().__getattr__(name) # Try getting attribute from self
        except AttributeError:
            return getattr(self.base_model, name) # Fallback to base_model

base_model_2 = AutoModelForCausalLM.from_pretrained(lora_config.base_model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(lora_config.base_model_name_or_path)
nero_model = NeroModel(
    base_model_2, 
    lora_config, 
    nero_bias=True, 
    return_nero_outputs=True,
)

In [13]:
print("Check LoRA parameters (unloaded):")
check_lora_parameters(nero_model)
print()

lora_path = os.path.join(model_configs['L2T1']['lora_dir'], 'adapter_model.safetensors')
nero_model.load_lora_weights(lora_path)
print()

print("Check LoRA parameters (loaded):")
check_lora_parameters(nero_model)

Check LoRA parameters (unloaded):
- Name    : base_model.model.layers.0.self_attn.q_proj.lora_A.weight
- Mean    : 0.0011637862771749496
- Min     : -1.6798632144927979
- Max     : 1.491278052330017

LoRA weights loaded successfully!

Check LoRA parameters (loaded):
- Name    : base_model.model.layers.0.self_attn.q_proj.lora_A.weight
- Mean    : 2.2152138626552187e-05
- Min     : -0.06327299773693085
- Max     : 0.0625513345003128


In [14]:
nero_model.freeze_all_except_nero()

In [15]:
nero_model.train()
device = next(nero_model.parameters()).device
inputs = tokenizer("Preheat the oven to 350 degrees and place the cookie dough", return_tensors='pt')
nero_model_outs = nero_model(input_ids=inputs['input_ids'].to(device))

base_out.requires_grad: False
base_out.grad_fn: None

lora_out.requires_grad: False
lora_out.grad_fn: None

nero_out.requires_grad: True
nero_out.grad_fn: <ToCopyBackward0 object at 0x794b183e6fb0>

base_out.requires_grad: False
base_out.grad_fn: None

lora_out.requires_grad: False
lora_out.grad_fn: None

nero_out.requires_grad: True
nero_out.grad_fn: <ToCopyBackward0 object at 0x794b183e6e60>

base_out.requires_grad: False
base_out.grad_fn: None

lora_out.requires_grad: False
lora_out.grad_fn: None

nero_out.requires_grad: True
nero_out.grad_fn: <ToCopyBackward0 object at 0x794b183e6fb0>

base_out.requires_grad: True
base_out.grad_fn: <ToCopyBackward0 object at 0x794b183e6e60>

lora_out.requires_grad: True
lora_out.grad_fn: <MulBackward0 object at 0x794b183e6fb0>

nero_out.requires_grad: True
nero_out.grad_fn: <ToCopyBackward0 object at 0x794b183e6e60>

base_out.requires_grad: True
base_out.grad_fn: <ToCopyBackward0 object at 0x794b183e40a0>

lora_out.requires_grad: True
lora_out.grad

In [16]:
nero_model_outs[1]['layers__DOT__0__DOT__self_attn__DOT__q_proj']

tensor([[[0.3853, 0.0000, 0.5693,  ..., 0.0213, 0.0000, 0.0000],
         [0.3853, 0.0000, 0.5693,  ..., 0.0213, 0.0000, 0.0000],
         [0.3853, 0.0000, 0.5693,  ..., 0.0213, 0.0000, 0.0000],
         ...,
         [0.3853, 0.0000, 0.5693,  ..., 0.0213, 0.0000, 0.0000],
         [0.3853, 0.0000, 0.5693,  ..., 0.0213, 0.0000, 0.0000],
         [0.3853, 0.0000, 0.5693,  ..., 0.0213, 0.0000, 0.0000]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)

# Training

In [17]:
def loss_func_v1(nero_outs, lora_outs):
    assert nero_outs.keys() == lora_outs.keys() # TODO: Print warning message
    total_loss = 0.0

    for layer_name in lora_outs.keys():
        # Normalized MSE loss
        mse_loss = F.mse_loss(nero_outs[layer_name], lora_outs[layer_name], reduction='sum') / torch.sum(nero_outs[layer_name] ** 2)
        total_loss += mse_loss

    return total_loss / len(lora_outs)  # Averaging loss across layers

loss = loss_func_v1(nero_model_outs[1], lora_model_outs[1])
print(loss)

tensor(1.0049, device='cuda:0', dtype=torch.float16, grad_fn=<DivBackward0>)


In [18]:
loss.backward()

In [None]:
def loss_func(pred_outs, gt_outs, lambda_reg, lora_A_list, lora_B_list):
    total_loss = 0.0
    num_layers = len(gt_outs)

    for i in range(num_layers):
        # Normalized MSE loss
        mse_loss = F.mse_loss(pred_outs[i], gt_outs[i], reduction='sum') / torch.sum(pred_outs[i] ** 2)
        
        # L2 regularization for LoRA matrices
        reg_loss = lambda_reg * (torch.norm(lora_A_list[i], p=2) ** 2 + torch.norm(lora_B_list[i], p=2) ** 2)

        total_loss += mse_loss + reg_loss

    return total_loss / num_layers  # Averaging loss across layers