In [1]:
# Kill all processess on GPU
!fuser -v /dev/nvidia* -k

# Libraries

In [2]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    %pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    %pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    %pip install sentencepiece protobuf "datasets>=4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    %pip install --no-deps unsloth
%pip install trl==0.19.1 # Fix error: ImportError: cannot import name 'ConstantLengthDataset' from 'trl.trainer.utils'

In [3]:
import os
import functools
import gc
import json
import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from typing import Optional, Literal, Union, List
from torch.utils.data import DataLoader
from datasets import load_dataset, Dataset
from transformers import default_data_collator, AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
from huggingface_hub import snapshot_download, create_repo, upload_folder
from safetensors.torch import load_file, save_file
from datetime import datetime
from pprint import pprint

import cuml
from cuml.decomposition import PCA as cuPCA
from cuml.manifold import UMAP as cuUMAP

In [4]:
def download_hf_model(
        repo_id: str, 
        checkpoint: Optional[int], 
        max_checkpoints: str = 10_000, 
        checkpoint_steps: str = 25,
    ):
    local_dir = repo_id.split('/')[-1]
    ignore_checkpoints = None
    
    if checkpoint is not None:
        ignore_checkpoints = [f'checkpoint-{i}/*' for i in range(0, max_checkpoints, checkpoint_steps) if i != checkpoint]

    snapshot_download(
        repo_id=repo_id,
        local_dir=local_dir,
        ignore_patterns=ignore_checkpoints,
    )

    checkpoint_dir = None
    if checkpoint is not None:
        checkpoint_dir = os.path.join(local_dir, f'checkpoint-{checkpoint}')
    return local_dir, checkpoint_dir

def check_loss_and_grad_norm(
    model, 
    tokenizer, 
    prompt="Paris is the capital of",
):
    # Set model to training mode
    model.train()

    # Zero gradients manually
    for p in model.parameters():
        p.grad = None

    # Forward pass
    inputs = tokenizer(
        prompt, 
        return_tensors='pt',
    ).to(next(model.parameters()).device)

    outputs = model(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        labels=inputs['input_ids'],
        use_cache=False, # Disable cache to not conflict with gradient checkpointing
    )
    if isinstance(outputs, tuple):
        outputs = outputs[0]
    print("Loss:", outputs.loss)

    # Backward pass
    if outputs.loss.grad_fn is None:
        print("Gradient norm:", None)
        return

    outputs.loss.backward()

    # Compute gradient norm
    grad_norm = 0.0
    for n, p in model.named_parameters():
        if p.requires_grad:
            p_grad_norm = p.grad.data.norm(2)
            grad_norm += p_grad_norm.item() ** 2
    grad_norm = grad_norm ** 0.5

    print("Gradient norm:", grad_norm)

def check_parameter(n, p):
    print(f"- {'name':<8}:", n)
    print(f"- {'device':<8}:", p.device)
    print(f"- {'dtype':<8}:", p.dtype)
    print(f"- {'mean':<8}:", p.mean().item())
    print(f"- {'min':<8}:", p.min().item())
    print(f"- {'max':<8}:", p.max().item())

def check_lora_parameters(model, prefix=None):
    prefix = 'lora.' + prefix if prefix != None else 'lora'
    for n, p in model.named_parameters():
        if prefix in n:
            check_parameter(n, p)
            break

@torch.no_grad()
def generate_text(model, tokenizer, prompt, max_new_tokens=32, device=None, skip_special_tokens=True):
    device = device or next(model.parameters()).device
    model.eval()
    inputs = tokenizer(prompt, return_tensors='pt').to(device)
    outputs = model.generate(input_ids=inputs['input_ids'], max_new_tokens=max_new_tokens)
    model.train()
    return tokenizer.decode(outputs[0], skip_special_tokens=skip_special_tokens)

def get_task_and_lang_from_repo_id(repo_id: str):
    task, lang, _ = repo_id.split('B-')[-1].split('K-')[0].split('-')
    return task, lang

def load_hf_dataset_from_lora(
    lora_repo_id: str,
    train_size: int = 5000,
    test_size: int = 1000,
):
    # Get task and language
    task, lang = get_task_and_lang_from_repo_id(lora_repo_id)

    # Set up Hugging Face configuration
    data_id_map = {
        'wikipedia': 'wikimedia/wikipedia',
        'gsm8k': 'openai/gsm8k',
    }
    data_id = data_id_map[task]
    data_dir = f'20231101.{lang}' if task == 'wikipedia' else 'main'
    
    # Load dataset using streaming
    dataset_stream = load_dataset(data_id, data_dir=data_dir, split='train', streaming=True)

    # Manually take train_size + test_size samples
    total_size = train_size + test_size
    sliced_data = []
    for i, example in enumerate(dataset_stream):
        if i >= total_size:
            break
        sliced_data.append(example)

    # Convert to regular in-memory dataset
    dataset = Dataset.from_list(sliced_data)
    
    return dataset

def load_hf_dataset(
    lang, 
    task,
    split='train',
    train_size = 5000,
    test_size = 1000,
):
    # Set up Hugging Face configuration
    data_id_map = {
        'wikipedia': 'wikimedia/wikipedia',
        'gsm8k': 'openai/gsm8k',
    }
    data_id = data_id_map[task]
    data_dir = f'20231101.{lang}' if task == 'wikipedia' else 'main'

    # Use streaming
    dataset_stream = load_dataset(data_id, data_dir=data_dir, split=split, streaming=True)

    # Manually take train_size + test_size samples
    total_size = train_size + test_size
    sliced_data = []
    for i, example in enumerate(dataset_stream):
        if i >= total_size:
            break
        sliced_data.append(example)

    # Convert to regular in-memory dataset
    dataset = Dataset.from_list(sliced_data)
    
    return dataset

