In [1]:
# Kill all processess on GPU
!fuser -v /dev/nvidia* -k

# Libraries

In [2]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    %pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    %pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    %pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    %pip install --no-deps unsloth
%pip install trl==0.19.1 # Fix error: ImportError: cannot import name 'ConstantLengthDataset' from 'trl.trainer.utils'

In [3]:
from unsloth import FastLanguageModel
import os
import functools
import gc
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import default_data_collator, AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from pprint import pprint

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [4]:
def download_hf_model(repo_id, checkpoint):
    local_dir = repo_id.split('/')[-1]
    ignore_checkpoints = [f'checkpoint-{i}/*' for i in range(0, 2000, 25) if i != checkpoint]

    snapshot_download(
        repo_id=repo_id,
        local_dir=local_dir,
        ignore_patterns=ignore_checkpoints,
    )

    if checkpoint:
        return os.path.join(local_dir, f'checkpoint-{checkpoint}')
    return local_dir

@torch.no_grad()
def check_lora_parameters(model):
    for n, p in model.named_parameters():
        if 'lora' in n:
            print(f"- {'Name':<8}:", n)
            print(f"- {'Mean':<8}:", p.mean().item())
            print(f"- {'Min':<8}:", p.min().item())
            print(f"- {'Max':<8}:", p.max().item())
            break

@torch.no_grad()
def generate_text(model, tokenizer, prompt, max_new_tokens=50, skip_special_tokens=True):
    device = next(model.parameters()).device
    inputs = tokenizer(prompt, return_tensors='pt')
    outputs = model.generate(input_ids=inputs['input_ids'].to(device), max_new_tokens=max_new_tokens)
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=skip_special_tokens)[0])

def load_hf_dataset_from_lora(
    lora_repo_id,
    train_size = 5000,
    test_size = 1000,
):
    # Get task and language
    task, lang, _ = lora_repo_id.split('B-')[-1].split('K-')[0].split('-')

    # Set up Hugging Face configuration
    data_id_map = {
        'wikipedia': 'wikimedia/wikipedia',
        'gsm8k': 'openai/gsm8k',
    }
    data_id = data_id_map[task]
    data_dir = f'20231101.{lang}' if task == 'wikipedia' else 'main'
    split = f'train[:{(train_size+test_size)}]'

    # Load dataset
    dataset = load_dataset(data_id, data_dir=data_dir, split=split) # TODO: Limit dataset size first
    return dataset

# Config

In [5]:
# Project configuration
seed = 69
device = 'cuda'

# Model configuration
max_seq_length = 1024
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

model_configs = {
    'L1T1': {
        'hf_lora_id': 'alxxtexxr/L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650',
        'checkpoint': 650,
        'lora_dir': None,
        'lora_path': None,
        'lora_config': None,
    },
    'L2T1': {
        'hf_lora_id': 'alxxtexxr/L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629',
        'checkpoint': 650,
        'lora_dir': None,
        'lora_path': None,
        'lora_config': None,
    },
}

for key, config in model_configs.items():
    lora_dir = download_hf_model(config['hf_lora_id'], config['checkpoint'])
    model_configs[key]['lora_dir'] = lora_dir
    model_configs[key]['lora_path'] = os.path.join(lora_dir, 'adapter_model.safetensors')
    model_configs[key]['lora_config'] = LoraConfig.from_pretrained(lora_dir)

print("Model configurations:",)
for key, config in model_configs.items():
    print(f"- {key}:")
    for config_name, config_value in config.items():
        if config_name == 'lora_config':
            continue
        print(f"{'-':>3} {config_name:<10}: {config_value}")
print()

assert model_configs['L1T1']['lora_config'].base_model_name_or_path == model_configs['L2T1']['lora_config'].base_model_name_or_path, "Base models must be the same"
base_model_name = model_configs['L1T1']['lora_config'].base_model_name_or_path
print(f"Base model name: {base_model_name}")

tokenizer = AutoTokenizer.from_pretrained(base_model_name)

Model configurations:
- L1T1:
  - hf_lora_id: alxxtexxr/L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650
  - checkpoint: 650
  - lora_dir  : L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650/checkpoint-650
  - lora_path : L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650/checkpoint-650/adapter_model.safetensors
- L2T1:
  - hf_lora_id: alxxtexxr/L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629
  - checkpoint: 650
  - lora_dir  : L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629/checkpoint-650
  - lora_path : L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629/checkpoint-650/adapter_model.safetensors

Base model name: unsloth/meta-llama-3.1-8b-unsloth-bnb-4bit


# Model

## LoRA Model

### References
- https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/bnb.py
- https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py

