In [1]:
# Kill all processess on GPU
# !fuser -v /dev/nvidia* -k

# Libraries

In [2]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    %pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    %pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    %pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    %pip install --no-deps unsloth
%pip install trl==0.19.1 # Fix error: ImportError: cannot import name 'ConstantLengthDataset' from 'trl.trainer.utils'

In [3]:
import os
import functools
import gc
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch.utils.data import DataLoader
from datasets import load_dataset, Dataset
from transformers import default_data_collator, AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from pprint import pprint

In [4]:
def download_hf_model(repo_id, checkpoint=None, max_checkpoints=2000, checkpoint_step=25):
    local_dir = repo_id.split('/')[-1]
    ignore_checkpoints = None
    
    if checkpoint is not None:
        ignore_checkpoints = [f'checkpoint-{i}/*' for i in range(0, max_checkpoints, checkpoint_step) if i != checkpoint]

    snapshot_download(
        repo_id=repo_id,
        local_dir=local_dir,
        ignore_patterns=ignore_checkpoints,
    )

    if checkpoint is not None:
        return os.path.join(local_dir, f'checkpoint-{checkpoint}')
    return local_dir

@torch.no_grad()
def check_lora_parameters(model):
    for n, p in model.named_parameters():
        if 'lora' in n:
            print(f"- {'Name':<8}:", n)
            print(f"- {'Mean':<8}:", p.mean().item())
            print(f"- {'Min':<8}:", p.min().item())
            print(f"- {'Max':<8}:", p.max().item())
            break

@torch.no_grad()
def generate_text(model, tokenizer, prompt, max_new_tokens=50, skip_special_tokens=True):
    device = next(model.parameters()).device
    inputs = tokenizer(prompt, return_tensors='pt')
    outputs = model.generate(input_ids=inputs['input_ids'].to(device), max_new_tokens=max_new_tokens)
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=skip_special_tokens)[0])

def load_hf_dataset_from_lora(
    lora_repo_id,
    train_size = 5000,
    test_size = 1000,
):
    # Get task and language
    task, lang, _ = lora_repo_id.split('B-')[-1].split('K-')[0].split('-')

    # Set up Hugging Face configuration
    data_id_map = {
        'wikipedia': 'wikimedia/wikipedia',
        'gsm8k': 'openai/gsm8k',
    }
    data_id = data_id_map[task]
    data_dir = f'20231101.{lang}' if task == 'wikipedia' else 'main'
    split = f'train[:{(train_size+test_size)}]'

    # Load dataset
    # TODO: Use streaming to not download the entire dataset
    # dataset = load_dataset(data_id, data_dir=data_dir, split=split)

    # Use streaming
    dataset_stream = load_dataset(data_id, data_dir=data_dir, split='train', streaming=True)

    # Manually take train_size + test_size samples
    total_size = train_size + test_size
    sliced_data = []
    for i, example in enumerate(dataset_stream):
        if i >= total_size:
            break
        sliced_data.append(example)

    # Convert to regular in-memory dataset
    dataset = Dataset.from_list(sliced_data)
    
    return dataset

# Config

In [5]:
# Project configuration
seed = 69
lora_model_device = 'cuda:0'
nero_model_device = 'cuda:1'

# Training configuration
block_size = 96
batch_size = 4
num_epochs = 1
max_global_steps = 100
resume_step = 0
lr = 1e-4

model_configs = {
    # L1T1 (Source Language - Source Task)
    'source': {
        'label': 'L1T1',
        'hf_lora_id': 'alxxtexxr/L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650',
        'checkpoint': 650,
    },

    # L2T1 (Target Language - Source Task)
    'target': {
        'label': 'L2T1',
        'hf_lora_id': 'alxxtexxr/L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629',
        'checkpoint': 650,
    },

    # L1T2 (Source Language - Target Task)
    # 'target': {
    #     'label': 'L1T2',
    #     'hf_lora_id': 'alxxtexxr/L3.1-8B-gsm8k-en-5K-LoRA-v20250701060457',
    #     'checkpoint': 1875,
    # },
}

for key, config in model_configs.items():
    lora_dir = download_hf_model(config['hf_lora_id'], config['checkpoint'])
    model_configs[key]['lora_dir'] = lora_dir
    model_configs[key]['lora_path'] = os.path.join(lora_dir, 'adapter_model.safetensors')
    model_configs[key]['lora_config'] = LoraConfig.from_pretrained(lora_dir)

print("Model configurations:",)
for key, config in model_configs.items():
    print(f"- {key}:")
    for config_name, config_value in config.items():
        if config_name == 'lora_config':
            continue
        print(f"{'-':>3} {config_name:<10}: {config_value}")
print()

assert model_configs['source']['lora_config'].base_model_name_or_path == model_configs['target']['lora_config'].base_model_name_or_path, "Base models must be the same"
base_model_name = model_configs['source']['lora_config'].base_model_name_or_path
print(f"Base model name: {base_model_name}")

tokenizer = AutoTokenizer.from_pretrained(base_model_name)

Fetching 20 files:   0%|          | 0/20 [00:00<?, ?it/s]

Fetching 20 files:   0%|          | 0/20 [00:00<?, ?it/s]

Model configurations:
- source:
  - label     : L1T1
  - hf_lora_id: alxxtexxr/L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650
  - checkpoint: 650
  - lora_dir  : L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650/checkpoint-650
  - lora_path : L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650/checkpoint-650/adapter_model.safetensors
