In [1]:
# Kill all processess on GPU
!fuser -v /dev/nvidia* -k

# Libraries

In [2]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    %pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    %pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    %pip install sentencepiece protobuf "datasets>=4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    %pip install --no-deps unsloth
%pip install trl==0.19.1 # Fix error: ImportError: cannot import name 'ConstantLengthDataset' from 'trl.trainer.utils'

In [3]:
import os
import functools
import gc
import json
import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from typing import Optional, Literal, Union, List
from torch.utils.data import DataLoader
from datasets import load_dataset, Dataset
from transformers import default_data_collator, AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
from huggingface_hub import snapshot_download, create_repo, upload_folder
from safetensors.torch import load_file, save_file
from datetime import datetime
from pprint import pprint

import cuml
from cuml.decomposition import PCA as cuPCA
from cuml.manifold import UMAP as cuUMAP

import copy

In [4]:
def download_hf_model(
        repo_id: str, 
        checkpoint: Optional[int], 
        max_checkpoints: str = 10_000, 
        checkpoint_steps: str = 25,
    ):
    local_dir = repo_id.split('/')[-1]
    ignore_checkpoints = None
    
    if checkpoint is not None:
        ignore_checkpoints = [f'checkpoint-{i}/*' for i in range(0, max_checkpoints, checkpoint_steps) if i != checkpoint]

    snapshot_download(
        repo_id=repo_id,
        local_dir=local_dir,
        ignore_patterns=ignore_checkpoints,
    )

    checkpoint_dir = None
    if checkpoint is not None:
        checkpoint_dir = os.path.join(local_dir, f'checkpoint-{checkpoint}')
    return local_dir, checkpoint_dir

def check_loss_and_grad_norm(
    model, 
    tokenizer, 
    prompt="Paris is the capital of",
):
    # Set model to training mode
    model.train()

    # Zero gradients manually
    for p in model.parameters():
        p.grad = None

    # Forward pass
    inputs = tokenizer(
        prompt, 
        return_tensors='pt',
    ).to(next(model.parameters()).device)

    outputs = model(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        labels=inputs['input_ids'],
        use_cache=False, # Disable cache to not conflict with gradient checkpointing
    )
    if isinstance(outputs, tuple):
        outputs = outputs[0]
        print("Loss 1:", outputs.loss)
        print("Loss 2:", outputs[1])
    else:
        print("Loss:", outputs.loss)

    # Backward pass
    if outputs.loss.grad_fn is None:
        print("Gradient norm:", None)
        return

    outputs.loss.backward()

    # Compute gradient norm
    grad_norm = 0.0
    for n, p in model.named_parameters():
        if p.requires_grad:
            p_grad_norm = p.grad.data.norm(2)
            grad_norm += p_grad_norm.item() ** 2
    grad_norm = grad_norm ** 0.5

    print("Gradient norm:", grad_norm)

def check_parameter(n, p):
    print(f"- {'name':<8}:", n)
    print(f"- {'device':<8}:", p.device)
    print(f"- {'dtype':<8}:", p.dtype)
    print(f"- {'mean':<8}:", p.mean().item())
    print(f"- {'min':<8}:", p.min().item())
    print(f"- {'max':<8}:", p.max().item())

def check_lora_parameters(model, prefix=None):
    prefix = 'lora.' + prefix if prefix != None else 'lora'
    for n, p in model.named_parameters():
        if prefix in n:
            check_parameter(n, p)
            break

@torch.no_grad()
def generate_text(model, tokenizer, prompt, max_new_tokens=100, device=None, skip_special_tokens=True):
    device = device or next(model.parameters()).device
    model.eval()
    inputs = tokenizer(prompt, return_tensors='pt').to(device)
    outputs = model.generate(input_ids=inputs['input_ids'], max_new_tokens=max_new_tokens)
    model.train()
    return tokenizer.decode(outputs[0], skip_special_tokens=skip_special_tokens)

def get_task_and_lang_from_repo_id(repo_id: str):
    task, lang, _ = repo_id.split('B-')[-1].split('K-')[0].split('-')
    return task, lang

def load_hf_dataset_from_lora(
    lora_repo_id: str,
    train_size: int = 5000,
    test_size: int = 1000,
):
    # Get task and language
    task, lang = get_task_and_lang_from_repo_id(lora_repo_id)

    # Set up Hugging Face configuration
    data_id_map = {
        'wikipedia': 'wikimedia/wikipedia',
        'gsm8k': 'openai/gsm8k',
    }
    data_id = data_id_map[task]
    data_dir = f'20231101.{lang}' if task == 'wikipedia' else 'main'
    
    # Load dataset using streaming
    dataset_stream = load_dataset(data_id, data_dir=data_dir, split='train', streaming=True)

    # Manually take train_size + test_size samples
    total_size = train_size + test_size
    sliced_data = []
    for i, example in enumerate(dataset_stream):
        if i >= total_size:
            break
        sliced_data.append(example)

    # Convert to regular in-memory dataset
    dataset = Dataset.from_list(sliced_data)
    
    return dataset

def load_hf_dataset(
    lang, 
    task,
    split='train',
    train_size = 5000,
    test_size = 1000,
):
    # Set up Hugging Face configuration
    data_id_map = {
        'wikipedia': 'wikimedia/wikipedia',
        'gsm8k': 'openai/gsm8k',
    }
    data_id = data_id_map[task]
    data_dir = f'20231101.{lang}' if task == 'wikipedia' else 'main'

    # Use streaming
    dataset_stream = load_dataset(data_id, data_dir=data_dir, split=split, streaming=True)

    # Manually take train_size + test_size samples
    total_size = train_size + test_size
    sliced_data = []
    for i, example in enumerate(dataset_stream):
        if i >= total_size:
            break
        sliced_data.append(example)

    # Convert to regular in-memory dataset
    dataset = Dataset.from_list(sliced_data)
    
    return dataset

def compute_grad_norm(params):
    grad_norm = 0.0
    for p in params:
        p_grad_norm = p.grad.data.norm(2)
        grad_norm += p_grad_norm.item() ** 2
    grad_norm = grad_norm ** 0.5
    return grad_norm

def compute_named_grad_norm(named_params):
    grad_norm = 0.0
    for n, p in named_params.items():
        if p.grad is not None:
            p_grad_norm = p.grad.data.norm(2)
            print(f"{n}/p_grad_norm:", p_grad_norm)
            grad_norm += p_grad_norm.item() ** 2
        else:
            print(f"[WARN] No gradient for {n}")
    grad_norm = grad_norm ** 0.5
    return grad_norm

def format_float(v):
    if abs(v) < 0.0001 or abs(v) >= 10000:
        return f"{v:.4e}"
    else:
        return f"{v:.4f}"

# Configurations

In [5]:
# Project configuration
seed = 69
device = 'auto'

# Data configuration
hf_data_id = 'alxxtexxr/Nero-XLT-Dataset'
hf_data_dir = 'gsm8k_en_5K_1K_1K_512'

# Training configuration
batch_size = 4
num_epochs = 1
max_global_steps = None
grad_accumulation_steps = 2
clip_grad_norm = 1.0
lr = 2e-4
warmup_ratio = 0.1
# num_warmup_steps = 100
checkpoint_steps = 10
push_to_hf = False
generate_steps = 10
L1T2_sample_prompt = """### Instruction:
Solve the following math problem step by step.

### Question:
Dan plants 3 rose bushes. Each rose bush has 25 roses. Each rose has 8 thorns. How many thorns are there total?

### Answer:
"""
L2T2_sample_prompt = """### 命令:
次の数学の問題を段階的に解いてください。

### 質問:
ダンは3本のバラの灌木を植えました。各バラの灌木には25本のバラがあります。各バラには8本のトゲがあります。合計で何本のトゲがあるでしょうか？

### 答え:
"""