def compute_grad_norm(params):
    grad_norm = 0.0
    for p in params:
        p_grad_norm = p.grad.data.norm(2)
        grad_norm += p_grad_norm.item() ** 2
    grad_norm = grad_norm ** 0.5
    return grad_norm

def compute_named_grad_norm(named_params):
    grad_norm = 0.0
    for n, p in named_params.items():
        if p.grad is not None:
            p_grad_norm = p.grad.data.norm(2)
            print("{n} p_grad_norm:", p_grad_norm)
            grad_norm += p_grad_norm.item() ** 2
        else:
            print(f"[WARN] No gradient for {n}")
    grad_norm = grad_norm ** 0.5
    return grad_norm

def format_float(v):
    if abs(v) < 0.0001 or abs(v) >= 10000:
        return f"{v:.4e}"
    else:
        return f"{v:.4f}"

# Configurations

In [5]:
# Project configuration
seed = 69
device = 'auto'

# Data configuration
hf_data_id = 'alxxtexxr/Nero-XLT-Dataset'
hf_data_dir = 'gsm8k_en_5K_1K_1K_512'

# Training configuration
batch_size = 4
num_epochs = 1
max_global_steps = None
grad_accumulation_steps = 2
clip_grad_norm = 1.0
lr = 2e-4
warmup_ratio = 0.1
# num_warmup_steps = 100
checkpoint_steps = 25
push_to_hf = False
generate_steps = 25
sample_prompt = '102452 + 102453 ='
distance_fn = 'euclidean'

# Model configurations
model_configs = {
    # L1T1 (Source Language - Source Task)
    'L1T1': {
        'hf_lora_id': 'alxxtexxr/L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650',
        'checkpoint': 650,
    },

    # L2T1 (Target Language - Source Task)
    'L2T1': {
        'hf_lora_id': 'alxxtexxr/L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629',
        'checkpoint': 650,
    },

    # L1T2 (Source Language - Target Task)
    # 'L1T2': {
    #     'hf_lora_id': 'alxxtexxr/L3.1-8B-gsm8k-en-5K-LoRA-v20250701060457',
    #     'checkpoint': 1875,
    # },
}

for key, config in model_configs.items():
    _, lora_dir = download_hf_model(config['hf_lora_id'], config['checkpoint'])
    model_configs[key]['lora_dir'] = lora_dir
    model_configs[key]['lora_path'] = os.path.join(lora_dir, 'adapter_model.safetensors')
    model_configs[key]['lora_config'] = LoraConfig.from_pretrained(lora_dir)

print("Model configurations:",)
for key, config in model_configs.items():
    print(f"- {key}:")
    for config_name, config_value in config.items():
        if config_name == 'lora_config':
            continue
        print(f"{'-':>3} {config_name:<10}: {config_value}")
print()

assert (
    model_configs['L1T1']['lora_config'].base_model_name_or_path == 
    model_configs['L2T1']['lora_config'].base_model_name_or_path
), "Base models must be the same"
base_model_name = model_configs['L1T1']['lora_config'].base_model_name_or_path
print(f"Base model name: {base_model_name}")

tokenizer = AutoTokenizer.from_pretrained(base_model_name)

Fetching 20 files:   0%|          | 0/20 [00:00<?, ?it/s]

Fetching 20 files:   0%|          | 0/20 [00:00<?, ?it/s]

Model configurations:
- L1T1:
  - hf_lora_id: alxxtexxr/L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650
  - checkpoint: 650
  - lora_dir  : L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650/checkpoint-650
  - lora_path : L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650/checkpoint-650/adapter_model.safetensors
- L2T1:
  - hf_lora_id: alxxtexxr/L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629
  - checkpoint: 650
  - lora_dir  : L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629/checkpoint-650
  - lora_path : L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629/checkpoint-650/adapter_model.safetensors

Base model name: unsloth/meta-llama-3.1-8b-unsloth-bnb-4bit


In [6]:
# Hugging Face configuration
hf_nero_id = None
resume_step = 0

if hf_nero_id is not None and resume_step > 0:
    print(f"[INFO] Downloading Nero checkpoint at step {resume_step} from Hugging Face repository:", hf_nero_id)
    nero_dir, _ = download_hf_model(hf_nero_id, resume_step)
    print(f"[INFO] Nero checkpoint downloaded successfully!")
else:
    hf_username = 'alxxtexxr'
    nero_dir = f'L3.1-8B-{hf_data_dir.replace("_", "-")}-Nero-{distance_fn}-v{datetime.now().strftime("%Y%m%d%H%M%S")}'
    print(f"[INFO] Creating Nero directory:", nero_dir)
    hf_nero_id = f'{hf_username}/{nero_dir}'
    os.makedirs(nero_dir, exist_ok=True)
    print(f"[INFO] Nero directory created!")

[INFO] Creating Nero directory: L3.1-8B-gsm8k-en-5K-1K-1K-512-Nero-euclidean-v20250821173023
[INFO] Nero directory created!


# Model

In [7]:
class NonlinearFunction(nn.Module):
    def __init__(self, dim, rank=8):
        super().__init__()
        self.A = nn.Linear(dim, rank, bias=False)
        self.B = nn.Linear(rank, dim, bias=False)
        self.act = nn.Tanh() # or GELU
    
    def forward(self, x, name=None):
        act_out = self.act(self.B(self.A(x)))
        # print()
        # print(f"{name}.act_out ->", act_out.mean().item(), act_out.max().item(), act_out.min().item())
        # print()
        return x + act_out