In [6]:
class LoraLayer(nn.Module):
    def __init__(self, base_layer, rank, alpha, dropout, lora_bias, use_rslora, 
                 return_lora_output=False, debug=False):
        super().__init__()
        self.base_layer = base_layer
        self.device = base_layer.weight.device
        self.alpha = alpha
        self.lora_bias = lora_bias
        self.scaling = alpha / math.sqrt(rank) if use_rslora else alpha / rank
        self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
        self.return_lora_output = return_lora_output
        self.debug = debug

        # Extract input and output features from the base layer
        in_features = getattr(base_layer, 'in_features', None)
        out_features = getattr(base_layer, 'out_features', None)

        if in_features is None or out_features is None:
            raise ValueError(f"Cannot determine in_features or out_features from {base_layer}.")
        
        # LoRA decomposition: A (down-projection) and B (up-projection)
        self.lora_A = nn.Linear(in_features, rank, bias=lora_bias).to(self.device)  # Projects down
        self.lora_B = nn.Linear(rank, out_features, bias=lora_bias).to(self.device) # Projects up

        # Initialize LoRA matrices: A ~ N(0, 1/rank), B initialized to 0
        std = 1 / torch.sqrt(torch.tensor(rank).float())
        nn.init.normal_(self.lora_A.weight, mean=0.0, std=std)
        nn.init.zeros_(self.lora_B.weight)
        
    def forward(self, x):
        # Forward through base layer
        base_out = self.base_layer(x)

        if self.debug:
            print("================================================================")
            print("base_out.requires_grad:", base_out.requires_grad)
            print("base_out.grad_fn:", base_out.grad_fn)
            print()

        # LoRA transformation
        requires_conversion = not torch.is_autocast_enabled()
        if requires_conversion:
            x = x.to(self.lora_A.weight.dtype)
        lora_out = self.lora_B(self.lora_A(self.dropout(x))) * self.scaling
        if requires_conversion:
            lora_out = lora_out.to(base_out.dtype)
        
        if self.debug:
            print("lora_out.requires_grad:", lora_out.requires_grad)
            print("lora_out.grad_fn:", lora_out.grad_fn)
            print()

        output = base_out + lora_out

        if self.return_lora_output:
            return output, lora_out
        
        return output

    def load_lora_weights(self, state_dict, prefix):
        self.lora_A.weight.data = state_dict[f'{prefix}.lora_A.weight'].to(self.device)
        self.lora_B.weight.data = state_dict[f'{prefix}.lora_B.weight'].to(self.device)
        if self.lora_bias:
            self.lora_A.bias.data = state_dict[f'{prefix}.lora_A.bias'].to(self.device)
            self.lora_B.bias.data = state_dict[f'{prefix}.lora_B.bias'].to(self.device)
    
class LoraModel(nn.Module):
    def __init__(self, base_model: nn.Module, lora_config: LoraConfig, 
                 return_lora_outputs: bool=False, debug: bool=False):
        super().__init__()
        self.base_model = base_model
        self.lora_layers = nn.ModuleDict()
        self.return_lora_outputs = return_lora_outputs
        self.debug = debug

        # Wrap target layers with NeroLayer
        self._wrap_target_layers(lora_config)
    
    def _wrap_target_layers(self, lora_config):
        for module_name, module in self.base_model.named_modules():
            if isinstance(module, LoraLayer):
                # Convert module name format and store reference
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.lora_layers[module_name] = module
                continue

            if any(module_name.endswith(target_module) for target_module in lora_config.target_modules) and isinstance(module, nn.Linear):    
                parent_module, child_name = self._get_parent_module(module_name)
                lora_layer = LoraLayer(
                    module, 
                    lora_config.r, 
                    lora_config.lora_alpha, 
                    lora_config.lora_dropout, 
                    lora_config.lora_bias, 
                    lora_config.use_rslora,
                    return_lora_output=self.return_lora_outputs,
                    debug=self.debug,
                )
                setattr(parent_module, child_name, lora_layer)

                # Store LoRA layers for weight loading
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.lora_layers[module_name] = lora_layer
    
    def _get_parent_module(self, module_name):
        parts = module_name.split('.')
        parent_module = self.base_model
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)
        return parent_module, parts[-1]

    def freeze_all(self):
        for param in self.base_model.parameters():
            param.requires_grad = False
    
    def unfreeze_all(self):
        for param in self.base_model.parameters():
            param.requires_grad = True
        
        for lora_layer in self.lora_layers.values():
            for param in lora_layer.parameters():
                param.requires_grad = True
    
    def load_lora_weights(self, lora_path):
        state_dict = load_file(lora_path)
        prefix = list(state_dict.keys())[0].rsplit('model.', 1)[0] + 'model.'
        for lora_layer_name, lora_layer in self.lora_layers.items():
            lora_layer_name = lora_layer_name.replace('__DOT__', '.')
            lora_layer_name = prefix + lora_layer_name
            if f'{lora_layer_name}.lora_A.weight' in state_dict and f'{lora_layer_name}.lora_B.weight' in state_dict:
                lora_layer.load_lora_weights(state_dict, lora_layer_name)
            else:
                # TODO: Print warning message
                pass
        print("LoRA weights loaded successfully!")
    
    def forward(self, input_ids, attention_mask=None):
        if self.return_lora_outputs:
            lora_outs = {}
            
            def _hook_fn(layer_name, module, _in, _out):
                if isinstance(_out, tuple) and len(_out) == 2:
                    layer_out, lora_out = _out
                    lora_outs[layer_name] = lora_out # Store nero_out separately
                    return layer_out # Return only layer_out to avoid breaking model flow

            # Register hooks to extract nero_out during forward pass
            hooks = []
            for layer_name, layer in self.lora_layers.items():
                hook = layer.register_forward_hook(functools.partial(_hook_fn, layer_name))
                hooks.append(hook)
        
            try:
                output = self.base_model(input_ids, attention_mask=attention_mask)
            finally:
                # Remove hooks after forward pass, ensuring it's done even if an error occurs
                for hook in hooks:
                    hook.remove()

            return output, lora_outs
        
        return self.base_model(input_ids, attention_mask=attention_mask)
    
    def __getattr__(self, name):
        try:
            return super().__getattr__(name) # Try getting attribute from self
        except AttributeError:
            return getattr(self.base_model, name) # Fallback to base_model