# Model configurations
model_configs = {
    # L1T1 (Source Language - Source Task)
    'L1T1': {
        'hf_lora_id': 'alxxtexxr/L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650',
        'checkpoint': 650,
    },

    # L2T1 (Target Language - Source Task)
    'L2T1': {
        'hf_lora_id': 'alxxtexxr/L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629',
        'checkpoint': 650,
    },

    # L1T2 (Source Language - Target Task)
    # 'L1T2': {
    #     'hf_lora_id': 'alxxtexxr/L3.1-8B-gsm8k-en-5K-LoRA-v20250701060457',
    #     'checkpoint': 1875,
    # },
}

for key, config in model_configs.items():
    _, lora_dir = download_hf_model(config['hf_lora_id'], config['checkpoint'])
    model_configs[key]['lora_dir'] = lora_dir
    model_configs[key]['lora_path'] = os.path.join(lora_dir, 'adapter_model.safetensors')
    model_configs[key]['lora_config'] = LoraConfig.from_pretrained(lora_dir)

print("Model configurations:",)
for key, config in model_configs.items():
    print(f"- {key}:")
    for config_name, config_value in config.items():
        if config_name == 'lora_config':
            continue
        print(f"{'-':>3} {config_name:<10}: {config_value}")
print()

assert (
    model_configs['L1T1']['lora_config'].base_model_name_or_path == 
    model_configs['L2T1']['lora_config'].base_model_name_or_path
), "Base models must be the same"
base_model_name = model_configs['L1T1']['lora_config'].base_model_name_or_path
print(f"Base model name: {base_model_name}")

tokenizer = AutoTokenizer.from_pretrained(base_model_name)

Fetching 20 files:   0%|          | 0/20 [00:00<?, ?it/s]

Fetching 20 files:   0%|          | 0/20 [00:00<?, ?it/s]

Model configurations:
- L1T1:
  - hf_lora_id: alxxtexxr/L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650
  - checkpoint: 650
  - lora_dir  : L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650/checkpoint-650
  - lora_path : L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650/checkpoint-650/adapter_model.safetensors
- L2T1:
  - hf_lora_id: alxxtexxr/L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629
  - checkpoint: 650
  - lora_dir  : L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629/checkpoint-650
  - lora_path : L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629/checkpoint-650/adapter_model.safetensors

Base model name: unsloth/meta-llama-3.1-8b-unsloth-bnb-4bit


In [6]:
# Hugging Face configuration
hf_nero_id = None
resume_step = 0

if hf_nero_id is not None and resume_step > 0:
    print(f"[INFO] Downloading Nero checkpoint at step {resume_step} from Hugging Face repository:", hf_nero_id)
    nero_dir, _ = download_hf_model(hf_nero_id, resume_step)
    print(f"[INFO] Nero checkpoint downloaded successfully!")
else:
    hf_username = 'alxxtexxr'
    nero_dir = f'L3.1-8B-{hf_data_dir.replace("_", "-")}-Nero-v{datetime.now().strftime("%Y%m%d%H%M%S")}'
    print(f"[INFO] Creating Nero directory:", nero_dir)
    hf_nero_id = f'{hf_username}/{nero_dir}'
    os.makedirs(nero_dir, exist_ok=True)
    print(f"[INFO] Nero directory created!")

[INFO] Creating Nero directory: L3.1-8B-gsm8k-en-5K-1K-1K-512-Nero-v20250824121713
[INFO] Nero directory created!


# Model

In [7]:
class G(nn.Module):
    def __init__(self, dim, residual, rank, alpha, beta, use_rslora=False):
        super().__init__()
        self.residual = residual
        self.scaling = alpha / math.sqrt(rank) if use_rslora else alpha / rank
        self.beta = beta
        self.A = nn.Linear(dim, rank, bias=False)
        self.B = nn.Linear(rank, dim, bias=False)
        self.act = nn.GELU()
        self.ln = nn.LayerNorm(dim)
        
        # Initialize weights
        nn.init.normal_(self.A.weight, 0.0, 1 / math.sqrt(rank))
        nn.init.zeros_(self.B.weight)
    
    def forward(self, x):
        act_out = self.B(self.act(self.A(x))) * self.scaling * self.beta
        if self.residual:
            return x + act_out
        else:
            return act_out