- target:
  - label     : L2T1
  - hf_lora_id: alxxtexxr/L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629
  - checkpoint: 650
  - lora_dir  : L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629/checkpoint-650
  - lora_path : L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629/checkpoint-650/adapter_model.safetensors

Base model name: unsloth/meta-llama-3.1-8b-unsloth-bnb-4bit


# Model

## LoRA Model

### References
- https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/bnb.py
- https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py

In [6]:
class LoraLayer(nn.Module):
    def __init__(self, base_layer, rank, alpha, dropout, lora_bias, use_rslora, 
                 return_lora_output=False, debug=False):
        super().__init__()
        self.base_layer = base_layer
        self.device = base_layer.weight.device
        self.alpha = alpha
        self.lora_bias = lora_bias
        self.scaling = alpha / math.sqrt(rank) if use_rslora else alpha / rank
        self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
        self.return_lora_output = return_lora_output
        self.debug = debug

        # Extract input and output features from the base layer
        in_features = getattr(base_layer, 'in_features', None)
        out_features = getattr(base_layer, 'out_features', None)

        if in_features is None or out_features is None:
            raise ValueError(f"Cannot determine in_features or out_features from {base_layer}.")
        
        # LoRA decomposition: A (down-projection) and B (up-projection)
        self.lora_A = nn.Linear(in_features, rank, bias=lora_bias).to(self.device)  # Projects down
        self.lora_B = nn.Linear(rank, out_features, bias=lora_bias).to(self.device) # Projects up

        # Initialize LoRA matrices: A ~ N(0, 1/rank), B initialized to 0
        std = 1 / torch.sqrt(torch.tensor(rank).float())
        nn.init.normal_(self.lora_A.weight, mean=0.0, std=std)
        nn.init.zeros_(self.lora_B.weight)
        
    def forward(self, x):
        # Forward through base layer
        base_out = self.base_layer(x)

        if self.debug:
            print("================================================================")
            print("base_out.requires_grad:", base_out.requires_grad)
            print("base_out.grad_fn:", base_out.grad_fn)
            print()

        # LoRA transformation
        requires_conversion = not torch.is_autocast_enabled()
        if requires_conversion:
            x = x.to(self.lora_A.weight.dtype)
        lora_out = self.lora_B(self.lora_A(self.dropout(x))) * self.scaling
        if requires_conversion:
            lora_out = lora_out.to(base_out.dtype)
        
        if self.debug:
            print("lora_out.requires_grad:", lora_out.requires_grad)
            print("lora_out.grad_fn:", lora_out.grad_fn)
            print()

        output = base_out + lora_out

        if self.return_lora_output:
            return output, lora_out
        
        return output

    def load_lora_params(self, state_dict, prefix):
        self.lora_A.weight.data = state_dict[f'{prefix}.lora_A.weight'].to(self.device)
        self.lora_B.weight.data = state_dict[f'{prefix}.lora_B.weight'].to(self.device)
        if self.lora_bias:
            self.lora_A.bias.data = state_dict[f'{prefix}.lora_A.bias'].to(self.device)
            self.lora_B.bias.data = state_dict[f'{prefix}.lora_B.bias'].to(self.device)
    
class LoraModel(nn.Module):
    def __init__(self, base_model: nn.Module, lora_config: LoraConfig, 
                 return_lora_outputs: bool=False, debug: bool=False):
        super().__init__()
        self.base_model = base_model
        self.lora_layers = nn.ModuleDict()
        self.return_lora_outputs = return_lora_outputs
        self.debug = debug

        # Wrap target layers with NeroLayer
        self._wrap_target_layers(lora_config)
    
    def _wrap_target_layers(self, lora_config):
        for module_name, module in self.base_model.named_modules():
            if isinstance(module, LoraLayer):
                # Convert module name format and store reference
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.lora_layers[module_name] = module
                continue

            if any(module_name.endswith(target_module) for target_module in lora_config.target_modules) and isinstance(module, nn.Linear):    
                parent_module, child_name = self._get_parent_module(module_name)
                lora_layer = LoraLayer(
                    module, 
                    lora_config.r, 
                    lora_config.lora_alpha, 
                    lora_config.lora_dropout, 
                    lora_config.lora_bias, 
                    lora_config.use_rslora,
                    return_lora_output=self.return_lora_outputs,
                    debug=self.debug,
                )
                setattr(parent_module, child_name, lora_layer)

                # Store LoRA layers for weight loading
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.lora_layers[module_name] = lora_layer
    
    def _get_parent_module(self, module_name):
        parts = module_name.split('.')
        parent_module = self.base_model
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)
        return parent_module, parts[-1]

    def set_return_lora_outputs(self, return_lora_outputs: bool):
        self.return_lora_outputs = return_lora_outputs
        for layer in self.lora_layers.values():
            layer.return_lora_output = return_lora_outputs

    def freeze_all(self):
        for param in self.base_model.parameters():
            param.requires_grad = False
    
    def unfreeze_all(self):
        for param in self.base_model.parameters():
            param.requires_grad = True
        
        for lora_layer in self.lora_layers.values():
            for param in lora_layer.parameters():
                param.requires_grad = True
    
    def load_lora_params(self, lora_path):
        state_dict = load_file(lora_path)
        prefix = list(state_dict.keys())[0].rsplit('model.', 1)[0] + 'model.'
        for lora_layer_name, lora_layer in self.lora_layers.items():
            lora_layer_name = lora_layer_name.replace('__DOT__', '.')
            lora_layer_name = prefix + lora_layer_name
            if f'{lora_layer_name}.lora_A.weight' in state_dict and f'{lora_layer_name}.lora_B.weight' in state_dict:
                lora_layer.load_lora_params(state_dict, lora_layer_name)
            else:
                # TODO: Print warning message
                pass
        print("LoRA weights loaded successfully!")
    
    def forward(self, input_ids, attention_mask=None):
        if self.return_lora_outputs:
            lora_outs = {}
            
            def _hook_fn(layer_name, module, _in, _out):
                if isinstance(_out, tuple) and len(_out) == 2:
                    layer_out, lora_out = _out
                    lora_outs[layer_name] = lora_out # Store nero_out separately
                    return layer_out # Return only layer_out to avoid breaking model flow

            # Register hooks to extract nero_out during forward pass
            hooks = []
            for layer_name, layer in self.lora_layers.items():
                hook = layer.register_forward_hook(functools.partial(_hook_fn, layer_name))
                hooks.append(hook)
        
            try:
                output = self.base_model(input_ids, attention_mask=attention_mask)
            finally:
                # Remove hooks after forward pass, ensuring it's done even if an error occurs
                for hook in hooks:
                    hook.remove()

            return output, lora_outs
        
        return self.base_model(input_ids, attention_mask=attention_mask)
    
    def __getattr__(self, name):
        try:
            return super().__getattr__(name) # Try getting attribute from self
        except AttributeError:
            return getattr(self.base_model, name) # Fallback to base_model