base_lora_model = AutoModelForCausalLM.from_pretrained(base_model_name)
lora_model = LoraModel(
    base_lora_model, 
    model_configs['L2T1']['lora_config'],
    return_lora_outputs=True,
    debug=False,
)

In [7]:
print("Check LoRA parameters (unloaded):")
check_lora_parameters(lora_model)
print()

lora_model.load_lora_weights(model_configs['L2T1']['lora_path'])
print()

print("Check LoRA parameters (loaded):")
check_lora_parameters(lora_model)

Check LoRA parameters (unloaded):
- Name    : base_model.model.layers.0.self_attn.q_proj.lora_A.weight
- Mean    : 0.003538851160556078
- Min     : -1.3572509288787842
- Max     : 1.3094143867492676

LoRA weights loaded successfully!

Check LoRA parameters (loaded):
- Name    : base_model.model.layers.0.self_attn.q_proj.lora_A.weight
- Mean    : 2.2152138626552187e-05
- Min     : -0.06327299773693085
- Max     : 0.0625513345003128


In [8]:
lora_model.freeze_all()

In [9]:
lora_model.train()
device = next(lora_model.parameters()).device
inputs = tokenizer("Preheat the oven to 350 degrees and place the cookie dough", return_tensors='pt')
lora_model_outs = lora_model(input_ids=inputs['input_ids'].to(device))

In [10]:
lora_model_outs[1]['layers__DOT__0__DOT__self_attn__DOT__q_proj']

tensor([[[-0.0036, -0.0022,  0.0020,  ..., -0.0032, -0.0023, -0.0026],
         [-0.0068, -0.0042,  0.0053,  ..., -0.0093, -0.0069, -0.0077],
         [-0.0037, -0.0036,  0.0024,  ..., -0.0055, -0.0041, -0.0046],
         ...,
         [-0.0059, -0.0045,  0.0033,  ..., -0.0072, -0.0053, -0.0059],
         [-0.0064, -0.0063,  0.0004,  ..., -0.0070, -0.0050, -0.0058],
         [-0.0051, -0.0008,  0.0057,  ..., -0.0058, -0.0042, -0.0048]]],
       device='cuda:0', dtype=torch.float16)

## Nero Layer