class NeroLayer(nn.Module):
    def __init__(self, base_layer,
                 
                 # LoRA parameters
                 L1T1_lora_params, 
                 L2T1_lora_params,
                 L2T2_lora_params,
                 lora_beta,
                 train_L2T2_lora_and_g=True,
                 eval_L1T2_lora=False,
                 eval_L2T2_lora=False,
                 return_layer_loss=True,
                 
                 # Debugging parameters
                 debug=False,
                 module_name=None,
                 ):
        super().__init__()

        self.base_layer = base_layer
        self.device = base_layer.weight.device
        self.lora_beta = lora_beta
        self._train_L2T2_lora_and_g = train_L2T2_lora_and_g
        self._eval_L1T2_lora = eval_L1T2_lora
        self._eval_L2T2_lora = eval_L2T2_lora
        self._return_layer_loss = return_layer_loss
        self.debug = debug
        self.module_name = module_name

        # Extract input and output features from the base layer
        in_features = getattr(base_layer, 'in_features', None)
        out_features = getattr(base_layer, 'out_features', None)

        if in_features is None or out_features is None:
            raise ValueError(f"Cannot determine in_features or out_features from {base_layer}.")

        # Initialize LoRA layers
        self.lora = nn.ModuleDict({
            'L1T1': self._init_lora_layer('L1T1', in_features, out_features, **L1T1_lora_params, device=self.device), # frozen
            'L2T1': self._init_lora_layer('L2T1', in_features, out_features, **L2T1_lora_params, device=self.device), # frozen
            'L2T2': self._init_lora_layer('L2T2', in_features, out_features, **L2T2_lora_params, device=self.device), # trainable
        })
        
        # Initialize nonlinear function layer
        g_rank = L2T2_lora_params['rank'] #// 2 # WARN: Hardcoded rank for g layer, can be adjusted
        self.g = G(
            out_features, 
            residual=True, 
            rank=g_rank, 
            alpha=L2T2_lora_params['alpha'],
            beta=self.lora_beta,
        ).to(self.device) # trainable
        
    def _init_lora_layer(self, name, in_features, out_features, rank, alpha, dropout=0.0, bias=True, use_rslora=False, device=None):
        scaling = alpha / math.sqrt(rank) if use_rslora else alpha / rank
        dropout_layer = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        layer = nn.ModuleDict({
            'A': nn.Linear(in_features, rank, bias=bias, device=device),
            'B': nn.Linear(rank, out_features, bias=bias, device=device),
            'dropout': dropout_layer
        })
        layer.scaling = scaling
        layer.bias_flag = bias
        if name == 'L2T2':
            nn.init.zeros_(layer.A.weight)
        else:
            nn.init.normal_(layer.A.weight, 0.0, 1 / math.sqrt(rank))
        nn.init.zeros_(layer.B.weight)
        return layer
    
    def set_lora_beta(self, lora_beta, verbose: bool=False):
        self.lora_beta = lora_beta
        self.g.beta = lora_beta
        
        if verbose:
            print(f"[INFO] LoRA beta on {self.module_name} layer is set to {lora_beta}!")

    def forward(self, x):
        print("================================================================")
        print(self.module_name)
        print("================================================================")
        
        # ================================================================
        # Base Layer
        # ================================================================
        base_out = self.base_layer(x)

        requires_conversion = not torch.is_autocast_enabled()
        if requires_conversion:
            # TODO: Add assertion to check if all weights have the same dtype
            x = x.to(self.lora.L2T2.A.weight.dtype)
            
        # ================================================================
        # L2T2 LoRA Layer (Trainable)
        # ================================================================
        # L2T2_lora_out = self.lora.L2T2.B(self.lora.L2T2.A(self.lora.L2T2.dropout(x))) #* self.lora.L2T2.scaling
        L2T2_lora_dropout_out = self.lora.L2T2.dropout(x)
        L2T2_lora_A_out = self.lora.L2T2.A(L2T2_lora_dropout_out)
        L2T2_lora_B_out = self.lora.L2T2.B(L2T2_lora_A_out)
        L2T2_lora_out = L2T2_lora_B_out * self.lora.L2T2.scaling * self.lora_beta
        
        if self._eval_L2T2_lora:
            if requires_conversion:
                L2T2_lora_out = L2T2_lora_out.to(base_out.dtype)
            return base_out + L2T2_lora_out

        # ================================================================
        # L1T1 LoRA Layer (Frozen)
        # ================================================================
        # L1T1_lora_out = self.lora.L1T1.B(self.lora.L1T1.A(self.lora.L1T1.dropout(x))) #* self.lora.L1T1.scaling
        L1T1_lora_dropout_out = self.lora.L1T1.dropout(x)
        L1T1_lora_A_out = self.lora.L1T1.A(L1T1_lora_dropout_out)
        L1T1_lora_B_out = self.lora.L1T1.B(L1T1_lora_A_out)
        L1T1_lora_out = L1T1_lora_B_out * self.lora.L1T1.scaling * self.lora_beta
        
        # ================================================================
        # L2T1 LoRA Layer (Frozen)
        # ================================================================
        # L2T1_lora_out = self.lora.L2T1.B(self.lora.L2T1.A(self.lora.L2T1.dropout(x))) #* self.lora.L2T1.scaling
        L2T1_lora_dropout_out = self.lora.L2T1.dropout(x)
        L2T1_lora_A_out = self.lora.L2T1.A(L2T1_lora_dropout_out)
        L2T1_lora_B_out = self.lora.L2T1.B(L2T1_lora_A_out)
        L2T1_lora_out = L2T1_lora_B_out * self.lora.L2T1.scaling * self.lora_beta
        
        # ================================================================
        # L1T2 Output
        # ================================================================
        # Assume linear shift:
        # king - queen = man - woman
        # L1T2 - L2T2 = L1T1 - L2T1
        
        # Assume difference-based nonlinear shift (with free L2T2):
        # king - queen = man - woman
        # L1T2 - L2T2 = L1T1 - L2T1
        # f(L1T2 - L2T2) = f(L1T1 - L2T1)
        # L1T2 - L2T2 = f^-1(f(L1T1 - L2T1))
        # L1T2 = L2T2 + f^-1(f(L1T1 - L2T1))
        # L1T2 = L2T2 + g(L1T1 - L2T1)
        g_out = self.g(L1T1_lora_out - L2T1_lora_out)
        L1T2_out = L2T2_lora_out + g_out
        # Note:
        # L2T2 weights should be initialized from L2T1 weights
        # so that during initialization: L1T2 = L2T2 + g(L1T1 - L2T1) => L1T2 = L2T1 + 0 => L1T2 = L2T1
        # This ensures the model behaves like sequential fine-tuning from L2T1 to L1T2,
        # with g(L1T1 - L2T1) acting as a constraint.
        
        # Assume difference-based nonlinear shift (without free L2T2):
        # king - queen = man - woman
        # king - man = queen - woman
        # L1T2 - L1T1 = L2T2 - L2T1
        # f(L1T2 - L1T1) = f(L2T2 - L2T1)
        # L1T2 - L1T1 = f^-1(f(L2T2 - L2T1))
        # L1T2 = L1T1 + f^-1(f(L2T2 - L2T1))
        # L1T2 = L1T1 + g(L2T2 - L2T1)
        # g_out = self.g(L2T2_lora_out - L2T1_lora_out)
        # L1T2_out = L1T1_lora_out + g_out
        
        # if self.debug:
        # with torch.no_grad():
        #     print("L1T1_lora_B_out:", (L1T1_lora_B_out.max().item(), L1T1_lora_B_out.mean().item(), L1T1_lora_B_out.min().item()))
        #     # print("L1T1/scaling:", self.lora.L1T1.scaling)
        #     print("L1T1_lora_out:", (L1T1_lora_out.max().item(), L1T1_lora_out.mean().item(), L1T1_lora_out.min().item()))
        #     print()
        #     print("L2T1_lora_B_out:", (L2T1_lora_B_out.max().item(), L2T1_lora_B_out.mean().item(), L2T1_lora_B_out.min().item()))
        #     # # print("L2T1/scaling:", self.lora.L2T1.scaling)
        #     print("L2T1_lora_out:", (L2T1_lora_out.max().item(), L2T1_lora_out.mean().item(), L2T1_lora_out.min().item()))
        #     print()
        #     print("L2T2_lora_B_out:", (L2T2_lora_B_out.max().item(), L2T2_lora_B_out.mean().item(), L2T2_lora_B_out.min().item()))
        #     # # print("L2T2/scaling:", self.lora.L2T2.scaling)
        #     print("L2T2_lora_out:", (L2T2_lora_out.max().item(), L2T2_lora_out.mean().item(), L2T2_lora_out.min().item()))
        #     print()
        #     # # print("g/scaling:", self.g.scaling)
        #     print("g_out:", (g_out.max().item(), g_out.mean().item(), g_out.min().item())) # must be zero at init
        #     print("L1T2_out:", (L1T2_out.max().item(), L1T2_out.mean().item(), L1T2_out.min().item())) # must be zero at init
        #     print()
        
        # ================================================================
        # Loss
        # ================================================================
        loss_fn = 'mse' # WARN: Hardcoded loss function
        if loss_fn is not None:
            if loss_fn == 'mse':
                loss = F.mse_loss(L2T2_lora_out, , reduction='mean')
            elif loss_fn == 'kl':
                temperature = 2.0 # WARN: Hardcoded temperature for KL divergence
                L1T2_out_detached = L1T2_out.detach()
                loss = F.kl_div(
                    F.log_softmax(L2T2_lora_out / temperature, dim=-1),
                    F.softmax(L1T2_out_detached / temperature, dim=-1),
                    reduction='batchmean'
                ) * (temperature ** 2)
            elif loss_fn == 'cosine':
                L1T2_out_detached = L1T2_out.detach()
                loss = 1 - F.cosine_similarity(L2T2_lora_out, L1T2_out_detached, dim=-1).mean()
            
            if self._return_layer_loss:
                if requires_conversion:
                    L1T2_out = L1T2_out.to(base_out.dtype)
                return base_out + L1T2_out, loss
            
        # ================================================================
        if requires_conversion:
            L1T2_out = L1T2_out.to(base_out.dtype)
        return base_out + L1T2_out

    def load_lora_params(self, mode: Literal['L1T1', 'L2T1', 'L2T2'], state_dict, prefix: str):
        self.lora[mode].A.weight.data = state_dict[f'{prefix}.lora_A.weight'].to(self.device)
        self.lora[mode].B.weight.data = state_dict[f'{prefix}.lora_B.weight'].to(self.device)
        if self.lora[mode].bias_flag:
            self.lora[mode].A.bias.data = state_dict[f'{prefix}.lora_A.bias'].to(self.device)
            self.lora[mode].B.bias.data = state_dict[f'{prefix}.lora_B.bias'].to(self.device)
        if mode in ['L1T1', 'L2T1']:
            self.lora['L2T2'].A.weight.data += state_dict[f'{prefix}.lora_A.weight'].to(self.device)
            self.lora['L2T2'].B.weight.data += state_dict[f'{prefix}.lora_B.weight'].to(self.device)
            if self.lora['L2T2'].bias_flag:
                self.lora['L2T2'].A.bias.data += state_dict[f'{prefix}.lora_A.bias'].to(self.device)
                self.lora['L2T2'].B.bias.data += state_dict[f'{prefix}.lora_B.bias'].to(self.device)
            