class NeroLayer(nn.Module):
    def __init__(self, base_layer,
                 
                 # LoRA parameters
                 L1T1_lora_params, 
                 L2T1_lora_params,
                 L2T2_lora_params,
                 eval_L2T2_lora = False,
                 
                 # Debugging parameters
                 debug=False,
                 module_name=None,
                 ):
        super().__init__()

        self.base_layer = base_layer
        self.device = base_layer.weight.device
        self._eval_L2T2_lora = eval_L2T2_lora
        self.debug = debug
        self.module_name = module_name

        # Extract input and output features from the base layer
        in_features = getattr(base_layer, 'in_features', None)
        out_features = getattr(base_layer, 'out_features', None)

        if in_features is None or out_features is None:
            raise ValueError(f"Cannot determine in_features or out_features from {base_layer}.")

        # Initialize LoRA layers
        self.lora = nn.ModuleDict({
            'L1T1': self._init_lora_layer('L1T1', in_features, out_features, **L1T1_lora_params, device=self.device), # frozen
            'L2T1': self._init_lora_layer('L2T1', in_features, out_features, **L2T1_lora_params, device=self.device), # frozen
            'L2T2': self._init_lora_layer('L2T2', in_features, out_features, **L2T2_lora_params, device=self.device), # trained alternately with nonlinear function
        })
        
        # Initialize nonlinear function layer
        self.nonlinear_fn = NonlinearFunction(out_features).to(self.device) # trained alternately with L2T2 LoRA layer
        # self.nonlinear_fn = nn.Identity()
        
    def _init_lora_layer(self, name, in_features, out_features, rank, alpha, dropout=0.0, bias=True, use_rslora=False, device=None):
        scaling = alpha / math.sqrt(rank) if use_rslora else alpha / rank
        scaling *= 0.1 # temporary to prevent overflow
        dropout_layer = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        layer = nn.ModuleDict({
            'A': nn.Linear(in_features, rank, bias=bias, device=device),
            'B': nn.Linear(rank, out_features, bias=bias, device=device),
            'dropout': dropout_layer
        })
        layer.scaling = scaling
        layer.bias_flag = bias
        nn.init.normal_(layer['A'].weight, 0.0, 1 / math.sqrt(rank))
        if name == 'L2T2':
            nn.init.normal_(layer['B'].weight, 0.0, 1e-3)
        else:
            nn.init.zeros_(layer['B'].weight)
        return layer

    def forward(self, x):
        # ================================================================
        # Base Layer
        # ================================================================
        base_out = self.base_layer(x)

        requires_conversion = not torch.is_autocast_enabled()
        if requires_conversion:
            x = x.to(self.lora.L1T1.A.weight.dtype)
            
        # ================================================================
        # L2T2 LoRA Layer (Trainable)
        # ================================================================
        # L2T2_lora_out = self.lora.L2T2.B(self.lora.L2T2.A(self.lora.L2T2.dropout(x))) #* self.lora.L2T2.scaling
        L2T2_lora_dropout_out = self.lora.L2T2.dropout(x)
        L2T2_lora_A_out = self.lora.L2T2.A(L2T2_lora_dropout_out)
        L2T2_lora_B_out = self.lora.L2T2.B(L2T2_lora_A_out)
        L2T2_lora_out = L2T2_lora_B_out * self.lora.L2T2.scaling
        
        if self._eval_L2T2_lora:
            if requires_conversion:
                L2T2_lora_out = L2T2_lora_out.to(base_out.dtype)
            return base_out + L2T2_lora_out

        # ================================================================
        # L1T1 LoRA Layer (Frozen)
        # ================================================================
        # L1T1_lora_out = self.lora.L1T1.B(self.lora.L1T1.A(self.lora.L1T1.dropout(x))) #* self.lora.L1T1.scaling
        L1T1_lora_dropout_out = self.lora.L1T1.dropout(x)
        L1T1_lora_A_out = self.lora.L1T1.A(L1T1_lora_dropout_out)
        L1T1_lora_B_out = self.lora.L1T1.B(L1T1_lora_A_out)
        L1T1_lora_out = L1T1_lora_B_out * self.lora.L1T1.scaling
        
        # ================================================================
        # L2T1 LoRA Layer (Frozen)
        # ================================================================
        # L2T1_lora_out = self.lora.L2T1.B(self.lora.L2T1.A(self.lora.L2T1.dropout(x))) #* self.lora.L2T1.scaling
        L2T1_lora_dropout_out = self.lora.L2T1.dropout(x)
        L2T1_lora_A_out = self.lora.L2T1.A(L2T1_lora_dropout_out)
        L2T1_lora_B_out = self.lora.L2T1.B(L2T1_lora_A_out)
        L2T1_lora_out = L2T1_lora_B_out * self.lora.L2T1.scaling

        # ================================================================
        # Output
        # ================================================================
        # Linear:
        # queen = king - man + king
        # L2T2 = L1T2 - L1T1 + L2T1
        
        # Nonlinear:
        # f(L2T2) = f(L1T2) - f(L1T1) + f(L2T1)
        # f(L1T2) = f(L2T2) - f(L2T1) + f(L1T1)
        # L1T2 = f^-1(f(L2T2) - f(L2T1) + f(L1T1))
        
        # with torch.no_grad():  # avoid affecting computation graph
        #     print()
        #     print("x:", x.mean().item(), x.max().item(), x.min().item())
        #     print()
        #     print("L2T2_lora_B_out:", L2T2_lora_B_out.mean().item(), L2T2_lora_B_out.max().item(), L2T2_lora_B_out.min().item())
        #     print("L2T2_lora_out:", L2T2_lora_out.mean().item(), L2T2_lora_out.max().item(), L2T2_lora_out.min().item())
        #     print("L1T1_lora_B_out:", L1T1_lora_B_out.mean().item(), L1T1_lora_B_out.max().item(), L1T1_lora_B_out.min().item())
        #     print("L1T1_lora_out:", L1T1_lora_out.mean().item(), L1T1_lora_out.max().item(), L1T1_lora_out.min().item())
        #     print("L2T1_lora_B_out:", L2T1_lora_B_out.mean().item(), L2T1_lora_B_out.max().item(), L2T1_lora_B_out.min().item())
        #     print("L2T1_lora_out:", L2T1_lora_out.mean().item(), L2T1_lora_out.max().item(), L2T1_lora_out.min().item())
        #     print()
        
        f_L2T2_out = self.nonlinear_fn(L2T2_lora_out, name='L2T2_lora_out')
        f_diff = self.nonlinear_fn(L1T1_lora_out, name='L1T1_lora_out') - self.nonlinear_fn(L2T1_lora_out, name='L2T1_lora_out')
        L1T2_out = f_L2T2_out - f_diff # approximate f^-1 as subtraction
        
        # with torch.no_grad():
        #     print("f_L2T2_out:", f_L2T2_out.mean().item(), f_L2T2_out.max().item(), f_L2T2_out.min().item())
        #     print("f_diff:", f_diff.mean().item(), f_diff.max().item(), f_diff.min().item())
        #     print("L1T2_out:", L1T2_out.mean().item(), L1T2_out.max().item(), L1T2_out.min().item())
        #     print()
        
        if requires_conversion:
            L1T2_out = L1T2_out.to(base_out.dtype)
        return base_out + L1T2_out

    def load_lora_params(self, mode: Literal['L1T1', 'L2T1', 'L2T2'], state_dict, prefix: str):
        self.lora[mode].A.weight.data = state_dict[f'{prefix}.lora_A.weight'].to(self.device)
        self.lora[mode].B.weight.data = state_dict[f'{prefix}.lora_B.weight'].to(self.device)
        if self.lora[mode].bias_flag:
            self.lora[mode].A.bias.data = state_dict[f'{prefix}.lora_A.bias'].to(self.device)
            self.lora[mode].B.bias.data = state_dict[f'{prefix}.lora_B.bias'].to(self.device)
            