In [11]:
class NeroLayer(nn.Module):
    def __init__(self, base_layer, 
                 # LoRA parameters
                 rank, alpha, dropout, lora_bias, use_rslora, 
                 # Nero parameters
                 nero_bias=False, 
                 return_nero_output=False,
                 # For debugging 
                 debug=False,
                 module_name=None,
                 ):
        super().__init__()
        self.base_layer = base_layer
        self.device = base_layer.weight.device
        self.alpha = alpha
        self.lora_bias = lora_bias
        self.scaling = alpha / math.sqrt(rank) if use_rslora else alpha / rank
        self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
        self.return_nero_output = return_nero_output

        # For debugging
        self.debug = debug
        self.module_name = module_name

        # Extract input and output features from the base layer
        in_features = getattr(base_layer, 'in_features', None)
        out_features = getattr(base_layer, 'out_features', None)

        if in_features is None or out_features is None:
            raise ValueError(f"Cannot determine in_features or out_features from {base_layer}.")
        
        # LoRA decomposition: A (down-projection) and B (up-projection)
        self.lora_A = nn.Linear(in_features, rank, bias=lora_bias).to(self.device)  # Projects down
        self.lora_B = nn.Linear(rank, out_features, bias=lora_bias).to(self.device) # Projects up

        # Initialize LoRA matrices: A ~ N(0, 1/rank), B initialized to 0
        std = 1 / torch.sqrt(torch.tensor(rank).float())
        nn.init.normal_(self.lora_A.weight, mean=0.0, std=std)
        nn.init.zeros_(self.lora_B.weight)

        # Nero decomposition: additional transformation applied to LoRA output
        self.nero_A = nn.Linear(out_features, rank, bias=nero_bias).to(self.device)
        self.nero_B = nn.Linear(rank, out_features, bias=nero_bias).to(self.device)

        # Initialize Nero matrices similarly
        nn.init.normal_(self.nero_A.weight, mean=0.0, std=std)
        nn.init.zeros_(self.nero_B.weight)
        
    def forward(self, x):
        # Forward through base layer
        base_out = self.base_layer(x)

        if self.debug:
            print("================================================================")
            print(self.module_name)
            print("================================================================")
            print("base_out.requires_grad:", base_out.requires_grad)
            print("base_out.grad_fn:", base_out.grad_fn)
            print()

        # LoRA transformation
        requires_conversion = not torch.is_autocast_enabled()
        if requires_conversion:
            x = x.to(self.lora_A.weight.dtype)
        lora_out = self.lora_B(self.lora_A(self.dropout(x))) * self.scaling
        # if requires_conversion:
        #     lora_out = lora_out.to(base_out.dtype)

        if self.debug:
            print("lora_out.requires_grad:", lora_out.requires_grad)
            print("lora_out.grad_fn:", lora_out.grad_fn)
            print()

        # nero_out = F.relu(self.nero_B(self.nero_A(self.dropout(lora_out))) * self.scaling)
        nero_dropout_out = self.dropout(lora_out)
        nero_A_out = self.nero_A(nero_dropout_out)
        nero_B_out = self.nero_B(nero_A_out)
        nero_scaling_out = nero_B_out * self.scaling
        nero_out = F.relu(nero_scaling_out)
        if requires_conversion:
            nero_out = nero_out.to(base_out.dtype)

        if self.debug:
            print("nero_out.requires_grad:", nero_out.requires_grad)
            print("nero_out.grad_fn:", nero_out.grad_fn)
            print()

            nero_out_has_nan = torch.isnan(nero_out).any()
            if nero_out_has_nan:
                print("!!! NERO OUT HAS NAN !!!")
                print("nero_out:")
                print(nero_out)
                print()
                print("nero_scaling_out:")
                print(nero_scaling_out)
                print()
                print("nero_B_out:")
                print(nero_B_out)
                print()
                print("nero_A_out:")
                print(nero_A_out)
                print()
                print("nero_dropout_out:")
                print(nero_dropout_out)
                print()
                print("lora_out:")
                print(lora_out)
                print()

        # Add `base_out` with gradients-detached `nero_out`, 
        # so that `base_out` does not carry gradients
        nero_out_detached = nero_out.detach()

        if self.debug:
            print("nero_out_detached.requires_grad:", nero_out_detached.requires_grad)
            print("nero_out_detached.grad_fn:", nero_out_detached.grad_fn)
            print()

        output = base_out + nero_out_detached

        if self.debug:
            print("output.requires_grad:", output.requires_grad)
            print("output.grad_fn:", output.grad_fn)
            print()

        if self.return_nero_output:
            return output, nero_out
        
        return output

    def load_lora_weights(self, state_dict, prefix):
        self.lora_A.weight.data = state_dict[f'{prefix}.lora_A.weight'].to(self.device)
        self.lora_B.weight.data = state_dict[f'{prefix}.lora_B.weight'].to(self.device)
        if self.lora_bias:
            self.lora_A.bias.data = state_dict[f'{prefix}.lora_A.bias'].to(self.device)
            self.lora_B.bias.data = state_dict[f'{prefix}.lora_B.bias'].to(self.device)
    