class NeroModel(nn.Module):
    def __init__(self, 
                 base_model: nn.Module, 
                 lora_beta: float,
                 L1T1_lora_config: LoraConfig, 
                 L2T1_lora_config: LoraConfig,
                 train_L2T2_lora_and_g: bool=True,
                 eval_L1T2_lora: bool=False,
                 eval_L2T2_lora: bool=False,
                 return_layer_loss: bool=True,
                 debug: bool=False
                ):
        super().__init__()
        
        self.base_model = base_model
        self.lora_beta = lora_beta
        self._train_L2T2_lora_and_g = train_L2T2_lora_and_g
        self._eval_L1T2_lora = eval_L1T2_lora
        self._eval_L2T2_lora = eval_L2T2_lora
        self._return_layer_loss = return_layer_loss
        self.debug = debug

        # Wrap target layers with NeroLayer
        self.nero_layers = nn.ModuleDict()
        self._wrap_target_layers(L1T1_lora_config, L2T1_lora_config)
        
    def _wrap_target_layers(self, L1T1_lora_config, L2T1_lora_config):
        assert L1T1_lora_config.target_modules == L2T1_lora_config.target_modules, "[ERROR] L1T1 and L2T1 LoRA configurations must have the same target modules."

        for module_name, module in self.base_model.named_modules():
            if isinstance(module, NeroLayer):
                # Convert module name format and store reference
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.nero_layers[module_name] = module
                continue

            if any(module_name.endswith(target_module) for target_module in L1T1_lora_config.target_modules) and isinstance(module, nn.Linear):    
                parent_module, child_name = self._get_parent_module(module_name)
                nero_layer = NeroLayer(
                    base_layer=module,
                    eval_L1T2_lora=self._eval_L1T2_lora,
                    eval_L2T2_lora=self._eval_L2T2_lora,
                    return_layer_loss=self._return_layer_loss,

                    # L1T1 LoRA parameters
                    L1T1_lora_params={
                        'rank': L1T1_lora_config.r, 
                        'alpha': L1T1_lora_config.lora_alpha, 
                        'dropout': L1T1_lora_config.lora_dropout,
                        'bias': L1T1_lora_config.lora_bias,
                        'use_rslora': L1T1_lora_config.use_rslora,
                    },
                
                    # L2T1 LoRA parameters
                    L2T1_lora_params={
                        'rank': L2T1_lora_config.r, 
                        'alpha': L2T1_lora_config.lora_alpha, 
                        'dropout': L2T1_lora_config.lora_dropout,
                        'bias': L2T1_lora_config.lora_bias,
                        'use_rslora': L2T1_lora_config.use_rslora,
                    },

                    # L2T2 parameters (for temporary, use L2T1 LoRA parameters)
                    L2T2_lora_params={
                        'rank': L2T1_lora_config.r, 
                        'alpha': L2T1_lora_config.lora_alpha, 
                        'dropout': L2T1_lora_config.lora_dropout,
                        'bias': L2T1_lora_config.lora_bias,
                        'use_rslora': L2T1_lora_config.use_rslora,
                    },
                    lora_beta=self.lora_beta,
                    
                    # Debugging parameters
                    debug=self.debug,
                    module_name=module_name,
                )
                setattr(parent_module, child_name, nero_layer)

                # Store LoRA layers for weight loading
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.nero_layers[module_name] = nero_layer
    
    def _get_parent_module(self, module_name):
        parts = module_name.split('.')
        parent_module = self.base_model
        
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)
            
        return parent_module, parts[-1]
    
    def set_debug(self, debug: bool=False, verbose: bool=False):
        self.debug = debug
        
        for nero_layer in self.nero_layers.values():
            nero_layer.debug = debug
        
        if verbose:
            print(f"[INFO] Debugging is {'enabled' if debug else 'disabled'}!")
    
    def set_lora_beta(self, lora_beta: bool=False, verbose: bool=False):
        self.lora_beta = lora_beta
        
        for nero_layer in self.nero_layers.values():
            nero_layer.set_lora_beta(lora_beta)
        
        if verbose:
            print(f"[INFO] LoRA beta is set to {lora_beta}.")
    
    def set_return_layer_loss(self, return_layer_loss: bool=True, verbose: bool=False):
        self._return_layer_loss = return_layer_loss
        
        for layer in self.nero_layers.values():
            layer._return_layer_loss = return_layer_loss
        
        if verbose:
            print(f"[INFO] Return layer loss set to '{return_layer_loss}'.")
    
    def train_L2T2_lora_and_g(self, verbose: bool=False):
        self.freeze_all_except_L2T2_lora_and_g()
        
        self._train_L2T2_lora_and_g = True
        self._eval_L1T2_lora = False
        self._eval_L2T2_lora = False
        # self._return_layer_loss = True
        
        for nero_layer in self.nero_layers.values():
            nero_layer._eval_L1T2_lora = False
            nero_layer._eval_L2T2_lora = False
            # nero_layer._return_layer_loss = True

        if verbose:
            print(f"[INFO] Training L2T2 LoRA and g layers!")
    
    def eval_L1T2_lora(self, verbose: bool=False):
        self.freeze_all()
        
        self._train_L2T2_lora_and_g = False
        self._eval_L1T2_lora = True
        self._eval_L2T2_lora = False
        self._return_layer_loss = False
        
        for nero_layer in self.nero_layers.values():
            nero_layer._eval_L1T2_lora = True
            nero_layer._eval_L2T2_lora = False
            nero_layer._return_layer_loss = False
        
        if verbose:
            print(f"[INFO] Evaluating L1T2 LoRA!")
            
    def eval_L2T2_lora(self, verbose: bool=False):
        self.freeze_all()
        
        self._train_L2T2_lora_and_g = False
        self._eval_L1T2_lora = False
        self._eval_L2T2_lora = True
        self._return_layer_loss = False
        
        for nero_layer in self.nero_layers.values():
            nero_layer._eval_L1T2_lora = False
            nero_layer._eval_L2T2_lora = True
            nero_layer._return_layer_loss = False
        
        if verbose:
            print(f"[INFO] Evaluating L2T2 LoRA!")
            
    def freeze_all(self, verbose: bool=False):
        for p in self.base_model.parameters():
            p.requires_grad = False
        
        if verbose:
            print("[INFO] All layers are frozen!")
    
    def freeze_all_except_L2T2_lora_and_g(self, verbose=False):
        self.freeze_all(verbose=verbose)
        
        for nero_layer in self.nero_layers.values():
            for n, p in nero_layer.named_parameters():
                if 'lora.L2T2.' in n or 'g.' in n:
                    p.requires_grad = True
                else:
                    p.requires_grad = False
        
        if verbose:
            print("[INFO] All layers are frozen except L2T2 LoRA and g layers!")
    
    def load_lora_params(self, mode: Literal['L1T1', 'L2T1', 'L2T2'], lora_path: str):
        if not os.path.exists(lora_path):
            raise FileNotFoundError("[ERROR] LoRA file not found:", lora_path)
        
        if lora_path.endswith('.safetensors'):
            state_dict = load_file(lora_path)
        else:
            state_dict = torch.load(lora_path, map_location='cpu')

        prefix = list(state_dict.keys())[0].rsplit('model.', 1)[0] + 'model.'

        for nero_layer_name, nero_layer in self.nero_layers.items():
            nero_layer_name = nero_layer_name.replace('__DOT__', '.')
            nero_layer_name = prefix + nero_layer_name
            if f'{nero_layer_name}.lora_A.weight' in state_dict and f'{nero_layer_name}.lora_B.weight' in state_dict:
                nero_layer.load_lora_params(mode, state_dict, nero_layer_name)
            else:
                # TODO: Print warning message
                pass

        print(f"[INFO] {mode} LoRA parameters loaded successfully!")
        if mode == 'L2T1':
            print(f"[INFO] L2T2 LoRA parameters loaded successfully!")
    
    def forward(self, *args, **kwargs):
        if self._return_layer_loss:
            layer_losses = []
            # L2T2_lora_outs = {}
            # L2T2_tgts = {}
            
            def _hook_fn(layer_name, module, _in, _out):
                assert isinstance(_out, tuple) and len(_out) == 2
                layer_out, layer_loss = _out
                layer_losses.append(layer_loss)
                return layer_out # Return only `layer_out` to avoid breaking model flow
                
                # assert isinstance(_out, tuple) and len(_out) == 3
                # layer_out, L2T2_lora_out, L2T2_tgt = _out
                # L2T2_lora_outs[layer_name] = L2T2_lora_out
                # L2T2_tgts[layer_name] = L2T2_tgt
                # return layer_out # Return only `layer_out` to avoid breaking model flow

            # Register hooks to extract hidden_out during forward pass
            hooks = []
            for layer_name, layer in self.nero_layers.items():
                hook = layer.register_forward_hook(functools.partial(_hook_fn, layer_name))
                hooks.append(hook)
        
            try:
                outputs = self.base_model(*args, **kwargs)
            finally:
                # Remove hooks after forward pass, ensuring it's done even if an error occurs
                for hook in hooks:
                    hook.remove()
                    
            # Move all `layer_losses` to the same device
            layer_losses = [t.to(layer_losses[0].device) for t in layer_losses]
            
            # Average `layer_losses`
            layer_loss = torch.stack(layer_losses).mean()
            return outputs, layer_loss
            
            # return outputs, L2T2_lora_outs, L2T2_tgts
        
        return self.base_model(*args, **kwargs)
    
    def __getattr__(self, name):
        try:
            return super().__getattr__(name) # Try getting attribute from self
        except AttributeError:
            return getattr(self.base_model, name) # Fallback to base_model

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name, 
    device_map=device,
)
model = NeroModel(
    base_model, 
    L1T1_lora_config=model_configs['L1T1']['lora_config'], 
    L2T1_lora_config=model_configs['L2T1']['lora_config'], 
    lora_beta=1.0,
    return_layer_loss=False,
)