base_lora_model = AutoModelForCausalLM.from_pretrained(base_model_name, device_map=lora_model_device)
lora_model = LoraModel(
    base_lora_model, 
    model_configs['target']['lora_config'],
    return_lora_outputs=True,
    debug=False,
)

In [7]:
print("Check LoRA parameters (unloaded):")
check_lora_parameters(lora_model)
print()

lora_model.load_lora_params(model_configs['target']['lora_path'])
print()

print("Check LoRA parameters (loaded):")
check_lora_parameters(lora_model)

Check LoRA parameters (unloaded):
- Name    : base_model.model.layers.0.self_attn.q_proj.lora_A.weight
- Mean    : -0.002439698437228799
- Min     : -1.4944459199905396
- Max     : 1.4414706230163574

LoRA weights loaded successfully!

Check LoRA parameters (loaded):
- Name    : base_model.model.layers.0.self_attn.q_proj.lora_A.weight
- Mean    : 2.2152138626552187e-05
- Min     : -0.06327299773693085
- Max     : 0.0625513345003128


In [8]:
lora_model.freeze_all()

In [9]:
lora_model.train()
device = next(lora_model.parameters()).device
inputs = tokenizer("Preheat the oven to 350 degrees and place the cookie dough", return_tensors='pt')
lora_model_outs = lora_model(input_ids=inputs['input_ids'].to(device))

In [10]:
_lora_logits = lora_model_outs[0].logits
_lora_outs = lora_model_outs[1]
print("_lora_logits.requires_grad", _lora_logits.requires_grad)
print("_lora_logits.grad_fn", _lora_logits.grad_fn)
print("next(iter(_lora_outs.values()).requires_grad", next(iter(_lora_outs.values())).requires_grad)
print("next(iter(_lora_outs.values()).grad_fn", next(iter(_lora_outs.values())).grad_fn)