class NeroModel(nn.Module):
    def __init__(self, base_model: nn.Module, lora_config: LoraConfig, nero_bias: bool=False, 
                 return_nero_outputs: bool=False, debug: bool=False):
        super().__init__()
        self.base_model = base_model
        self.nero_bias = nero_bias
        self.nero_layers = nn.ModuleDict()
        self.return_nero_outputs = return_nero_outputs
        self.debug = debug

        # Wrap target layers with NeroLayer
        self._wrap_target_layers(lora_config)
        
    def _wrap_target_layers(self, lora_config):
        for module_name, module in self.base_model.named_modules():
            if isinstance(module, NeroLayer):
                # Convert module name format and store reference
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.nero_layers[module_name] = module
                continue

            if any(module_name.endswith(target_module) for target_module in lora_config.target_modules) and isinstance(module, nn.Linear):    
                parent_module, child_name = self._get_parent_module(module_name)
                nero_layer = NeroLayer(
                    module, 
                    lora_config.r, 
                    lora_config.lora_alpha, 
                    lora_config.lora_dropout, 
                    lora_config.lora_bias, 
                    lora_config.use_rslora,
                    nero_bias=self.nero_bias,
                    return_nero_output=self.return_nero_outputs,
                    debug=self.debug,
                    module_name=module_name,
                )
                setattr(parent_module, child_name, nero_layer)

                # Store LoRA layers for weight loading
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.nero_layers[module_name] = nero_layer
    
    def _get_parent_module(self, module_name):
        parts = module_name.split('.')
        parent_module = self.base_model
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)
        return parent_module, parts[-1]

    def freeze_all_except_nero(self):
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        for nero_layer in self.nero_layers.values():
            for param_name, param in nero_layer.named_parameters():
                if 'nero_A' in param_name or 'nero_B' in param_name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
    
    def unfreeze_all(self):
        for param in self.base_model.parameters():
            param.requires_grad = True
        
        for nero_layer in self.nero_layers.values():
            for param in nero_layer.parameters():
                param.requires_grad = True
    
    def load_lora_weights(self, lora_path):
        state_dict = load_file(lora_path)
        prefix = list(state_dict.keys())[0].rsplit('model.', 1)[0] + 'model.'
        for nero_layer_name, nero_layer in self.nero_layers.items():
            nero_layer_name = nero_layer_name.replace('__DOT__', '.')
            nero_layer_name = prefix + nero_layer_name
            if f'{nero_layer_name}.lora_A.weight' in state_dict and f'{nero_layer_name}.lora_B.weight' in state_dict:
                nero_layer.load_lora_weights(state_dict, nero_layer_name)
            else:
                # TODO: Print warning message
                pass
        print("LoRA weights loaded successfully!")
    
    def forward(self, input_ids, attention_mask=None):
        if self.return_nero_outputs:
            nero_outs = {}
            
            def _hook_fn(layer_name, module, _in, _out):
                if isinstance(_out, tuple) and len(_out) == 2:
                    layer_out, nero_out = _out
                    nero_outs[layer_name] = nero_out # Store nero_out separately
                    return layer_out # Return only layer_out to avoid breaking model flow

            # Register hooks to extract nero_out during forward pass
            hooks = []
            for layer_name, layer in self.nero_layers.items():
                hook = layer.register_forward_hook(functools.partial(_hook_fn, layer_name))
                hooks.append(hook)
        
            try:
                output = self.base_model(input_ids, attention_mask=attention_mask)
            finally:
                # Remove hooks after forward pass, ensuring it's done even if an error occurs
                for hook in hooks:
                    hook.remove()

            return output, nero_outs
        
        return self.base_model(input_ids, attention_mask=attention_mask)
    
    def __getattr__(self, name):
        try:
            return super().__getattr__(name) # Try getting attribute from self
        except AttributeError:
            return getattr(self.base_model, name) # Fallback to base_model

base_nero_model = AutoModelForCausalLM.from_pretrained(base_model_name)
nero_model = NeroModel(
    base_nero_model, 
    model_configs['L1T1']['lora_config'], 
    nero_bias=True, 
    return_nero_outputs=True,
    debug=False,
)

In [12]:
print("Check LoRA parameters (unloaded):")
check_lora_parameters(nero_model)
print()

nero_model.load_lora_weights(model_configs['L1T1']['lora_path'])
print()

print("Check LoRA parameters (loaded):")
check_lora_parameters(nero_model)

Check LoRA parameters (unloaded):
- Name    : base_model.model.layers.0.self_attn.q_proj.lora_A.weight
- Mean    : -0.0002531199133954942
- Min     : -1.5828715562820435
- Max     : 1.3799933195114136

LoRA weights loaded successfully!

Check LoRA parameters (loaded):
- Name    : base_model.model.layers.0.self_attn.q_proj.lora_A.weight
- Mean    : 6.287686119321734e-05
- Min     : -0.04176201671361923
- Max     : 0.04242725297808647


In [13]:
nero_model.freeze_all_except_nero()

In [14]:
nero_model.train()
device = next(nero_model.parameters()).device
inputs = tokenizer("Preheat the oven to 350 degrees and place the cookie dough", return_tensors='pt')
nero_model_outs = nero_model(input_ids=inputs['input_ids'].to(device))