In [8]:
print("Check unloaded L1T1 LoRA parameters:")
check_lora_parameters(model, prefix='L1T1')
print()

model.load_lora_params('L1T1', model_configs['L1T1']['lora_path'])
print()

print("Check loaded L1T1 LoRA parameters:")
check_lora_parameters(model, prefix='L1T1')
print()

Check unloaded L1T1 LoRA parameters:
- name    : base_model.model.layers.0.self_attn.q_proj.lora.L1T1.A.weight
- device  : cuda:0
- dtype   : torch.float32
- mean    : 0.002062347484752536
- min     : -1.3849756717681885
- max     : 1.4044522047042847

[INFO] L1T1 LoRA parameters loaded successfully!

Check loaded L1T1 LoRA parameters:
- name    : base_model.model.layers.0.self_attn.q_proj.lora.L1T1.A.weight
- device  : cuda:0
- dtype   : torch.float32
- mean    : 6.287686119321734e-05
- min     : -0.04176201671361923
- max     : 0.04242725297808647



In [9]:
print("Check unloaded L2T1 LoRA parameters:")
check_lora_parameters(model, prefix='L2T1')
print()

print("Check unloaded L2T2 LoRA parameters:")
check_lora_parameters(model, prefix='L2T2')
print()

model.load_lora_params('L2T1', model_configs['L2T1']['lora_path'])
print()

print("Check loaded L2T1 LoRA parameters:")
check_lora_parameters(model, prefix='L2T1')
print()

print("Check loaded L2T2 LoRA parameters:")
check_lora_parameters(model, prefix='L2T2')
print()

Check unloaded L2T1 LoRA parameters:
- name    : base_model.model.layers.0.self_attn.q_proj.lora.L2T1.A.weight
- device  : cuda:0
- dtype   : torch.float32
- mean    : -0.00022871472174301744
- min     : -1.4289981126785278
- max     : 1.4715492725372314

Check unloaded L2T2 LoRA parameters:
- name    : base_model.model.layers.0.self_attn.q_proj.lora.L2T2.A.weight
- device  : cuda:0
- dtype   : torch.float32
- mean    : 6.287686119321734e-05
- min     : -0.04176201671361923
- max     : 0.04242725297808647

[INFO] L2T1 LoRA parameters loaded successfully!
[INFO] L2T2 LoRA parameters loaded successfully!

Check loaded L2T1 LoRA parameters:
- name    : base_model.model.layers.0.self_attn.q_proj.lora.L2T1.A.weight
- device  : cuda:0
- dtype   : torch.float32
- mean    : 2.2152138626552187e-05
- min     : -0.06327299773693085
- max     : 0.0625513345003128

Check loaded L2T2 LoRA parameters:
- name    : base_model.model.layers.0.self_attn.q_proj.lora.L2T2.A.weight
- device  : cuda:0
- dtype

In [10]:
model.freeze_all_except_L2T2_lora_and_g()
print()

# check_loss_and_grad_norm(nero_model, tokenizer)




In [11]:
model.set_debug(False, verbose=True)

[INFO] Debugging is disabled!


In [12]:
# Set model to training mode
model.train()

# Zero gradients manually
for p in model.parameters():
    p.grad = None

# Forward pass
inputs = tokenizer(
    L2T2_sample_prompt, 
    return_tensors='pt',
).to(next(model.parameters()).device)

outputs = model(
    input_ids=inputs['input_ids'],
    attention_mask=inputs['attention_mask'],
    labels=inputs['input_ids'],
    use_cache=False, # Disable cache to not conflict with gradient checkpointing
)

model.layers.0.self_attn.q_proj
L1T1_lora_B_out: (0.05553344264626503, 1.6210209651035257e-05, -0.04610259830951691)
L1T1_lora_out: (0.11106688529253006, 3.242041930207051e-05, -0.09220519661903381)

L2T1_lora_B_out: (0.09217894822359085, 0.0002234637358924374, -0.09333131462335587)
L2T1_lora_out: (0.1843578964471817, 0.0004469274717848748, -0.18666262924671173)

L2T2_lora_B_out: (0.19206136465072632, 0.00032500174711458385, -0.18994779884815216)
L2T2_lora_out: (0.38412272930145264, 0.0006500034942291677, -0.3798955976963043)

g_out: (0.24636539816856384, 0.000203075964236632, -0.225641667842865)
L1T2_out: (0.3418782949447632, 0.0002354963799007237, -0.3017212748527527)

model.layers.0.self_attn.k_proj
L1T1_lora_B_out: (0.021090906113386154, 2.7187588784727268e-05, -0.023282183334231377)
L1T1_lora_out: (0.04218181222677231, 5.4375177569454536e-05, -0.04656436666846275)