class NeroModel(nn.Module):
    def __init__(self, 
                 base_model: nn.Module, 
                 L1T1_lora_config: LoraConfig, 
                 L2T1_lora_config: LoraConfig, 
                 debug: bool = False,
                ):
        super().__init__()
        
        self.base_model = base_model
        self._eval_L2T2_lora = False
        self.debug = debug

        # Wrap target layers with NeroLayer
        self.nero_layers = nn.ModuleDict()
        self._wrap_target_layers(L1T1_lora_config, L2T1_lora_config)
        
    def _wrap_target_layers(self, L1T1_lora_config, L2T1_lora_config):
        assert L1T1_lora_config.target_modules == L2T1_lora_config.target_modules, "[ERROR] L1T1 and L2T1 LoRA configurations must have the same target modules."

        for module_name, module in self.base_model.named_modules():
            if isinstance(module, NeroLayer):
                # Convert module name format and store reference
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.nero_layers[module_name] = module
                continue

            if any(module_name.endswith(target_module) for target_module in L1T1_lora_config.target_modules) and isinstance(module, nn.Linear):    
                parent_module, child_name = self._get_parent_module(module_name)
                nero_layer = NeroLayer(
                    base_layer=module,
                    eval_L2T2_lora=self._eval_L2T2_lora,

                    # L1T1 LoRA parameters
                    L1T1_lora_params={
                        'rank': L1T1_lora_config.r, 
                        'alpha': L1T1_lora_config.lora_alpha, 
                        'dropout': L1T1_lora_config.lora_dropout,
                        'bias': L1T1_lora_config.lora_bias,
                        'use_rslora': L1T1_lora_config.use_rslora,
                    },
                
                    # L2T1 LoRA parameters
                    L2T1_lora_params={
                        'rank': L2T1_lora_config.r, 
                        'alpha': L2T1_lora_config.lora_alpha, 
                        'dropout': L2T1_lora_config.lora_dropout,
                        'bias': L2T1_lora_config.lora_bias,
                        'use_rslora': L2T1_lora_config.use_rslora,
                    },

                    # L2T2 parameters (for temporary, use L2T1 LoRA parameters)
                    L2T2_lora_params={
                        'rank': L2T1_lora_config.r, 
                        'alpha': L2T1_lora_config.lora_alpha, 
                        'dropout': L2T1_lora_config.lora_dropout,
                        'bias': L2T1_lora_config.lora_bias,
                        'use_rslora': L2T1_lora_config.use_rslora,
                    },
                    
                    # Debugging parameters
                    debug=self.debug,
                    module_name=module_name,
                )
                setattr(parent_module, child_name, nero_layer)

                # Store LoRA layers for weight loading
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.nero_layers[module_name] = nero_layer
    
    def _get_parent_module(self, module_name):
        parts = module_name.split('.')
        parent_module = self.base_model
        
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)
            
        return parent_module, parts[-1]
    
    def train_L2T2_lora(self, verbose: bool=False):
        self.freeze_all_except_L2T2_lora()
        
        for layer in self.nero_layers.values():
            layer._eval_L2T2_lora = False

        if verbose:
            print(f"[INFO] Training L2T2 LoRA!")
    
    def train_nonlinear_fn(self, verbose: bool=False):
        self.freeze_all_except_nonlinear_fn()
        
        for layer in self.nero_layers.values():
            layer._eval_L2T2_lora = False
        
        if verbose:
            print(f"[INFO] Training nonlinear functions!")
            
    def eval_L2T2_lora(self, verbose: bool=False):
        self.freeze_all()
        
        for layer in self.nero_layers.values():
            layer._eval_L2T2_lora = True
        
        if verbose:
            print(f"[INFO] Evaluating L2T2 LoRA!")
            
    def freeze_all(self, verbose: bool=False):
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        if verbose:
            print("[INFO] All layers are frozen!")

    def freeze_all_except_L2T2_lora(self, verbose=False):
        self.freeze_all(verbose=verbose)
        
        for nero_layer in self.nero_layers.values():
            for param_name, param in nero_layer.named_parameters():
                if 'lora.L2T2' in param_name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
        
        if verbose:
            print("[INFO] All layers are frozen except L2T2 LoRA layers!")
    
    def freeze_all_except_nonlinear_fn(self, verbose=False):
        self.freeze_all(verbose=verbose)
        
        for nero_layer in self.nero_layers.values():
            for param_name, param in nero_layer.named_parameters():
                if 'nonlinear_fn' in param_name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
        
        if verbose:
            print("[INFO] All layers are frozen except nonlinear functions layers")
    
    def load_lora_params(self, mode: Literal['L1T1', 'L2T1', 'L2T2'], lora_path: str):
        if not os.path.exists(lora_path):
            raise FileNotFoundError("[ERROR] LoRA file not found:", lora_path)
        
        if lora_path.endswith('.safetensors'):
            state_dict = load_file(lora_path)
        else:
            state_dict = torch.load(lora_path, map_location='cpu')

        prefix = list(state_dict.keys())[0].rsplit('model.', 1)[0] + 'model.'

        for nero_layer_name, nero_layer in self.nero_layers.items():
            nero_layer_name = nero_layer_name.replace('__DOT__', '.')
            nero_layer_name = prefix + nero_layer_name
            if f'{nero_layer_name}.lora_A.weight' in state_dict and f'{nero_layer_name}.lora_B.weight' in state_dict:
                nero_layer.load_lora_params(mode, state_dict, nero_layer_name)
            else:
                # TODO: Print warning message
                pass

        print(f"[INFO] {mode} LoRA parameters loaded successfully!")
    
    def forward(self, *args, **kwargs):
        return self.base_model(*args, **kwargs)
    
    def __getattr__(self, name):
        try:
            return super().__getattr__(name) # Try getting attribute from self
        except AttributeError:
            return getattr(self.base_model, name) # Fallback to base_model

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name, 
    device_map=device,
)
model = NeroModel(
    base_model, 
    L1T1_lora_config=model_configs['L1T1']['lora_config'], 
    L2T1_lora_config=model_configs['L2T1']['lora_config'], 
    debug=False,
)