In [15]:
nero_model_outs[1]['layers__DOT__0__DOT__self_attn__DOT__q_proj']

tensor([[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0839],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0839],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0839],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0839],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0839],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0839]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<ToCopyBackward0>)

# Data

In [16]:
dataset = load_hf_dataset_from_lora(model_configs['L2T1']['hf_lora_id'])

In [17]:
eos_token = tokenizer.eos_token
tokenizer.pad_token = eos_token

# Tokenize dataset
def tokenize_fn(example):
    return tokenizer(example['text'])

dataset_tokenized = dataset.map(
    tokenize_fn, 
    batched=True, 
    remove_columns=dataset.column_names,
)

# Concatenate all tokens into one long stream, then split into blocks
block_size = 16

def group_texts(examples):
    concatenated = []
    for input_ids in examples['input_ids']:
        concatenated += input_ids

    total_length = len(concatenated) // block_size * block_size

    input_ids = [concatenated[i:i + block_size] for i in range(0, total_length, block_size)]
    attention_mask = [[1] * block_size for _ in input_ids]
    labels = input_ids.copy()

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels,
    }

dataset_grouped = dataset_tokenized.map(
    group_texts, 
    batched=True, 
    batch_size=1000,
    remove_columns=dataset_tokenized.column_names,
)

In [18]:
batch_size = 4
train_loader = DataLoader(
    dataset_grouped, 
    batch_size=batch_size, 
    shuffle=True, 
    collate_fn=default_data_collator,
)

# Training

In [19]:
# _debugged = False
# _layer_name = None
# _nero_out = None
# _lora_out = None
# _mse_loss_unnormed = None
# _nero_out_sum = None

# def loss_func_v1(nero_outs, lora_outs):
#     assert nero_outs.keys() == lora_outs.keys() # TODO: Print warning message
#     total_loss = 0.0

#     for layer_name in lora_outs.keys():
#         nero_out = nero_outs[layer_name]
#         lora_out = lora_outs[layer_name]

#         # Normalized MSE loss
#         # mse_loss = F.mse_loss(nero_out, lora_out, reduction='sum') / torch.sum(nero_outs[layer_name] ** 2)
#         mse_loss_unnormed = F.mse_loss(nero_out, lora_out, reduction='sum')
#         nero_out_sum = torch.sum(nero_outs[layer_name] ** 2)
#         mse_loss = mse_loss_unnormed / nero_out_sum

#         print("================================================================")
#         print(layer_name)
#         print("================================================================")
#         print("mse_loss:", mse_loss)
#         print("nero_out_sum:", nero_out_sum)
#         print("mse_loss_unnormed:", mse_loss_unnormed)
#         print()

#         global _debugged
#         if (torch.isinf(nero_out_sum) or torch.isnan(mse_loss_unnormed)) and not _debugged:
#             global _layer_name
#             global _nero_out
#             global _lora_out
#             global _mse_loss_unnormed
#             global _nero_out_sum
#             _layer_name = layer_name
#             _nero_out = nero_out
#             _lora_out = lora_out
#             _mse_loss_unnormed = mse_loss_unnormed
#             _nero_out_sum = nero_out_sum
#             _debugged = True

#         if torch.isnan(mse_loss):
#             nero_out_has_nan = torch.isnan(nero_out).any()
#             lora_out_has_nan = torch.isnan(lora_out).any()

#             print("nero_out_has_nan:", nero_out_has_nan)
#             print("lora_out_has_nan:", lora_out_has_nan)
#             print()

#             # if nero_out_has_nan:
#             print("nero_out:")
#             print(nero_out)
#             print()
            
#             # if lora_out_has_nan:
#             print("lora_out:")
#             print(lora_out)
#             print()

#         total_loss += mse_loss

#     return total_loss / len(lora_outs)  # Averaging loss across layers

# # loss = loss_func_v1(nero_model_outs[1], lora_model_outs[1])
# # print(loss)

In [20]:
# def loss_func_v2(nero_outs, lora_outs):
#     assert nero_outs.keys() == lora_outs.keys()
#     total_loss = 0.0

#     for layer_name in lora_outs.keys():
#         print("================================================================")
#         print(layer_name)
#         print("================================================================")

#         nero_out = nero_outs[layer_name]
#         lora_out = lora_outs[layer_name]

#         diff = nero_out - lora_out
#         print("mean(abs(diff)):", diff.abs().mean())
#         print("mean((diff)^2):", (diff ** 2).mean())
#         print()
        
#         print("max abs(diff):", diff.abs().max().item())
#         print("mean abs(diff):", diff.abs().mean().item())
#         print("std abs(diff):", diff.abs().std().item())
#         print()