L2T1_lora_B_out: (0.04138234630227089, -0.0003364249423611909, -0.056755971163511276)
L2T1_lora_out: (0.0827646926045

In [13]:
outputs.loss

tensor(4.3433, device='cuda:0', grad_fn=<ToCopyBackward0>)

In [14]:
# L2T2_lora_outs = outputs[1]
# L2T2_tgts = outputs[2]

# loss_device = next(iter(L2T2_tgts.values())).device

# total_hidden_loss = torch.tensor(0.0, device=loss_device)
# for layer_name in L2T2_tgts.keys():
#     # print(layer_name)
#     L2T2_lora_out = L2T2_lora_outs[layer_name]
#     L2T2_tgt = L2T2_tgts[layer_name].detach()
    
#     # print(L2T2_lora_out.shape, L2T2_tgt.shape)
#     # print(L2T2_lora_out.grad_fn, L2T2_tgt.grad_fn)
#     hidden_loss = F.mse_loss(L2T2_lora_out, L2T2_tgt, reduction='mean')
#     total_hidden_loss += hidden_loss.to(loss_device)
    
#     # print()
# print(total_hidden_loss)

In [14]:
model.gradient_checkpointing_enable({'use_reentrant': False})
print("[INFO] Gradient checkpointing enabled!")
print()

# model.set_return_hidden_outputs(False)
check_loss_and_grad_norm(model, tokenizer)
# model.set_return_hidden_outputs(True)

[INFO] Gradient checkpointing enabled!

model.layers.0.self_attn.q_proj
L1T1_lora_B_out: (0.05573910102248192, 1.1024943887605332e-05, -0.04592706635594368)
L1T1_lora_out: (0.11147820204496384, 2.2049887775210664e-05, -0.09185413271188736)

L2T1_lora_B_out: (0.04789327457547188, 0.00019716875976882875, -0.05974280461668968)
L2T1_lora_out: (0.09578654915094376, 0.0003943375195376575, -0.11948560923337936)

L2T2_lora_B_out: (0.1532118022441864, 0.000313695112708956, -0.16253496706485748)
L2T2_lora_out: (0.3064236044883728, 0.000627390225417912, -0.32506993412971497)

g_out: (0.24000284075737, 0.00023305288050323725, -0.21710637211799622)
L1T2_out: (0.35148105025291443, 0.0002551026991568506, -0.308960497379303)

model.layers.0.self_attn.k_proj
L1T1_lora_B_out: (0.021090907976031303, -3.678964276332408e-05, -0.023282183334231377)
L1T1_lora_out: (0.04218181595206261, -7.357928552664816e-05, -0.04656436666846275)

L2T1_lora_B_out: (0.03269142284989357, -0.0001890525163616985, -0.03203193843

  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


L1T1_lora_B_out: (0.22498826682567596, 1.343503754469566e-05, -0.21202944219112396)
L1T1_lora_out: (0.44997653365135193, 2.687007508939132e-05, -0.4240588843822479)

L2T1_lora_B_out: (0.2551562488079071, -9.468446660321206e-05, -0.1911800354719162)
L2T1_lora_out: (0.5103124976158142, -0.00018936893320642412, -0.3823600709438324)

L2T2_lora_B_out: (0.4419698715209961, -0.0001815831201383844, -0.3491738438606262)
L2T2_lora_out: (0.8839397430419922, -0.0003631662402767688, -0.6983476877212524)

g_out: (0.7078957557678223, -0.00017379748169332743, -0.674232542514801)
L1T2_out: (1.1578723192214966, -0.0001469272538088262, -1.0982913970947266)

model.layers.31.mlp.down_proj
model.layers.30.self_attn.q_proj
L1T1_lora_B_out: (0.034300025552511215, -5.554730523726903e-05, -0.03530000150203705)
L1T1_lora_out: (0.06860005110502243, -0.00011109461047453806, -0.0706000030040741)

L2T1_lora_B_out: (0.10309743881225586, 8.892474579624832e-05, -0.07557198405265808)
L2T1_lora_out: (0.20619487762451172,

# Data

In [16]:
datasets = load_dataset(hf_data_id, data_dir=hf_data_dir)
print(datasets)

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 5000
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 1000
    })
})


In [17]:
train_loader = DataLoader(
    datasets['train'], 
    batch_size=batch_size, 
    shuffle=True, 
    collate_fn=default_data_collator,
)
print("[INFO] Total batches:", len(train_loader))

[INFO] Total batches: 1250


In [18]:
# Sanity check
first_batch = next(iter(train_loader))
print("First batch data shape (input_ids, attention_mask):")
print((
    first_batch['input_ids'].shape, 
    first_batch['attention_mask'].shape, 
))
print()

first_batch_text = tokenizer.batch_decode(first_batch['input_ids'], skip_special_tokens=True)[0]
print("First batch text:")
print(first_batch_text[:100], "...")
print()

# check_loss_and_grad_norm(model, tokenizer, prompt=first_batch_text)

First batch data shape (input_ids, attention_mask):
(torch.Size([4, 512]), torch.Size([4, 512]))

First batch text:
### Instruction:
Solve the following math problem step by step.

### Question: 
Mrs. Snyder used to  ...



# Training

In [19]:
# Set the model to training mode
model.train()

# If `device` is not specified or set to 'auto', use the model's device
# if device is None or device == 'auto':
device = next(iter(model.parameters())).device

# Set up optimizer and gradient scaler
# for L2T2 LoRA
model.train_L2T2_lora_and_g()
L2T2_lora_params = [p for n, p in model.named_parameters() if p.requires_grad]
L2T2_optimizer = torch.optim.Adam(L2T2_lora_params, lr=lr)
L2T2_scaler = torch.cuda.amp.GradScaler()

# for nonlinear functions
# model.train_nonlinear_fn()
# nonlinear_fn_params = [p for n, p in model.named_parameters() if p.requires_grad]
# nonlinear_fn_optimizer = torch.optim.Adam(nonlinear_fn_params, lr=lr)
# nonlinear_fn_scaler = torch.cuda.amp.GradScaler()