_lora_logits.requires_grad False
_lora_logits.grad_fn None
next(iter(_lora_outs.values()).requires_grad False
next(iter(_lora_outs.values()).grad_fn None


In [11]:
lora_model.eval()
lora_model.set_return_lora_outputs(False)
generate_text(
    lora_model, 
    tokenizer, 
    prompt="Preheat the oven to 350 degrees and place the cookie dough",
)
lora_model.set_return_lora_outputs(True)
lora_model.train();

Preheat the oven to 350 degrees and place the cookie dough in the refrigerator for 10 minutes. Remove the cookie dough from the refrigerator and slice into 1/4-inch thick slices. Place the cookie dough slices on a baking sheet lined with parchment paper or a non-stick baking mat. Bake the cookies


## Nero Layer

In [12]:
class NeroLayer(nn.Module):
    def __init__(self, base_layer, 
                 # LoRA parameters
                 rank, alpha, dropout, lora_bias, use_rslora, 
                 # Nero parameters
                 nero_bias=False, 
                 return_nero_output=False,
                 # For debugging 
                 debug=False,
                 module_name=None,
                 ):
        super().__init__()
        self.base_layer = base_layer
        self.device = base_layer.weight.device
        self.alpha = alpha
        self.lora_bias = lora_bias
        self.scaling = alpha / math.sqrt(rank) if use_rslora else alpha / rank
        self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
        self.return_nero_output = return_nero_output

        # For debugging
        self.debug = debug
        self.module_name = module_name

        # Extract input and output features from the base layer
        in_features = getattr(base_layer, 'in_features', None)
        out_features = getattr(base_layer, 'out_features', None)

        if in_features is None or out_features is None:
            raise ValueError(f"Cannot determine in_features or out_features from {base_layer}.")
        
        # LoRA decomposition: A (down-projection) and B (up-projection)
        self.lora_A = nn.Linear(in_features, rank, bias=lora_bias).to(self.device)  # Projects down
        self.lora_B = nn.Linear(rank, out_features, bias=lora_bias).to(self.device) # Projects up

        # Initialize LoRA matrices: A ~ N(0, 1/rank), B initialized to 0
        std = 1 / torch.sqrt(torch.tensor(rank).float())
        nn.init.normal_(self.lora_A.weight, mean=0.0, std=std)
        nn.init.zeros_(self.lora_B.weight)

        # Nero decomposition: additional transformation applied to LoRA output
        self.nero_A = nn.Linear(out_features, rank, bias=nero_bias).to(self.device)
        self.nero_B = nn.Linear(rank, out_features, bias=nero_bias).to(self.device)

        # Initialize Nero matrices similarly
        nn.init.normal_(self.nero_A.weight, mean=0.0, std=std)
        nn.init.zeros_(self.nero_B.weight)
        
    def forward(self, x):
        # Forward through base layer
        base_out = self.base_layer(x)

        if self.debug:
            print("================================================================")
            print(self.module_name)
            print("================================================================")
            print("base_out.requires_grad:", base_out.requires_grad)
            print("base_out.grad_fn:", base_out.grad_fn)
            print()

        # LoRA transformation
        requires_conversion = not torch.is_autocast_enabled()
        if requires_conversion:
            x = x.to(self.lora_A.weight.dtype)
        lora_out = self.lora_B(self.lora_A(self.dropout(x))) * self.scaling
        # if requires_conversion:
        #     lora_out = lora_out.to(base_out.dtype)

        if self.debug:
            print("lora_out.requires_grad:", lora_out.requires_grad)
            print("lora_out.grad_fn:", lora_out.grad_fn)
            print()

        # nero_out = F.relu(self.nero_B(self.nero_A(self.dropout(lora_out))) * self.scaling)
        nero_dropout_out = self.dropout(lora_out)
        nero_A_out = self.nero_A(nero_dropout_out)
        nero_B_out = self.nero_B(nero_A_out)
        nero_scaling_out = nero_B_out * self.scaling
        nero_out = F.relu(nero_scaling_out)
        if requires_conversion:
            nero_out = nero_out.to(base_out.dtype)

        if self.debug:
            print("nero_out.requires_grad:", nero_out.requires_grad)
            print("nero_out.grad_fn:", nero_out.grad_fn)
            print()

            nero_out_has_nan = torch.isnan(nero_out).any()
            if nero_out_has_nan:
                print("!!! NERO OUT HAS NAN !!!")
                print("nero_out:")
                print(nero_out)
                print()
                print("nero_scaling_out:")
                print(nero_scaling_out)
                print()
                print("nero_B_out:")
                print(nero_B_out)
                print()
                print("nero_A_out:")
                print(nero_A_out)
                print()
                print("nero_dropout_out:")
                print(nero_dropout_out)
                print()
                print("lora_out:")
                print(lora_out)
                print()

        # Add `base_out` with gradients-detached `nero_out`, 
        # so that `base_out` does not carry gradients
        # nero_out_detached = nero_out.detach()

        # if self.debug:
        #     print("nero_out_detached.requires_grad:", nero_out_detached.requires_grad)
        #     print("nero_out_detached.grad_fn:", nero_out_detached.grad_fn)
        #     print()

        # output = base_out + nero_out_detached
        output = base_out + nero_out

        if self.debug:
            print("output.requires_grad:", output.requires_grad)
            print("output.grad_fn:", output.grad_fn)
            print()

        if self.return_nero_output:
            return output, nero_out
        
        return output

    def load_lora_params(self, state_dict, prefix):
        self.lora_A.weight.data = state_dict[f'{prefix}.lora_A.weight'].to(self.device)
        self.lora_B.weight.data = state_dict[f'{prefix}.lora_B.weight'].to(self.device)
        if self.lora_bias:
            self.lora_A.bias.data = state_dict[f'{prefix}.lora_A.bias'].to(self.device)
            self.lora_B.bias.data = state_dict[f'{prefix}.lora_B.bias'].to(self.device)
    
class NeroModel(nn.Module):
    def __init__(self, base_model: nn.Module, lora_config: LoraConfig, nero_bias: bool=False, 
                 return_nero_outputs: bool=False, debug: bool=False):
        super().__init__()
        self.base_model = base_model
        self.nero_bias = nero_bias
        self.nero_layers = nn.ModuleDict()
        self.return_nero_outputs = return_nero_outputs
        self.debug = debug

        # Wrap target layers with NeroLayer
        self._wrap_target_layers(lora_config)
        
    def _wrap_target_layers(self, lora_config):
        for module_name, module in self.base_model.named_modules():
            if isinstance(module, NeroLayer):
                # Convert module name format and store reference
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.nero_layers[module_name] = module
                continue

            if any(module_name.endswith(target_module) for target_module in lora_config.target_modules) and isinstance(module, nn.Linear):    
                parent_module, child_name = self._get_parent_module(module_name)
                nero_layer = NeroLayer(
                    module, 
                    lora_config.r, 
                    lora_config.lora_alpha, 
                    lora_config.lora_dropout, 
                    lora_config.lora_bias, 
                    lora_config.use_rslora,
                    nero_bias=self.nero_bias,
                    return_nero_output=self.return_nero_outputs,
                    debug=self.debug,
                    module_name=module_name,
                )
                setattr(parent_module, child_name, nero_layer)

                # Store LoRA layers for weight loading
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.nero_layers[module_name] = nero_layer
    
    def _get_parent_module(self, module_name):
        parts = module_name.split('.')
        parent_module = self.base_model
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)
        return parent_module, parts[-1]
    
    def set_return_nero_outputs(self, return_nero_outputs: bool):
        self.return_nero_outputs = return_nero_outputs
        for layer in self.nero_layers.values():
            layer.return_nero_output = return_nero_outputs

    def freeze_all_except_nero(self):
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        for nero_layer in self.nero_layers.values():
            for param_name, param in nero_layer.named_parameters():
                if 'nero_A' in param_name or 'nero_B' in param_name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
    
    def unfreeze_all(self):
        for param in self.base_model.parameters():
            param.requires_grad = True
        
        for nero_layer in self.nero_layers.values():
            for param in nero_layer.parameters():
                param.requires_grad = True
    
    def load_lora_params(self, lora_path):
        state_dict = load_file(lora_path)
        prefix = list(state_dict.keys())[0].rsplit('model.', 1)[0] + 'model.'
        for nero_layer_name, nero_layer in self.nero_layers.items():
            nero_layer_name = nero_layer_name.replace('__DOT__', '.')
            nero_layer_name = prefix + nero_layer_name
            if f'{nero_layer_name}.lora_A.weight' in state_dict and f'{nero_layer_name}.lora_B.weight' in state_dict:
                nero_layer.load_lora_params(state_dict, nero_layer_name)
            else:
                # TODO: Print warning message
                pass
        print("LoRA weights loaded successfully!")
    
    def forward(self, input_ids, attention_mask=None):
        if self.return_nero_outputs:
            nero_outs = {}
            
            def _hook_fn(layer_name, module, _in, _out):
                if isinstance(_out, tuple) and len(_out) == 2:
                    layer_out, nero_out = _out
                    nero_outs[layer_name] = nero_out # Store nero_out separately
                    return layer_out # Return only layer_out to avoid breaking model flow

            # Register hooks to extract nero_out during forward pass
            hooks = []
            for layer_name, layer in self.nero_layers.items():
                hook = layer.register_forward_hook(functools.partial(_hook_fn, layer_name))
                hooks.append(hook)
        
            try:
                output = self.base_model(input_ids, attention_mask=attention_mask)
            finally:
                # Remove hooks after forward pass, ensuring it's done even if an error occurs
                for hook in hooks:
                    hook.remove()

            return output, nero_outs
        
        return self.base_model(input_ids, attention_mask=attention_mask)
    
    def __getattr__(self, name):
        try:
            return super().__getattr__(name) # Try getting attribute from self
        except AttributeError:
            return getattr(self.base_model, name) # Fallback to base_model

base_nero_model = AutoModelForCausalLM.from_pretrained(base_model_name, device_map=nero_model_device)
nero_model = NeroModel(
    base_nero_model, 
    model_configs['source']['lora_config'], 
    nero_bias=True, 
    return_nero_outputs=True,
    debug=False,
)

In [13]:
print("Check LoRA parameters (unloaded):")
check_lora_parameters(nero_model)
print()

nero_model.load_lora_params(model_configs['source']['lora_path'])
print()

print("Check LoRA parameters (loaded):")
check_lora_parameters(nero_model)

Check LoRA parameters (unloaded):
- Name    : base_model.model.layers.0.self_attn.q_proj.lora_A.weight
- Mean    : -0.0019485268276184797
- Min     : -1.4076730012893677
- Max     : 1.3847874402999878

LoRA weights loaded successfully!

Check LoRA parameters (loaded):
- Name    : base_model.model.layers.0.self_attn.q_proj.lora_A.weight
- Mean    : 6.287686119321734e-05
- Min     : -0.04176201671361923
- Max     : 0.04242725297808647


In [14]:
nero_model.freeze_all_except_nero()

In [15]:
nero_model.train()
device = next(nero_model.parameters()).device
inputs = tokenizer("Preheat the oven to 350 degrees and place the cookie dough", return_tensors='pt')
nero_model_outs = nero_model(
    input_ids=inputs['input_ids'].to(device), 
    attention_mask=inputs['attention_mask'].to(device)
)

In [16]:
_nero_logits = nero_model_outs[0].logits
_nero_outs = nero_model_outs[1]
print("_nero_logits.requires_grad", _nero_logits.requires_grad)
print("_nero_logits.grad_fn", _nero_logits.grad_fn)
print("next(iter(_nero_outs.values()).requires_grad", next(iter(_nero_outs.values())).requires_grad)
print("next(iter(_nero_outs.values()).grad_fn", next(iter(_nero_outs.values())).grad_fn)

_nero_logits.requires_grad True
_nero_logits.grad_fn <UnsafeViewBackward0 object at 0x7b51d4ef78e0>
next(iter(_nero_outs.values()).requires_grad True
next(iter(_nero_outs.values()).grad_fn <ToCopyBackward0 object at 0x7b51d4ef78e0>


In [17]:
nero_model.eval()
nero_model.set_return_nero_outputs(False)
generate_text(
    nero_model, 
    tokenizer, 
    prompt="Preheat the oven to 350 degrees and place the cookie dough",
)
nero_model.set_return_nero_outputs(True)
nero_model.train();

Preheat the oven to 350 degrees and place the cookie dough Destruction_<?_<?_<? gloss_<? �_<?470_<?_<? сбор_<?_<?_<? �_<? �芳 � сборhton Destruction_<?_<?vak470_<? Shore470_<?onse�姫MOOTH470 �470 gloss Destruction Satoshi470� �_<?_<? Destruction_<? gloss Rica


In [18]:
dataset = load_hf_dataset_from_lora(model_configs['target']['hf_lora_id'], test_size=0)

In [19]:
if 'gsm8k' in model_configs['target']['hf_lora_id']:
    def format_prompt(example):
        gsm8k_prompt = """### Instruction:
Solve the following math problem step by step.

### Question: 
{question}

### Answer: 
{answer}"""

        return {'text': gsm8k_prompt.format(
            question=example['question'], 
            answer=example['answer'],
        )}

    def tokenize_fn(example):
        return tokenizer(
            example["text"],
            truncation=True,
            padding='max_length',
            max_length=block_size,
        )

    def add_labels(example):
        example['labels'] = example['input_ids'].copy()
        return example

    dataset = dataset.map(format_prompt)
    dataset = dataset.map(tokenize_fn, batched=True, remove_columns=dataset.column_names)
    dataset = dataset.map(add_labels)
else:
    eos_token = tokenizer.eos_token
    tokenizer.pad_token = eos_token

    # Tokenize dataset
    def tokenize_fn(example):
        return tokenizer(example['text'])

    dataset = dataset.map(
        tokenize_fn, 
        batched=True, 
        remove_columns=dataset.column_names,
    )

    # Concatenate all tokens into one long stream, then split into blocks
    def group_texts(examples):
        concatenated = []
        for input_ids in examples['input_ids']:
            concatenated += input_ids

        total_length = len(concatenated) // block_size * block_size

        input_ids = [concatenated[i:i + block_size] for i in range(0, total_length, block_size)]
        attention_mask = [[1] * block_size for _ in input_ids]
        labels = input_ids.copy()

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels,
        }

    dataset = dataset.map(
        group_texts, 
        batched=True, 
        batch_size=1000,
        remove_columns=dataset.column_names,
    )

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (141723 > 131072). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

In [20]:
train_loader = DataLoader(
    dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    collate_fn=default_data_collator,
)
print("Total batches:", len(train_loader))
print("First batch input IDs shape:", next(iter(train_loader))['input_ids'].shape)

Total batches: 73856
First batch input IDs shape: torch.Size([4, 96])


# Training

In [21]:
# _debugged = False
# _layer_name = None
# _nero_out = None
# _lora_out = None
# _mse_loss_unnormed = None
# _nero_out_sum = None

# def loss_fn_v1(nero_outs, lora_outs):
#     assert nero_outs.keys() == lora_outs.keys() # TODO: Print warning message
#     total_loss = 0.0

#     for layer_name in lora_outs.keys():
#         nero_out = nero_outs[layer_name]
#         lora_out = lora_outs[layer_name]

#         # Normalized MSE loss
#         # mse_loss = F.mse_loss(nero_out, lora_out, reduction='sum') / torch.sum(nero_outs[layer_name] ** 2)
#         mse_loss_unnormed = F.mse_loss(nero_out, lora_out, reduction='sum')
#         nero_out_sum = torch.sum(nero_outs[layer_name] ** 2)
#         mse_loss = mse_loss_unnormed / nero_out_sum

#         print("================================================================")
#         print(layer_name)
#         print("================================================================")
#         print("mse_loss:", mse_loss)
#         print("nero_out_sum:", nero_out_sum)
#         print("mse_loss_unnormed:", mse_loss_unnormed)
#         print()

#         global _debugged
#         if (torch.isinf(nero_out_sum) or torch.isnan(mse_loss_unnormed)) and not _debugged:
#             global _layer_name
#             global _nero_out
#             global _lora_out
#             global _mse_loss_unnormed
#             global _nero_out_sum
#             _layer_name = layer_name
#             _nero_out = nero_out
#             _lora_out = lora_out
#             _mse_loss_unnormed = mse_loss_unnormed
#             _nero_out_sum = nero_out_sum
#             _debugged = True

#         if torch.isnan(mse_loss):
#             nero_out_has_nan = torch.isnan(nero_out).any()
#             lora_out_has_nan = torch.isnan(lora_out).any()

#             print("nero_out_has_nan:", nero_out_has_nan)
#             print("lora_out_has_nan:", lora_out_has_nan)
#             print()

#             # if nero_out_has_nan:
#             print("nero_out:")
#             print(nero_out)
#             print()
            
#             # if lora_out_has_nan:
#             print("lora_out:")
#             print(lora_out)
#             print()

#         total_loss += mse_loss

#     return total_loss / len(lora_outs)  # Averaging loss across layers

# # loss = loss_func_v1(nero_model_outs[1], lora_model_outs[1])
# # print(loss)

In [22]:
# def loss_fn_v2(nero_outs, lora_outs):
#     assert nero_outs.keys() == lora_outs.keys()
#     total_loss = 0.0

#     for layer_name in lora_outs.keys():
#         print("================================================================")
#         print(layer_name)
#         print("================================================================")

#         nero_out = nero_outs[layer_name]
#         lora_out = lora_outs[layer_name]

#         diff = nero_out - lora_out
#         print("mean(abs(diff)):", diff.abs().mean())
#         print("mean((diff)^2):", (diff ** 2).mean())
#         print()
        
#         print("max abs(diff):", diff.abs().max().item())
#         print("mean abs(diff):", diff.abs().mean().item())
#         print("std abs(diff):", diff.abs().std().item())
#         print()

#         # Use mean MSE to prevent overflow and keep scale uniform
#         mse_loss = F.mse_loss(nero_out.float(), lora_out.float(), reduction='mean')  # Compute in float32

#         print("mse_loss:", mse_loss)
#         print()

#         total_loss += mse_loss

#     return total_loss / len(lora_outs)


In [23]:
# _debugged = False
# _nero_out = None
# _lora_out = None

# def loss_fn_v3(nero_outs, lora_outs, debug=False):
#     assert nero_outs.keys() == lora_outs.keys()
#     loss_device = next(iter(nero_outs.values())).device
#     total_loss = torch.tensor(0.0, device=loss_device)
    
#     for layer_name in lora_outs.keys():
#         nero_out = nero_outs[layer_name]
#         lora_out = lora_outs[layer_name].to(nero_out.device)

#         # Use mean MAE to prevent overflow and keep scale uniform
#         # mae_loss = F.l1_loss(nero_out.float(), lora_out.float(), reduction='mean')  
#         # mae_loss = torch.mean(torch.abs(nero_out.float() - lora_out.float()))
#         mae_loss = torch.mean(torch.abs(nero_out.float() - lora_out.float()), dim=-1).mean() # scale by sequence length

#         if debug:
#             print("================================================================")
#             print(layer_name)
#             print("================================================================")
            
#             global _debugged
#             if not _debugged:
#                 global _nero_out
#                 global _lora_out
#                 _nero_out = nero_out
#                 _lora_out = lora_out
#                 _debugged = True

#             diff = nero_out - lora_out
#             print("mean(abs(diff)):", diff.abs().mean())
#             print("mean((diff)^2):", (diff ** 2).mean())
#             print()
            
#             print("max abs(diff):", diff.abs().max().item())
#             print("mean abs(diff):", diff.abs().mean().item())
#             print("std abs(diff):", diff.abs().std().item())
#             print()

#             print("mae_loss:", mae_loss)
#             print()

#         total_loss += mae_loss.to(loss_device)

#     return total_loss / len(lora_outs)

In [24]:
def loss_fn_v4(nero_outs, lora_outs, nero_logits, lora_logits, alpha=1.0, beta=0.00015, temperature=2.0, debug=False):
    assert nero_outs.keys() == lora_outs.keys()
    loss_device = next(iter(nero_outs.values())).device
    total_hidden_loss = torch.tensor(0.0, device=loss_device)

    # --- Hidden representation loss ---
    for layer_name in lora_outs.keys():
        nero_out = nero_outs[layer_name]
        lora_out = lora_outs[layer_name].to(nero_out.device)

        # Scale by sequence length to avoid length-sensitive loss scaling
        hidden_loss = torch.mean(torch.abs(nero_out.float() - lora_out.float()), dim=-1).mean()

        if debug:
            print("Layer:", layer_name)
            print("hidden_loss:", hidden_loss.item())

        total_hidden_loss += hidden_loss

    total_hidden_loss /= len(lora_outs)
    total_hidden_loss *= alpha

    # --- Logit KL divergence loss ---
    # Important: apply softmax with temperature for distillation
    logit_loss = F.kl_div(
        F.log_softmax(nero_logits.to(loss_device) / temperature, dim=-1),
        F.softmax(lora_logits.to(loss_device) / temperature, dim=-1),
        reduction='batchmean'
    ) * (temperature ** 2)
    logit_loss *= beta

    if debug:
        print("Logit KL loss:", logit_loss.item())

    # --- Final combined loss ---
    total_loss = total_hidden_loss + logit_loss
    return total_loss, total_hidden_loss, logit_loss


In [29]:
loss_fn_v4(
    nero_outs=_nero_outs,
    lora_outs=_lora_outs,
    nero_logits=_nero_logits,
    lora_logits=_lora_logits
)

(tensor(0.3478, device='cuda:1', grad_fn=<AddBackward0>),
 tensor(0.1808, device='cuda:1', grad_fn=<MulBackward0>),
 tensor(0.1670, device='cuda:1', dtype=torch.float16, grad_fn=<MulBackward0>))

aa

In [25]:
nero_params = [p for n, p in nero_model.named_parameters() if p.requires_grad]
optimizer = torch.optim.Adam(nero_params, lr=lr)
scaler = torch.cuda.amp.GradScaler()

wandb.init(
    project='Nero-XLT',
    config=dict(
        seed = seed,
        lora_model_device = lora_model_device,
        nero_model_device = nero_model_device,
        block_size = block_size,
        batch_size = batch_size,
        num_epochs = num_epochs,
        max_global_steps = max_global_steps,
        resume_step = resume_step,
        lr = lr,
    ),
)

# Ensure model devices are set if not specified
if nero_model_device is None:
    nero_model_device = next(iter(nero_model.parameters())).device
if lora_model_device is None:
    lora_model_device = next(iter(lora_model.parameters())).device

global_step = 0
max_global_steps = max_global_steps or len(train_loader) * num_epochs
done = False

for epoch in range(num_epochs):
    for step, batch in enumerate(train_loader):
        if global_step >= max_global_steps:
            done = True
            break

        if global_step < resume_step:
            global_step += 1
            continue

        # Flush gradients
        optimizer.zero_grad()

        # Move inputs to devices
        nero_input_ids = batch['input_ids'].to(nero_model_device)
        nero_attention_mask = batch['attention_mask'].to(nero_model_device)

        with torch.cuda.amp.autocast():
            # Forward pass for nero
            nero_model_outs, nero_outs = nero_model(
                input_ids=nero_input_ids,
                attention_mask=nero_attention_mask
            )

            # Forward pass for lora
            lora_input_ids = nero_input_ids.to(lora_model_device)
            lora_attention_mask = nero_attention_mask.to(lora_model_device)
            lora_model_outs, lora_outs = lora_model(
                input_ids=lora_input_ids,
                attention_mask=lora_attention_mask
            )

            # Loss computation
            loss, hidden_loss, logit_loss = loss_fn_v4(
                nero_outs, 
                lora_outs, 
                nero_model_outs.logits, 
                lora_model_outs.logits,
                debug=False,
            )

        # Backward pass
        scaler.scale(loss).backward()
        
        # Compute gradient norm
        # total_norm = 0.0
        # for p in nero_params:
        #     param_norm = p.grad.data.norm(2)
        #     total_norm += param_norm.item() ** 2
        # total_norm = total_norm ** 0.5
        # print(f"Gradient norm: {total_norm:.4f}")
        
        # Update parameters
        scaler.step(optimizer)
        scaler.update()

        # Logging
        wandb.log({
            'epoch': epoch,
            'step': global_step,
            'loss': loss.item(),
            'hidden_loss': hidden_loss.item(),
            'logit_loss': logit_loss.item(),
        })
        print(f"epoch: {epoch}/{num_epochs}, step: {global_step}/{max_global_steps}, loss: {loss.item():.4f}, hidden_loss: {hidden_loss.item():.4f}, logit_loss: {logit_loss.item():.4f}")

        global_step += 1
    
    if done:
        break

wandb.finish()

  scaler = torch.cuda.amp.GradScaler()
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33malimtegar[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112385333338656, max=1.0…

  with torch.cuda.amp.autocast():


epoch: 0/1, step: 0/100, loss: 0.3967, hidden_loss: 0.1817, logit_loss: 0.2150
epoch: 0/1, step: 1/100, loss: 0.3899, hidden_loss: 0.1816, logit_loss: 0.2083
epoch: 0/1, step: 2/100, loss: 0.3874, hidden_loss: 0.1813, logit_loss: 0.2061
epoch: 0/1, step: 3/100, loss: 0.3853, hidden_loss: 0.1809, logit_loss: 0.2043
epoch: 0/1, step: 4/100, loss: 0.4259, hidden_loss: 0.1812, logit_loss: 0.2447
epoch: 0/1, step: 5/100, loss: 0.3534, hidden_loss: 0.1810, logit_loss: 0.1724
epoch: 0/1, step: 6/100, loss: 0.3629, hidden_loss: 0.1808, logit_loss: 0.1821
epoch: 0/1, step: 7/100, loss: 0.3890, hidden_loss: 0.1806, logit_loss: 0.2084
epoch: 0/1, step: 8/100, loss: 0.3553, hidden_loss: 0.1804, logit_loss: 0.1749
epoch: 0/1, step: 9/100, loss: 0.3580, hidden_loss: 0.1803, logit_loss: 0.1777
epoch: 0/1, step: 10/100, loss: 0.4100, hidden_loss: 0.1806, logit_loss: 0.2294
epoch: 0/1, step: 11/100, loss: 0.3599, hidden_loss: 0.1801, logit_loss: 0.1799
epoch: 0/1, step: 12/100, loss: 0.3570, hidden_los

VBox(children=(Label(value='0.024 MB of 0.024 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
hidden_loss,██████████▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆▅▄▄▄▄▃▃▃▃▃▂▂▂▂▁
logit_loss,▆█▄▆▅▅▅▅▆▆▆▄▅▅▄▆▆▅▅▄▅▄▄▃▃▄▅▃▃▄▂▂▃▃▂▂▄▃▁▁
loss,▇▆█▅▆▅▅▆▅▅▆▅▅▅▆▆▅▆▆▅▅▅▄▄▄▅▅▃▃▂▃▂▂▃▂▃▁▃▂▁
step,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇█

0,1
epoch,0.0
hidden_loss,0.13782
logit_loss,0.10883
loss,0.24665
step,99.0


In [48]:
# Save Nero parameters
nero_params_path = f"nero_params_L1T1_to_{model_configs['target']['label']}.pth"
lora_state_dict = {k: v for k, v in nero_model.state_dict().items() if 'nero_' in k}
torch.save(lora_state_dict, nero_params_path)
print("Nero parameters saved to: ", nero_params_path)

Nero parameters saved to nero_params_L1T1_to_L1T2.pth


In [28]:
nero_model.eval()
nero_model.set_return_nero_outputs(False)
generate_text(
    nero_model, 
    tokenizer, 
    prompt="空が青い",
)
nero_model.set_return_nero_outputs(True)
nero_model.train();

空が青い年575の年757年3766年が年年77の年637のに7年年、7の年3のの年5年77577年7年7年