#         # Use mean MSE to prevent overflow and keep scale uniform
#         mse_loss = F.mse_loss(nero_out.float(), lora_out.float(), reduction='mean')  # Compute in float32

#         print("mse_loss:", mse_loss)
#         print()

#         total_loss += mse_loss

#     return total_loss / len(lora_outs)


In [21]:
_debugged = False
_nero_out = None
_lora_out = None

def loss_func_v3(nero_outs, lora_outs, debug=False):
    assert nero_outs.keys() == lora_outs.keys()
    total_loss = 0.0

    for layer_name in lora_outs.keys():
        nero_out = nero_outs[layer_name]
        lora_out = lora_outs[layer_name]

        # Use mean MAE to prevent overflow and keep scale uniform
        # mae_loss = F.l1_loss(nero_out.float(), lora_out.float(), reduction='mean')  
        mae_loss = torch.mean(torch.abs(nero_out.float() - lora_out.float()))

        if debug:
            print("================================================================")
            print(layer_name)
            print("================================================================")
            
            global _debugged
            if not _debugged:
                global _nero_out
                global _lora_out
                _nero_out = nero_out
                _lora_out = lora_out
                _debugged = True

            diff = nero_out - lora_out
            print("mean(abs(diff)):", diff.abs().mean())
            print("mean((diff)^2):", (diff ** 2).mean())
            print()
            
            print("max abs(diff):", diff.abs().max().item())
            print("mean abs(diff):", diff.abs().mean().item())
            print("std abs(diff):", diff.abs().std().item())
            print()

            print("mae_loss:", mae_loss)
            print()

        total_loss += mae_loss

    return total_loss / len(lora_outs)

In [22]:
lr = 1e-4
nero_params = [p for n, p in nero_model.named_parameters() if p.requires_grad]
optimizer = torch.optim.Adam(nero_params, lr=lr)

num_epochs = 1
for epoch in range(num_epochs):
    for step, batch in enumerate(train_loader):
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        _, nero_outs = nero_model(input_ids=input_ids, attention_mask=attention_mask)
        _, lora_outs = lora_model(input_ids=input_ids, attention_mask=attention_mask)

        # print("================================================================")
        # print("nero_outs:", nero_outs)
        # print("lora_outs:", lora_outs)
        # print()

        loss = loss_func_v3(nero_outs, lora_outs, debug=False)

        # print("================================================================")
        # print("LOSS")
        # print("================================================================")
        # print("loss:", loss)

        loss.backward()
        optimizer.step()
        
        print(f"epoch: {epoch + 1}/{num_epochs}, step: {step + 1}/{len(train_loader)}, loss: {loss.item()}")
        # break

epoch: 1/1, step: 1/514310, loss: 0.18136386573314667
epoch: 1/1, step: 2/514310, loss: 0.18185503780841827
epoch: 1/1, step: 3/514310, loss: 0.18154877424240112
epoch: 1/1, step: 4/514310, loss: 0.1818685680627823
epoch: 1/1, step: 5/514310, loss: 0.1815982609987259
epoch: 1/1, step: 6/514310, loss: 0.18147484958171844
epoch: 1/1, step: 7/514310, loss: 0.18116679787635803
epoch: 1/1, step: 8/514310, loss: 0.18142275512218475
epoch: 1/1, step: 9/514310, loss: 0.18123763799667358
epoch: 1/1, step: 10/514310, loss: 0.1808454543352127
epoch: 1/1, step: 11/514310, loss: 0.18146687746047974
epoch: 1/1, step: 12/514310, loss: 0.1809486597776413
epoch: 1/1, step: 13/514310, loss: 0.18075381219387054
epoch: 1/1, step: 14/514310, loss: 0.18101638555526733
epoch: 1/1, step: 15/514310, loss: 0.18098264932632446
epoch: 1/1, step: 16/514310, loss: 0.1814201921224594
epoch: 1/1, step: 17/514310, loss: 0.181378573179245
epoch: 1/1, step: 18/514310, loss: 0.18099746108055115
epoch: 1/1, step: 19/51431

In [43]:
diff = torch.rand_like(_nero_out) - torch.rand_like(_lora_out)  # Uniform noise
mae = diff.abs().mean()
print(mae)

tensor(0.3337, device='cuda:0', dtype=torch.float16)


In [44]:
cos_sim = F.cosine_similarity(_nero_out.flatten(1), _lora_out.flatten(1), dim=-1).mean()
print("cosine similarity:", cos_sim.item())

cosine similarity: 0.0115203857421875


In [32]:
# torch.isnan(_nero_out).any() # False
# torch.isnan(_lora_out).any() # False
# torch.isinf(_nero_out).any() # False
# torch.isinf(_lora_out).any() # False
F.mse_loss(_nero_out, _lora_out, reduction='sum')