# Set up LR scheduler
max_global_steps = max_global_steps or len(train_loader) * num_epochs
warmup_steps = int(warmup_ratio * max_global_steps)
if warmup_ratio > 0:
    # If `warmup_ratio` > 0, use cosine annealing scheduler with warm-up 
    from transformers import get_cosine_schedule_with_warmup # type: ignore
    max_optimizer_steps = (max_global_steps // grad_accumulation_steps) // 2 # divide by 2 because we train in 2 modes
    num_warmup_steps = int(warmup_ratio * max_optimizer_steps)
    L2T2_scheduler = get_cosine_schedule_with_warmup(
        L2T2_optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=max_optimizer_steps,

        # Cycle length is 2000 steps or less, but at least 1 cycle, and at most 10 cycles
        num_cycles=min(max(1, max_optimizer_steps // 2000), 10), 
    )
    # nonlinear_fn_scheduler = get_cosine_schedule_with_warmup(
    #     nonlinear_fn_optimizer,
    #     num_warmup_steps=num_warmup_steps,
    #     num_training_steps=max_optimizer_steps,

    #     # Cycle length is 2000 steps or less, but at least 1 cycle, and at most 10 cycles
    #     num_cycles=min(max(1, max_optimizer_steps // 2000), 10), 
    # )
else:
    # If `warmup_ratio` is 0, use a dummy scheduler that returns constant LR
    from torch.optim.lr_scheduler import LambdaLR # type: ignore
    L2T2_scheduler = LambdaLR(L2T2_optimizer, lr_lambda=lambda step: 1.0)
    # nonlinear_fn_scheduler = LambdaLR(nonlinear_fn_optimizer, lr_lambda=lambda step: 1.0)

# Initialize W&B
wandb.init(
    project='Nero-XLT',
    reinit=True, # End previous run and start a new one
    config=dict(
        # Project configuration
        seed = seed,
        device = device,

        # Data configuration
        hf_data_id = hf_data_id,
        hf_data_dir = hf_data_dir,

        # Training configuration
        batch_size = batch_size,
        num_epochs = num_epochs,
        max_global_steps = max_global_steps,
        grad_accumulation_steps = grad_accumulation_steps,
        clip_grad_norm = clip_grad_norm,
        lr = lr,
        warmup_ratio = warmup_ratio,
        checkpoint_steps = checkpoint_steps,
        resume_step = resume_step,
    ),
)

# Resume training
global_step = resume_step
start_epoch = 0

def load_trainer_params(target, model, optimizer, scheduler, scaler, checkpoint_dir, device):
    # Load Nero parameters
    nero_path = os.path.join(checkpoint_dir, f'{target}.safetensors')
    # model.load_nero_params(mode, nero_path)

    # Load optimizer state
    optimizer_path = os.path.join(checkpoint_dir, f'{target}_optimizer.pt')
    optimizer.load_state_dict(torch.load(optimizer_path, map_location=device))
    
    # Move optimizer state to the correct device
    for param in optimizer.state:
        param_device = param.device
        param_dtype = param.dtype
        for key, value in optimizer.state[param].items():
            if isinstance(value, torch.Tensor):
                optimizer.state[param][key] = value.to(device=param_device, dtype=param_dtype)

    # Load scheduler state
    scheduler_path = os.path.join(checkpoint_dir, f'{target}_scheduler.pt')
    scheduler.load_state_dict(torch.load(scheduler_path, map_location=device))

    # Load scaler state
    scaler_path = os.path.join(checkpoint_dir, f'{target}_scaler.pt')
    scaler.load_state_dict(torch.load(scaler_path, map_location=device))

if resume_step > 0:
    checkpoint_dir = os.path.join(nero_dir, f'checkpoint-{resume_step}')
    print(f"[INFO] Resuming training from checkpoint directory:", checkpoint_dir)

    # Load trainer parameters
    load_trainer_params('L2T2_lora', model, L2T2_optimizer, L2T2_scheduler, L2T2_scaler, checkpoint_dir, device)
    # load_trainer_params('nonlinear_fn', model, nonlinear_fn_optimizer, nonlinear_fn_scheduler, nonlinear_fn_scaler, checkpoint_dir, device)

    # Load trainer state
    trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json')
    if os.path.exists(trainer_state_path):
        with open(trainer_state_path, 'r') as f:
            trainer_state = json.load(f)
        log_history = trainer_state.get('log_history', [])
        start_epoch = log_history[-1]['epoch'] if log_history else 0
        print(f"[INFO] Resuming training from epoch {start_epoch} and step {resume_step}.")

    # Load RNG state for reproducibility
    rng_path = os.path.join(checkpoint_dir, 'rng_state.pth')
    if os.path.exists(rng_path):
        rng_state = torch.load(rng_path)
        random.setstate(rng_state['python'])
        np.random.set_state(rng_state['numpy'])
        torch.set_rng_state(rng_state['cpu'])
        if torch.cuda.is_available() and rng_state['cuda']:
            torch.cuda.set_rng_state_all(rng_state['cuda'])
    
    if resume_step % grad_accumulation_steps != 0:
        print("[WARN] Resuming mid-gradient accumulation cycle. Make sure this is intended.")
else:
    if push_to_hf:
        # If it's new training, create Hugging Face repository
        print(f"[INFO] Creating Hugging Face repository:", hf_nero_id) # print the link instead
        create_repo(repo_id=hf_nero_id, repo_type='model', exist_ok=True)
        print(f"[INFO] Hugging Face repository created successfully!")

  L2T2_scaler = torch.cuda.amp.GradScaler()
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33malimtegar[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112777455557282, max=1.0…

In [20]:
def get_lora_beta_schedule(num_warmup_steps, start=0.0, schedule='linear'):
    def lora_beta_scheduler(global_step):
        if schedule == 'constant':
            return start  # fixed at start value

        if global_step >= num_warmup_steps:
            return 1.0  # fully on after warmup

        progress = global_step / num_warmup_steps
        if schedule == 'linear':
            # interpolate from start -> 1.0
            return start + (1.0 - start) * progress
        elif schedule == 'cosine':
            # cosine from start -> 1.0
            return start + (1.0 - start) * (0.5 * (1 - math.cos(math.pi * progress)))
        else:
            raise ValueError(f"Unknown schedule: {schedule}")
    return lora_beta_scheduler

lora_beta_scheduler = get_lora_beta_schedule(num_warmup_steps=500, start=0.05, schedule='cosine')

In [21]:
# model.eval()
# generated = generate_text(model, tokenizer, L1T2_sample_prompt, device=device)
# print("================================")
# print("CHECK GENERATED TEXT (L1T2)")
# print("================================")
# print(f"{'Generated':<9}:", generated)
# print()

# model.eval_L2T2_lora()
# generated = generate_text(model, tokenizer, L2T2_sample_prompt, device=device)
# print("================================")
# print("CHECK GENERATED TEXT (L2T2)")
# print("================================")
# print(f"{'Generated':<9}:", generated)
# print()

In [22]:
target = 'L2T2_lora' # or 'nonlinear_fn'
log_history = []
done = False

# Safety: Zero gradients at the start of gradient accumulation cycle
# This ensures there are no leftover gradients when resuming mid-cycle or after a previous cycle was interrupted
if global_step % grad_accumulation_steps == 0:
    L2T2_optimizer.zero_grad(set_to_none=True)

# Training loop
for epoch in range(start_epoch, num_epochs):
    for step, batch in enumerate(train_loader):
        # Skip previously completed steps
        if global_step <= resume_step:
            global_step += 1
            continue

        # Stop training if `max_global_steps` reached
        if global_step >= max_global_steps:
            done = True
            break

        # Move inputs to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        with torch.cuda.amp.autocast():
            # Set model to training mode
            model.train_L2T2_lora_and_g()
            # model.set_return_layer_loss(False)
            
            # Update LoRA beta
            lora_beta = lora_beta_scheduler(global_step)
            model.set_lora_beta(lora_beta)
            
            # Forward pass            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=input_ids,
                use_cache=False, # Disable cache to avoid conflict with gradient checkpointing
            )            
            model.set_return_layer_loss(False)
            
            # total_loss = None
            # lm_loss = None
            # layer_loss = None
            # if isinstance(outputs, tuple) and len(outputs) == 2:
            #     lm_loss = outputs[0]
            #     layer_loss = outputs[1]
            #     total_loss = lm_loss + layer_loss
            # else:
            #     lm_loss = outputs.loss
            #     total_loss = lm_loss
            
            lm_loss = outputs.loss
            # L2T2_lora_outs = outputs[1]
            # L2T2_tgts = outputs[2]

            # loss_device = next(iter(L2T2_tgts.values())).device

            # hidden_loss = torch.tensor(0.0, device=loss_device)
            # for layer_name in L2T2_tgts.keys():
            #     # print(layer_name)
            #     L2T2_lora_out = L2T2_lora_outs[layer_name]
            #     L2T2_tgt = L2T2_tgts[layer_name].detach()
                
            #     # print(L2T2_lora_out.shape, L2T2_tgt.shape)
            #     # print(L2T2_lora_out.grad_fn, L2T2_tgt.grad_fn)
            #     hidden_loss_i = F.mse_loss(L2T2_lora_out, L2T2_tgt, reduction='mean')
            #     hidden_loss += hidden_loss_i.to(loss_device)
                
            #     # print()
            # # print(hidden_loss)
            # hidden_loss = outputs[1]
            
            loss = lm_loss / grad_accumulation_steps

        log = {
            'target': target,
            'epoch': epoch,
            'step': global_step,
        }

        # Backward pass
        # with torch.autograd.set_detect_anomaly(True):
        log['loss'] = loss.item() * grad_accumulation_steps
        # log['lm_loss'] = lm_loss.item()
        # log['hidden_loss'] = hidden_loss.item()
        L2T2_scaler.scale(loss).backward()
        
        # Update parameters only at the end of gradient accumulation cycle
        grad_norm_log = {}
        if (global_step + 1) % grad_accumulation_steps == 0:
            if target == 'L2T2_lora':
                # Unscale gradients before computing gradient norm and applying clipping
                L2T2_scaler.unscale_(L2T2_optimizer)
                
                # DEBUG: Grad norm
                # named_params = {n: p for n, p in model.named_parameters() if 'lora.L2T2.' in n or 'g.' in n}
                # compute_named_grad_norm(named_params)

                # Compute gradient norm
                grad_norm = compute_grad_norm(L2T2_lora_params)
                grad_norm_log['grad_norm'] = grad_norm

                # Clip gradients
                if clip_grad_norm is not None:
                    torch.nn.utils.clip_grad_norm_(L2T2_lora_params, clip_grad_norm)
                
                # Compute clipped gradient norm
                grad_norm_clipped = compute_grad_norm(L2T2_lora_params)
                grad_norm_log['grad_norm_clipped'] = grad_norm_clipped

                # Update parameters
                L2T2_scaler.step(L2T2_optimizer)
                L2T2_scaler.update()
                L2T2_scheduler.step()

                # Zero gradients for the next gradient accumulation cycle
                L2T2_optimizer.zero_grad(set_to_none=True)
            # else:
            #     print("TEST 3")
            #     # Unscale gradients before computing gradient norm and applying clipping
            #     nonlinear_fn_scaler.unscale_(nonlinear_fn_optimizer)
                
            #     # Compute gradient norm
            #     grad_norm = compute_grad_norm(nonlinear_fn_params)
            #     grad_norm_log['nonlinear_fn/grad_norm'] = grad_norm

            #     # Clip gradients
            #     if clip_grad_norm is not None:
            #         torch.nn.utils.clip_grad_norm_(nonlinear_fn_params, clip_grad_norm)
                
            #     # Compute clipped gradient norm
            #     grad_norm_clipped = compute_grad_norm(nonlinear_fn_params)
            #     grad_norm_log['nonlinear_fn/grad_norm_clipped'] = grad_norm_clipped

            #     # Update parameters
            #     nonlinear_fn_scaler.step(nonlinear_fn_optimizer)
            #     nonlinear_fn_scaler.update()
            #     nonlinear_fn_scheduler.step()

            #     # Zero gradients for the next gradient accumulation cycle
            #     nonlinear_fn_optimizer.zero_grad(set_to_none=True)

            # # After updating parameters, toggle training target
            # target = 'nonlinear_fn' if target == 'L2T2_lora' else 'L2T2_lora'

        # Logging
        scheduler_log = {
            'lr': L2T2_scheduler.get_last_lr()[0],
            'lora_beta': lora_beta,
        }
        log = {
            **log, 
            **scheduler_log, 
            **grad_norm_log,
        }
        log_history.append(log)
        wandb.log(log)
        print(", ".join(
            f"{k}: {format_float(v)}" if isinstance(v, float) else f"{k}: {v}"
            for k, v in log.items()
        ))
        
        # Save and push checkpoint every `checkpoint_steps`
        if global_step > 0 and global_step % checkpoint_steps == 0:
            # Create checkpoint directory
            checkpoint_dir = os.path.join(nero_dir, f'checkpoint-{global_step}')
            os.makedirs(checkpoint_dir, exist_ok=True)

            # Save Nero parameters, along with optimizer, scheduler, and scaler states
            # L1T2_nero_state_dict = {n: p.detach().cpu() for n, p in model.named_parameters() if 'L1T2' in n}
            # save_file(L1T2_nero_state_dict, os.path.join(checkpoint_dir, 'L1T2_nero.safetensors'))
            # torch.save(L1T2_optimizer.state_dict(), os.path.join(checkpoint_dir, 'L1T2_optimizer.pt'))
            # torch.save(L1T2_scheduler.state_dict(), os.path.join(checkpoint_dir, 'L1T2_scheduler.pt'))
            # torch.save(L1T2_scaler.state_dict(), os.path.join(checkpoint_dir, 'L1T2_scaler.pt'))

            L2T2_nero_state_dict = {n: p.detach().cpu() for n, p in model.named_parameters() if 'L2T2' in n}
            save_file(L2T2_nero_state_dict, os.path.join(checkpoint_dir, 'L2T2_nero.safetensors'))
            torch.save(L2T2_optimizer.state_dict(), os.path.join(checkpoint_dir, 'L2T2_optimizer.pt'))
            torch.save(L2T2_scheduler.state_dict(), os.path.join(checkpoint_dir, 'L2T2_scheduler.pt'))
            torch.save(L2T2_scaler.state_dict(), os.path.join(checkpoint_dir, 'L2T2_scaler.pt'))

            # Save trainer state for resuming training
            trainer_state = {
                'epoch': epoch,
                'global_step': global_step,
                'log_history': log_history,
            }
            with open(os.path.join(checkpoint_dir, 'trainer_state.json'), 'w') as f:
                json.dump(trainer_state, f, indent=2)

            # Save RNG state for reproducibility
            rng_state = {
                'python': random.getstate(),
                'numpy': np.random.get_state(),
                'cpu': torch.get_rng_state(),
                'cuda': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else [],
            }
            torch.save(rng_state, os.path.join(checkpoint_dir, 'rng_state.pth'))

            # Upload checkpoint directory to Hugging Face repository
            if push_to_hf:
                upload_folder(
                    folder_path=checkpoint_dir,
                    repo_id=hf_nero_id,
                    path_in_repo=f"checkpoint-{global_step}",
                    commit_message=f"Add checkpoint at step {global_step}",
                    repo_type='model',
                )
        
        # Check generated text every `generate_steps`
        if global_step > 0 and global_step % generate_steps == 0:
            print()
            
            model.eval()
            generated = generate_text(model, tokenizer, L1T2_sample_prompt, device=device)
            print("================================")
            print("CHECK GENERATED TEXT (L1T2)")
            print("================================")
            print(generated)
            print()

            model.eval_L2T2_lora()
            generated = generate_text(model, tokenizer, L2T2_sample_prompt, device=device)
            print("================================")
            print("CHECK GENERATED TEXT (L2T2)")
            print("================================")
            print(generated)
            print()
        
        global_step += 1
    
    if done:
        break

wandb.finish()

  with torch.cuda.amp.autocast():
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


target: L2T2_lora, epoch: 0, step: 1, loss: 6.8698, lr: 6.4516e-06, lora_beta: 0.0500, grad_norm: 0.6779, grad_norm_clipped: 0.6779
target: L2T2_lora, epoch: 0, step: 2, loss: 8.1905, lr: 6.4516e-06, lora_beta: 0.0500
target: L2T2_lora, epoch: 0, step: 3, loss: 7.9981, lr: 1.2903e-05, lora_beta: 0.0501, grad_norm: 1.5343, grad_norm_clipped: 1.0000
target: L2T2_lora, epoch: 0, step: 4, loss: 7.8096, lr: 1.2903e-05, lora_beta: 0.0502
target: L2T2_lora, epoch: 0, step: 5, loss: 8.1152, lr: 1.9355e-05, lora_beta: 0.0502, grad_norm: 1.4493, grad_norm_clipped: 1.0000
target: L2T2_lora, epoch: 0, step: 6, loss: 8.1369, lr: 1.9355e-05, lora_beta: 0.0503
target: L2T2_lora, epoch: 0, step: 7, loss: 6.9513, lr: 2.5806e-05, lora_beta: 0.0505, grad_norm: 1.3610, grad_norm_clipped: 1.0000
target: L2T2_lora, epoch: 0, step: 8, loss: 6.8713, lr: 2.5806e-05, lora_beta: 0.0506
target: L2T2_lora, epoch: 0, step: 9, loss: 7.3652, lr: 3.2258e-05, lora_beta: 0.0508, grad_norm: 1.2395, grad_norm_clipped: 1.0

KeyboardInterrupt: 