In [8]:
print("Check unloaded L1T1 LoRA parameters:")
check_lora_parameters(model, prefix='L1T1')
print()

model.load_lora_params('L1T1', model_configs['L1T1']['lora_path'])
print()

print("Check loaded L1T1 LoRA parameters:")
check_lora_parameters(model, prefix='L1T1')
print()

Check unloaded L1T1 LoRA parameters:
- name    : base_model.model.layers.0.self_attn.q_proj.lora.L1T1.A.weight
- device  : cuda:0
- dtype   : torch.float32
- mean    : -0.0011653788387775421
- min     : -1.404322624206543
- max     : 1.4684089422225952

[INFO] L1T1 LoRA parameters loaded successfully!

Check loaded L1T1 LoRA parameters:
- name    : base_model.model.layers.0.self_attn.q_proj.lora.L1T1.A.weight
- device  : cuda:0
- dtype   : torch.float32
- mean    : 6.287686119321734e-05
- min     : -0.04176201671361923
- max     : 0.04242725297808647



In [9]:
print("Check unloaded L2T1 LoRA parameters:")
check_lora_parameters(model, prefix='L2T1')
print()

model.load_lora_params('L2T1', model_configs['L2T1']['lora_path'])
print()

print("Check loaded L2T1 LoRA parameters:")
check_lora_parameters(model, prefix='L2T1')
print()

Check unloaded L2T1 LoRA parameters:
- name    : base_model.model.layers.0.self_attn.q_proj.lora.L2T1.A.weight
- device  : cuda:0
- dtype   : torch.float32
- mean    : 0.0026187323965132236
- min     : -1.3684216737747192
- max     : 1.397903561592102

[INFO] L2T1 LoRA parameters loaded successfully!

Check loaded L2T1 LoRA parameters:
- name    : base_model.model.layers.0.self_attn.q_proj.lora.L2T1.A.weight
- device  : cuda:0
- dtype   : torch.float32
- mean    : 2.2152138626552187e-05
- min     : -0.06327299773693085
- max     : 0.0625513345003128



In [10]:
model.freeze_all_except_L2T2_lora()
print()

# check_loss_and_grad_norm(nero_model, tokenizer)




In [11]:
model.gradient_checkpointing_enable({'use_reentrant': False})
print("[INFO] Gradient checkpointing enabled!")
print()

check_loss_and_grad_norm(model, tokenizer)

[INFO] Gradient checkpointing enabled!

Loss: tensor(3.5031, device='cuda:0', grad_fn=<ToCopyBackward0>)


  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


Gradient norm: 25.50963737989262


# Data

In [12]:
datasets = load_dataset(hf_data_id, data_dir=hf_data_dir)
print(datasets)

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 5000
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 1000
    })
})


In [13]:
train_loader = DataLoader(
    datasets['train'], 
    batch_size=batch_size, 
    shuffle=True, 
    collate_fn=default_data_collator,
)
print("[INFO] Total batches:", len(train_loader))

[INFO] Total batches: 1250