tensor(inf, device='cuda:0', dtype=torch.float16, grad_fn=<MseLossBackward0>)

In [41]:
print("_nero_out.dtype:", _nero_out.dtype)
print("_lora_out.dtype:", _lora_out.dtype)
print("_nero_out.shape:", _nero_out.shape)
print("_nero_out.abs().max():", _nero_out.abs().max())
print("_lora_out.abs().max():", _lora_out.abs().max())
print("(_nero_out - _lora_out).abs().max():", (_nero_out - _lora_out).abs().max())
print("F.mse_loss(_nero_out, _lora_out, reduction='mean'):", F.mse_loss(_nero_out, _lora_out, reduction='mean'))
print("F.mse_loss(_nero_out, _lora_out, reduction='sum'):", F.mse_loss(_nero_out, _lora_out, reduction='sum'))

_nero_out.dtype: torch.float16
_lora_out.dtype: torch.float16
_nero_out.shape: torch.Size([4, 16, 14336])
_nero_out.abs().max(): tensor(0.7070, device='cuda:0', dtype=torch.float16, grad_fn=<MaxBackward1>)
_lora_out.abs().max(): tensor(0.0555, device='cuda:0', dtype=torch.float16)
(_nero_out - _lora_out).abs().max(): tensor(0.7393, device='cuda:0', dtype=torch.float16, grad_fn=<MaxBackward1>)
F.mse_loss(_nero_out, _lora_out, reduction='mean'): tensor(0.0822, device='cuda:0', dtype=torch.float16,
       grad_fn=<MseLossBackward0>)
F.mse_loss(_nero_out, _lora_out, reduction='sum'): tensor(inf, device='cuda:0', dtype=torch.float16, grad_fn=<MseLossBackward0>)


In [24]:
print(torch.cuda.memory_allocated())

# if 'base_model' in globals():
#     # base_model.to('cpu')
#     del base_model

gc.collect()
torch.cuda.empty_cache()
print(torch.cuda.memory_allocated())

15571826176
15571826176


In [37]:
for obj in gc.get_objects():
    try:
        byte_size = obj.element_size() * obj.nelement()
        if torch.is_tensor(obj) and obj.is_cuda and byte_size > 1024**2 and byte_size < 1024**3:
            print(f"Tensor: {type(obj)}, Size: {obj.size()}, Memory: {byte_size / 1024**2:.2f} MB")
            del obj
    except:
        pass
gc.collect()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Tensor: <class 'torch.nn.parameter.Parameter'>, Size: torch.Size([128256, 4096]), Memory: 1002.00 MB
Tensor: <class 'torch.Tensor'>, Size: torch.Size([1, 14, 128256]), Memory: 3.42 MB
Tensor: <class 'torch.nn.parameter.Parameter'>, Size: torch.Size([128256, 4096]), Memory: 1002.00 MB
Tensor: <class 'torch.Tensor'>, Size: torch.Size([1, 8388608]), Memory: 8.00 MB
Tensor: <class 'torch.Tensor'>, Size: torch.Size([1, 29360128]), Memory: 28.00 MB
Tensor: <class 'torch.Tensor'>, Size: torch.Size([1, 29360128]), Memory: 28.00 MB
Tensor: <class 'torch.Tensor'>, Size: torch.Size([1, 29360128]), Memory: 28.00 MB
Tensor: <class 'torch.Tensor'>, Size: torch.Size([1, 8388608]), Memory: 8.00 MB
Tensor: <class 'torch.Tensor'>, Size: torch.Size([1, 2097152]), Memory: 2.00 MB
Tensor: <class 'torch.Tensor'>, Size: torch.Size([1, 2097152]), Memory: 2.00 MB
Tensor: <class 'torch.Tensor'>, Size: torch.Size([1, 8388608]), Memory: 8.00 MB
Tensor: <class 'torch.Tensor'>, Size: torch.Size([1, 2097152]), Memor

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

0

In [None]:
def loss_func(pred_outs, gt_outs, lambda_reg, lora_A_list, lora_B_list):
    total_loss = 0.0
    num_layers = len(gt_outs)

    for i in range(num_layers):
        # Normalized MSE loss
        mse_loss = F.mse_loss(pred_outs[i], gt_outs[i], reduction='sum') / torch.sum(pred_outs[i] ** 2)
        
        # L2 regularization for LoRA matrices
        reg_loss = lambda_reg * (torch.norm(lora_A_list[i], p=2) ** 2 + torch.norm(lora_B_list[i], p=2) ** 2)

        total_loss += mse_loss + reg_loss

    return total_loss / num_layers  # Averaging loss across layers