In [14]:
# Sanity check
first_batch = next(iter(train_loader))
print("First batch data shape (input_ids, attention_mask):")
print((
    first_batch['input_ids'].shape, 
    first_batch['attention_mask'].shape, 
))
print()

first_batch_text = tokenizer.batch_decode(first_batch['input_ids'], skip_special_tokens=True)[0]
print("First batch text:")
print(first_batch_text[:100], "...")
print()

check_loss_and_grad_norm(model, tokenizer)

First batch data shape (input_ids, attention_mask):
(torch.Size([4, 512]), torch.Size([4, 512]))

First batch text:
### Instruction:
Solve the following math problem step by step.

### Question: 
Jamie owns 4 Persian ...

Loss: tensor(3.5031, device='cuda:0', grad_fn=<ToCopyBackward0>)
Gradient norm: 25.50963737989262


In [15]:
check_loss_and_grad_norm(model, tokenizer, prompt=first_batch_text)

Loss: tensor(1.5417, device='cuda:0', grad_fn=<ToCopyBackward0>)
Gradient norm: 7.547106782720999


# Training

In [16]:
# Set the model to training mode
model.train()

# If `device` is not specified or set to 'auto', use the model's device
# if device is None or device == 'auto':
device = next(iter(model.parameters())).device

# Set up optimizer and gradient scaler
# for L2T2 LoRA
model.train_L2T2_lora()
L2T2_lora_params = [p for n, p in model.named_parameters() if p.requires_grad]
L2T2_optimizer = torch.optim.Adam(L2T2_lora_params, lr=lr)
L2T2_scaler = torch.cuda.amp.GradScaler()

# for nonlinear functions
model.train_nonlinear_fn()
nonlinear_fn_params = [p for n, p in model.named_parameters() if p.requires_grad]
nonlinear_fn_optimizer = torch.optim.Adam(nonlinear_fn_params, lr=lr)
nonlinear_fn_scaler = torch.cuda.amp.GradScaler()

# Set up LR scheduler
max_global_steps = max_global_steps or len(train_loader) * num_epochs
warmup_steps = int(warmup_ratio * max_global_steps)
if warmup_ratio > 0:
    # If `warmup_ratio` > 0, use cosine annealing scheduler with warm-up 
    from transformers import get_cosine_schedule_with_warmup # type: ignore
    max_optimizer_steps = (max_global_steps // grad_accumulation_steps) // 2 # divide by 2 because we train in 2 modes
    num_warmup_steps = int(warmup_ratio * max_optimizer_steps)
    L2T2_scheduler = get_cosine_schedule_with_warmup(
        L2T2_optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=max_optimizer_steps,

        # Cycle length is 2000 steps or less, but at least 1 cycle, and at most 10 cycles
        num_cycles=min(max(1, max_optimizer_steps // 2000), 10), 
    )
    nonlinear_fn_scheduler = get_cosine_schedule_with_warmup(
        nonlinear_fn_optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=max_optimizer_steps,

        # Cycle length is 2000 steps or less, but at least 1 cycle, and at most 10 cycles
        num_cycles=min(max(1, max_optimizer_steps // 2000), 10), 
    )
else:
    # If `warmup_ratio` is 0, use a dummy scheduler that returns constant LR
    from torch.optim.lr_scheduler import LambdaLR # type: ignore
    L2T2_scheduler = LambdaLR(L2T2_optimizer, lr_lambda=lambda step: 1.0)
    nonlinear_fn_scheduler = LambdaLR(nonlinear_fn_optimizer, lr_lambda=lambda step: 1.0)

# Initialize W&B
wandb.init(
    project='Nero-XLT',
    reinit=True, # End previous run and start a new one
    config=dict(
        # Project configuration
        seed = seed,
        device = device,

        # Data configuration
        hf_data_id = hf_data_id,
        hf_data_dir = hf_data_dir,

        # Training configuration
        batch_size = batch_size,
        num_epochs = num_epochs,
        max_global_steps = max_global_steps,
        grad_accumulation_steps = grad_accumulation_steps,
        clip_grad_norm = clip_grad_norm,
        lr = lr,
        warmup_ratio = warmup_ratio,
        checkpoint_steps = checkpoint_steps,
        distance_fn = distance_fn,
        resume_step = resume_step,
    ),
)

# Resume training
global_step = resume_step
start_epoch = 0

def load_trainer_params(target, model, optimizer, scheduler, scaler, checkpoint_dir, device):
    # Load Nero parameters
    nero_path = os.path.join(checkpoint_dir, f'{target}.safetensors')
    # model.load_nero_params(mode, nero_path)

    # Load optimizer state
    optimizer_path = os.path.join(checkpoint_dir, f'{target}_optimizer.pt')
    optimizer.load_state_dict(torch.load(optimizer_path, map_location=device))
    
    # Move optimizer state to the correct device
    for param in optimizer.state:
        param_device = param.device
        param_dtype = param.dtype
        for key, value in optimizer.state[param].items():
            if isinstance(value, torch.Tensor):
                optimizer.state[param][key] = value.to(device=param_device, dtype=param_dtype)

    # Load scheduler state
    scheduler_path = os.path.join(checkpoint_dir, f'{target}_scheduler.pt')
    scheduler.load_state_dict(torch.load(scheduler_path, map_location=device))

    # Load scaler state
    scaler_path = os.path.join(checkpoint_dir, f'{target}_scaler.pt')
    scaler.load_state_dict(torch.load(scaler_path, map_location=device))

if resume_step > 0:
    checkpoint_dir = os.path.join(nero_dir, f'checkpoint-{resume_step}')
    print(f"[INFO] Resuming training from checkpoint directory:", checkpoint_dir)

    # Load trainer parameters
    load_trainer_params('L2T2_lora', model, L2T2_optimizer, L2T2_scheduler, L2T2_scaler, checkpoint_dir, device)
    load_trainer_params('nonlinear_fn', model, nonlinear_fn_optimizer, nonlinear_fn_scheduler, nonlinear_fn_scaler, checkpoint_dir, device)

    # Load trainer state
    trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json')
    if os.path.exists(trainer_state_path):
        with open(trainer_state_path, 'r') as f:
            trainer_state = json.load(f)
        log_history = trainer_state.get('log_history', [])
        start_epoch = log_history[-1]['epoch'] if log_history else 0
        print(f"[INFO] Resuming training from epoch {start_epoch} and step {resume_step}.")

    # Load RNG state for reproducibility
    rng_path = os.path.join(checkpoint_dir, 'rng_state.pth')
    if os.path.exists(rng_path):
        rng_state = torch.load(rng_path)
        random.setstate(rng_state['python'])
        np.random.set_state(rng_state['numpy'])
        torch.set_rng_state(rng_state['cpu'])
        if torch.cuda.is_available() and rng_state['cuda']:
            torch.cuda.set_rng_state_all(rng_state['cuda'])
    
    if resume_step % grad_accumulation_steps != 0:
        print("[WARN] Resuming mid-gradient accumulation cycle. Make sure this is intended.")
else:
    if push_to_hf:
        # If it's new training, create Hugging Face repository
        print(f"[INFO] Creating Hugging Face repository:", hf_nero_id) # print the link instead
        create_repo(repo_id=hf_nero_id, repo_type='model', exist_ok=True)
        print(f"[INFO] Hugging Face repository created successfully!")

  L2T2_scaler = torch.cuda.amp.GradScaler()
  nonlinear_fn_scaler = torch.cuda.amp.GradScaler()
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33malimtegar[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112919377774233, max=1.0…

In [17]:
target = 'L2T2_lora' # or 'nonlinear_fn'
log_history = []
done = False

# Safety: Zero gradients at the start of gradient accumulation cycle
# This ensures there are no leftover gradients when resuming mid-cycle or after a previous cycle was interrupted
if global_step % grad_accumulation_steps == 0:
    L2T2_optimizer.zero_grad(set_to_none=True)

# Training loop
for epoch in range(start_epoch, num_epochs):
    for step, batch in enumerate(train_loader):
        # Skip previously completed steps
        if global_step <= resume_step:
            global_step += 1
            continue

        # Stop training if `max_global_steps` reached
        if global_step >= max_global_steps:
            done = True
            break

        # Move inputs to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        with torch.cuda.amp.autocast():
            # Forward pass
            if target == 'L2T2_lora':
                model.train_L2T2_lora()
            else:
                model.train_nonlinear_fn()
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=input_ids,
                use_cache=False, # Disable cache to avoid conflict with gradient checkpointing
            )
            
            _loss = outputs.loss
            print("_loss:", _loss)
            loss = _loss / grad_accumulation_steps
            print("loss:", loss)

        log = {
            'target': target,
            'epoch': epoch,
            'step': global_step,
        }

        # Backward pass
        # with torch.autograd.set_detect_anomaly(True):
        log['loss'] = loss.item() * grad_accumulation_steps
        if target == 'L2T2_lora':
            L2T2_scaler.scale(loss).backward()
        else:
            nonlinear_fn_scaler.scale(loss).backward()
        
        # Update parameters only at the end of gradient accumulation cycle
        grad_norm_log = {}
        print("TEST 0")
        if (global_step + 1) % grad_accumulation_steps == 0:
            print("TEST 1")
            if target == 'L2T2_lora':
                print("TEST 2")
                # Unscale gradients before computing gradient norm and applying clipping
                L2T2_scaler.unscale_(L2T2_optimizer)
                
                named_params = {n: p for n, p in model.named_parameters() if 'L2T2' in n}
                compute_named_grad_norm(named_params)

                # Compute gradient norm
                grad_norm = compute_grad_norm(L2T2_lora_params)
                grad_norm_log['grad_norm'] = grad_norm

                # Clip gradients
                if clip_grad_norm is not None:
                    torch.nn.utils.clip_grad_norm_(L2T2_lora_params, clip_grad_norm)
                
                # Compute clipped gradient norm
                grad_norm_clipped = compute_grad_norm(L2T2_lora_params)
                grad_norm_log['grad_norm_clipped'] = grad_norm_clipped

                # Update parameters
                L2T2_scaler.step(L2T2_optimizer)
                L2T2_scaler.update()
                L2T2_scheduler.step()

                # Zero gradients for the next gradient accumulation cycle
                L2T2_optimizer.zero_grad(set_to_none=True)
            else:
                print("TEST 3")
                # Unscale gradients before computing gradient norm and applying clipping
                nonlinear_fn_scaler.unscale_(nonlinear_fn_optimizer)
                
                # Compute gradient norm
                grad_norm = compute_grad_norm(nonlinear_fn_params)
                grad_norm_log['nonlinear_fn/grad_norm'] = grad_norm

                # Clip gradients
                if clip_grad_norm is not None:
                    torch.nn.utils.clip_grad_norm_(nonlinear_fn_params, clip_grad_norm)
                
                # Compute clipped gradient norm
                grad_norm_clipped = compute_grad_norm(nonlinear_fn_params)
                grad_norm_log['nonlinear_fn/grad_norm_clipped'] = grad_norm_clipped

                # Update parameters
                nonlinear_fn_scaler.step(nonlinear_fn_optimizer)
                nonlinear_fn_scaler.update()
                nonlinear_fn_scheduler.step()

                # Zero gradients for the next gradient accumulation cycle
                nonlinear_fn_optimizer.zero_grad(set_to_none=True)

            # After updating parameters, toggle training target
            target = 'nonlinear_fn' if target == 'L2T2_lora' else 'L2T2_lora'

        # Logging
        lr_log = {
            'L1T2_lora/lr': L2T2_scheduler.get_last_lr()[0],
            'nonlinear_fn/lr': nonlinear_fn_scheduler.get_last_lr()[0]
        }
        log = {
            **log, 
            **lr_log, 
            **grad_norm_log,
        }
        log_history.append(log)
        wandb.log(log)
        print(", ".join(
            f"{k}: {format_float(v)}" if isinstance(v, float) else f"{k}: {v}"
            for k, v in log.items()
        ))
        
        # Save and push checkpoint every `checkpoint_steps`
        if global_step > 0 and global_step % checkpoint_steps == 0:
            # Create checkpoint directory
            checkpoint_dir = os.path.join(nero_dir, f'checkpoint-{global_step}')
            os.makedirs(checkpoint_dir, exist_ok=True)

            # Save Nero parameters, along with optimizer, scheduler, and scaler states
            # L1T2_nero_state_dict = {n: p.detach().cpu() for n, p in model.named_parameters() if 'L1T2' in n}
            # save_file(L1T2_nero_state_dict, os.path.join(checkpoint_dir, 'L1T2_nero.safetensors'))
            # torch.save(L1T2_optimizer.state_dict(), os.path.join(checkpoint_dir, 'L1T2_optimizer.pt'))
            # torch.save(L1T2_scheduler.state_dict(), os.path.join(checkpoint_dir, 'L1T2_scheduler.pt'))
            # torch.save(L1T2_scaler.state_dict(), os.path.join(checkpoint_dir, 'L1T2_scaler.pt'))

            L2T2_nero_state_dict = {n: p.detach().cpu() for n, p in model.named_parameters() if 'L2T2' in n}
            save_file(L2T2_nero_state_dict, os.path.join(checkpoint_dir, 'L2T2_nero.safetensors'))
            torch.save(L2T2_optimizer.state_dict(), os.path.join(checkpoint_dir, 'L2T2_optimizer.pt'))
            torch.save(L2T2_scheduler.state_dict(), os.path.join(checkpoint_dir, 'L2T2_scheduler.pt'))
            torch.save(L2T2_scaler.state_dict(), os.path.join(checkpoint_dir, 'L2T2_scaler.pt'))

            # Save trainer state for resuming training
            trainer_state = {
                'epoch': epoch,
                'global_step': global_step,
                'log_history': log_history,
            }
            with open(os.path.join(checkpoint_dir, 'trainer_state.json'), 'w') as f:
                json.dump(trainer_state, f, indent=2)

            # Save RNG state for reproducibility
            rng_state = {
                'python': random.getstate(),
                'numpy': np.random.get_state(),
                'cpu': torch.get_rng_state(),
                'cuda': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else [],
            }
            torch.save(rng_state, os.path.join(checkpoint_dir, 'rng_state.pth'))

            # Upload checkpoint directory to Hugging Face repository
            if push_to_hf:
                upload_folder(
                    folder_path=checkpoint_dir,
                    repo_id=hf_nero_id,
                    path_in_repo=f"checkpoint-{global_step}",
                    commit_message=f"Add checkpoint at step {global_step}",
                    repo_type='model',
                )
        
        # Check generated text every `generate_steps`
        if global_step > 0 and global_step % generate_steps == 0:
            model.set_train_L2T2_nero(False)
            generated = generate_text(model, tokenizer, sample_prompt, device=device)
            print("================================")
            print("CHECK GENERATED TEXT")
            print("================================")
            print(f"{'Prompt':<9}:", sample_prompt)
            print(f"{'Generated':<9}:", generated)
            print()
        
        global_step += 1
    
    if done:
        break

wandb.finish()

  with torch.cuda.amp.autocast():


_loss: tensor(6.9696, device='cuda:0', grad_fn=<ToCopyBackward0>)
loss: tensor(3.4848, device='cuda:0', grad_fn=<DivBackward0>)


  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


TEST 0
TEST 1
TEST 2
{n} p_grad_norm: tensor(nan, device='cuda:0')
{n} p_grad_norm: tensor(nan, device='cuda:0')
{n} p_grad_norm: tensor(nan, device='cuda:0')
{n} p_grad_norm: tensor(nan, device='cuda:0')
{n} p_grad_norm: tensor(nan, device='cuda:0')
{n} p_grad_norm: tensor(nan, device='cuda:0')
{n} p_grad_norm: tensor(nan, device='cuda:0')
{n} p_grad_norm: tensor(nan, device='cuda:0')
{n} p_grad_norm: tensor(0.0032, device='cuda:0')
{n} p_grad_norm: tensor(1.0554, device='cuda:0')
{n} p_grad_norm: tensor(0.0053, device='cuda:0')
{n} p_grad_norm: tensor(1.7919, device='cuda:0')
{n} p_grad_norm: tensor(0.0125, device='cuda:0')
{n} p_grad_norm: tensor(3.8190, device='cuda:0')
{n} p_grad_norm: tensor(0.0007, device='cuda:0')
{n} p_grad_norm: tensor(0.1697, device='cuda:0')
{n} p_grad_norm: tensor(0.0004, device='cuda:0')
{n} p_grad_norm: tensor(0.1290, device='cuda:0')
{n} p_grad_norm: tensor(0.0336, device='cuda:0')
{n} p_grad_norm: tensor(inf, device='cuda:0')
{n} p_grad_norm: tensor(0.

KeyboardInterrupt: 