In [2]:
# Kill all processess on GPU
!fuser -v /dev/nvidia* -k

# Libraries

In [3]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    %pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    %pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    %pip install sentencepiece protobuf "datasets>=4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    %pip install --no-deps unsloth
%pip install trl==0.19.1 # Fix error: ImportError: cannot import name 'ConstantLengthDataset' from 'trl.trainer.utils'

In [4]:
import os
import functools
import gc
import json
import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from typing import Optional
from torch.utils.data import DataLoader
from datasets import load_dataset, Dataset
from transformers import default_data_collator, AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
from huggingface_hub import snapshot_download, create_repo, upload_folder
from safetensors.torch import load_file, save_file
from datetime import datetime
from pprint import pprint

: 

In [None]:
def download_hf_model(
        repo_id: str, 
        checkpoint: Optional[int], 
        max_checkpoints: str = 10_000, 
        checkpoint_steps: str = 25,
    ):
    local_dir = repo_id.split('/')[-1]
    ignore_checkpoints = None
    
    if checkpoint is not None:
        ignore_checkpoints = [f'checkpoint-{i}/*' for i in range(0, max_checkpoints, checkpoint_steps) if i != checkpoint]

    snapshot_download(
        repo_id=repo_id,
        local_dir=local_dir,
        ignore_patterns=ignore_checkpoints,
    )

    checkpoint_dir = None
    if checkpoint is not None:
        checkpoint_dir = os.path.join(local_dir, f'checkpoint-{checkpoint}')
    return local_dir, checkpoint_dir

def check_loss_and_grad_norm(
    model, 
    tokenizer, 
    prompt="Paris is the capital of",
):
    # Set model to training mode
    model.train()

    # Zero gradients manually
    for p in model.parameters():
        p.grad = None

    # Forward pass
    inputs = tokenizer(
        prompt, 
        return_tensors='pt',
    ).to(next(model.parameters()).device)

    outputs = model(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        labels=inputs['input_ids'],
        use_cache=False, # Disable cache to not conflict with gradient checkpointing
    )
    if isinstance(outputs, tuple):
        outputs = outputs[0]
    print("Loss:", outputs.loss)

    # Backward pass
    if outputs.loss.grad_fn is None:
        print("Gradient norm:", None)
        return

    outputs.loss.backward()

    # Compute gradient norm
    grad_norm = 0.0
    for n, p in model.named_parameters():
        if p.requires_grad:
            p_grad_norm = p.grad.data.norm(2)
            grad_norm += p_grad_norm.item() ** 2
    grad_norm = grad_norm ** 0.5

    print("Gradient norm:", grad_norm)

def check_parameter(n, p):
    print(f"- {'name':<8}:", n)
    print(f"- {'device':<8}:", p.device)
    print(f"- {'dtype':<8}:", p.dtype)
    print(f"- {'mean':<8}:", p.mean().item())
    print(f"- {'min':<8}:", p.min().item())
    print(f"- {'max':<8}:", p.max().item())

def check_lora_parameters(model, prefix=None):
    prefix = prefix + '_lora' if prefix != None else 'lora'
    for n, p in model.named_parameters():
        if prefix in n:
            check_parameter(n, p)
            break

@torch.no_grad()
def generate_text(model, tokenizer, prompt, max_new_tokens=32, device=None, skip_special_tokens=True):
    device = device or next(model.parameters()).device
    model.eval()
    inputs = tokenizer(prompt, return_tensors='pt').to(device)
    outputs = model.generate(input_ids=inputs['input_ids'], max_new_tokens=max_new_tokens)
    model.train()
    return tokenizer.decode(outputs[0], skip_special_tokens=skip_special_tokens)

def get_task_and_lang_from_repo_id(repo_id: str):
    task, lang, _ = repo_id.split('B-')[-1].split('K-')[0].split('-')
    return task, lang

def load_hf_dataset_from_lora(
    lora_repo_id: str,
    train_size: int = 5000,
    test_size: int = 1000,
):
    # Get task and language
    task, lang = get_task_and_lang_from_repo_id(lora_repo_id)

    # Set up Hugging Face configuration
    data_id_map = {
        'wikipedia': 'wikimedia/wikipedia',
        'gsm8k': 'openai/gsm8k',
    }
    data_id = data_id_map[task]
    data_dir = f'20231101.{lang}' if task == 'wikipedia' else 'main'
    
    # Load dataset using streaming
    dataset_stream = load_dataset(data_id, data_dir=data_dir, split='train', streaming=True)

    # Manually take train_size + test_size samples
    total_size = train_size + test_size
    sliced_data = []
    for i, example in enumerate(dataset_stream):
        if i >= total_size:
            break
        sliced_data.append(example)

    # Convert to regular in-memory dataset
    dataset = Dataset.from_list(sliced_data)
    
    return dataset

def load_hf_dataset(
    lang, 
    task,
    split='train',
    train_size = 5000,
    test_size = 1000,
):
    # Set up Hugging Face configuration
    data_id_map = {
        'wikipedia': 'wikimedia/wikipedia',
        'gsm8k': 'openai/gsm8k',
    }
    data_id = data_id_map[task]
    data_dir = f'20231101.{lang}' if task == 'wikipedia' else 'main'

    # Use streaming
    dataset_stream = load_dataset(data_id, data_dir=data_dir, split=split, streaming=True)

    # Manually take train_size + test_size samples
    total_size = train_size + test_size
    sliced_data = []
    for i, example in enumerate(dataset_stream):
        if i >= total_size:
            break
        sliced_data.append(example)

    # Convert to regular in-memory dataset
    dataset = Dataset.from_list(sliced_data)
    
    return dataset

def compute_grad_norm(params):
    grad_norm = 0.0
    for p in params:
        p_grad_norm = p.grad.data.norm(2)
        grad_norm += p_grad_norm.item() ** 2
    grad_norm = grad_norm ** 0.5
    return grad_norm

def compute_named_grad_norm(named_params):
    grad_norm = 0.0
    for n, p in named_params.items():
        if p.grad is not None:
            p_grad_norm = p.grad.data.norm(2)
            grad_norm += p_grad_norm.item() ** 2
        else:
            print(f"[WARN] No gradient for {n}")
    grad_norm = grad_norm ** 0.5
    return grad_norm

def format_float(v):
    if abs(v) < 0.0001 or abs(v) >= 10000:
        return f"{v:.4e}"
    else:
        return f"{v:.4f}"

# Configurations

In [None]:
# Project configuration
seed = 69
device = 'auto'

# Data configuration
hf_data_id = 'alxxtexxr/Nero-XLT-Dataset'
hf_data_dir = 'gsm8k_en_5K_1K_1K_512'

# Training configuration
batch_size = 4
num_epochs = 1
max_global_steps = None
grad_accumulation_steps = 2
clip_grad_norm = 3.0
lr = 2e-4
warmup_ratio = 0.03
# num_warmup_steps = 100
checkpoint_steps = 50
generate_steps = 50
sample_prompt = '日本は'

# Model configurations
model_configs = {
    # L1T1 (Source Language - Source Task)
    'L1T1': {
        'hf_lora_id': 'alxxtexxr/L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650',
        'checkpoint': 650,
    },

    # L2T1 (Target Language - Source Task)
    'L2T1': {
        'hf_lora_id': 'alxxtexxr/L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629',
        'checkpoint': 650,
    },

    # L1T2 (Source Language - Target Task)
    # 'L1T2': {
    #     'hf_lora_id': 'alxxtexxr/L3.1-8B-gsm8k-en-5K-LoRA-v20250701060457',
    #     'checkpoint': 1875,
    # },
}

for key, config in model_configs.items():
    _, lora_dir = download_hf_model(config['hf_lora_id'], config['checkpoint'])
    model_configs[key]['lora_dir'] = lora_dir
    model_configs[key]['lora_path'] = os.path.join(lora_dir, 'adapter_model.safetensors')
    model_configs[key]['lora_config'] = LoraConfig.from_pretrained(lora_dir)

print("Model configurations:",)
for key, config in model_configs.items():
    print(f"- {key}:")
    for config_name, config_value in config.items():
        if config_name == 'lora_config':
            continue
        print(f"{'-':>3} {config_name:<10}: {config_value}")
print()

assert (
    model_configs['L1T1']['lora_config'].base_model_name_or_path == 
    model_configs['L2T1']['lora_config'].base_model_name_or_path == 
    model_configs['L1T2']['lora_config'].base_model_name_or_path
), "Base models must be the same"
base_model_name = model_configs['L1T1']['lora_config'].base_model_name_or_path
print(f"Base model name: {base_model_name}")

tokenizer = AutoTokenizer.from_pretrained(base_model_name)

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


Fetching 20 files:   0%|          | 0/20 [00:00<?, ?it/s]

Fetching 20 files:   0%|          | 0/20 [00:00<?, ?it/s]

Fetching 20 files:   0%|          | 0/20 [00:00<?, ?it/s]

Model configurations:
- L1T1:
  - hf_lora_id: alxxtexxr/L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650
  - checkpoint: 650
  - lora_dir  : L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650/checkpoint-650
  - lora_path : L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650/checkpoint-650/adapter_model.safetensors
- L2T1:
  - hf_lora_id: alxxtexxr/L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629
  - checkpoint: 650
  - lora_dir  : L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629/checkpoint-650
  - lora_path : L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629/checkpoint-650/adapter_model.safetensors
- L1T2:
  - hf_lora_id: alxxtexxr/L3.1-8B-gsm8k-en-5K-LoRA-v20250701060457
  - checkpoint: 1875
  - lora_dir  : L3.1-8B-gsm8k-en-5K-LoRA-v20250701060457/checkpoint-1875
  - lora_path : L3.1-8B-gsm8k-en-5K-LoRA-v20250701060457/checkpoint-1875/adapter_model.safetensors

Base model name: unsloth/meta-llama-3.1-8b-unsloth-bnb-4bit


In [None]:
# Hugging Face configuration
hf_nero_id = None
resume_step = 0

if hf_nero_id is not None and resume_step > 0:
    print(f"[INFO] Downloading Nero checkpoint at step {resume_step} from Hugging Face repository:", hf_nero_id)
    nero_dir, _ = download_hf_model(hf_nero_id, resume_step)
    print(f"[INFO] Nero checkpoint downloaded successfully!")
else:
    hf_username = 'alxxtexxr'
    nero_dir = f'L3.1-8B-{hf_data_dir.replace("_", "-")}K-Nero-v{datetime.now().strftime("%Y%m%d%H%M%S")}'
    print(f"[INFO] Creating Nero directory:", nero_dir)
    hf_nero_id = f'{hf_username}/{nero_dir}'
    os.makedirs(nero_dir, exist_ok=True)
    print(f"[INFO] Nero directory created!")

[INFO] Creating Nero directory: L3.1-8B-gsm8k-en-5K-1K-1K-512K-Nero-v20250814024441
[INFO] Nero directory created!


# Model

In [None]:
class NeroLayer(nn.Module):
    def __init__(self, base_layer,
                 
                 # LoRA parameters
                 L1T1_lora_rank, L1T1_lora_alpha, L1T1_lora_dropout, L1T1_lora_bias, L1T1_lora_use_rslora,
                 L2T1_lora_rank, L2T1_lora_alpha, L2T1_lora_dropout, L2T1_lora_bias, L2T1_lora_use_rslora,

                 # Nero parameters
                 nero_rank, nero_alpha, nero_bias, # nero_dropout, nero_use_rslora,
                 train_L2T2_nero = False,
                 return_L2T2_nero_loss = False,
                 
                 # Distance function parameters  
                 distance_fn='euclidean',
                 distance_fn_pairwise=True,
                 
                 debug=False,
                 module_name=None,
                 ):
        super().__init__()

        self.base_layer = base_layer
        self.device = base_layer.weight.device
        self.train_L2T2_nero = train_L2T2_nero
        self.return_L2T2_nero_loss = return_L2T2_nero_loss
        
        self.distance_fn = distance_fn
        self.distance_fn_pairwise = distance_fn_pairwise

        self.debug = debug
        self.module_name = module_name

        # Extract input and output features from the base layer
        in_features = getattr(base_layer, 'in_features', None)
        out_features = getattr(base_layer, 'out_features', None)

        if in_features is None or out_features is None:
            raise ValueError(f"Cannot determine in_features or out_features from {base_layer}.")

        # ================================================================================================================================
        # L1T1 LoRA
        # ================================================================================================================================
        self.L1T1_lora_alpha = L1T1_lora_alpha
        self.L1T1_lora_bias = L1T1_lora_bias
        self.L1T1_lora_scaling = L1T1_lora_alpha / math.sqrt(L1T1_lora_rank) if L1T1_lora_use_rslora else L1T1_lora_alpha / L1T1_lora_rank
        self.L1T1_lora_dropout = nn.Dropout(L1T1_lora_dropout) if L1T1_lora_dropout > 0.0 else nn.Identity()

        # LoRA decomposition
        self.L1T1_lora_A = nn.Linear(in_features, L1T1_lora_rank, bias=L1T1_lora_bias).to(self.device)
        self.L1T1_lora_B = nn.Linear(L1T1_lora_rank, out_features, bias=L1T1_lora_bias).to(self.device)

        # Initialize LoRA matrices: A ~ N(0, 1/rank), B initialized to 0
        L1T1_lora_std = 1 / torch.sqrt(torch.tensor(L1T1_lora_rank).float())
        nn.init.normal_(self.L1T1_lora_A.weight, mean=0.0, std=L1T1_lora_std)
        nn.init.zeros_(self.L1T1_lora_B.weight) 

        # ================================================================================================================================
        # L2T1 LoRA
        # ================================================================================================================================
        self.L2T1_lora_alpha = L2T1_lora_alpha
        self.L2T1_lora_bias = L2T1_lora_bias
        self.L2T1_lora_scaling = L2T1_lora_alpha / math.sqrt(L2T1_lora_rank) if L2T1_lora_use_rslora else L2T1_lora_alpha / L2T1_lora_rank
        self.L2T1_lora_dropout = nn.Dropout(L2T1_lora_dropout) if L2T1_lora_dropout > 0.0 else nn.Identity()

        # LoRA decomposition
        self.L2T1_lora_A = nn.Linear(in_features, L2T1_lora_rank, bias=L2T1_lora_bias).to(self.device)
        self.L2T1_lora_B = nn.Linear(L2T1_lora_rank, out_features, bias=L2T1_lora_bias).to(self.device)

        # Initialize LoRA matrices: A ~ N(0, 1/rank), B initialized to 0
        L2T1_lora_std = 1 / torch.sqrt(torch.tensor(L2T1_lora_rank).float())
        nn.init.normal_(self.L2T1_lora_A.weight, mean=0.0, std=L2T1_lora_std)
        nn.init.zeros_(self.L2T1_lora_B.weight) 
        
        # ================================================================================================================================
        # L1T2 Nero
        # ================================================================================================================================
        self.nero_alpha = nero_alpha
        self.nero_bias = nero_bias
        # self.nero_scaling = nero_alpha / math.sqrt(nero_rank) if nero_use_rslora else nero_alpha / nero_rank
        # self.nero_dropout = nn.Dropout(nero_dropout) if nero_dropout > 0.0 else nn.Identity()

        # Nero decomposition like LoRA
        self.L1T2_nero_A = nn.Linear(out_features, nero_rank, bias=nero_bias).to(self.device)
        self.L1T2_nero_B = nn.Linear(nero_rank, out_features, bias=nero_bias).to(self.device)

        # Initialize Nero matrices like LoRA: A ~ N(0, 1/rank), B initialized to 0
        nero_std = 1 / torch.sqrt(torch.tensor(nero_rank).float())
        nn.init.normal_(self.L1T2_nero_A.weight, mean=0.0, std=nero_std)
        nn.init.zeros_(self.L1T2_nero_B.weight)

        # ================================================================================================================================
        # L1T2 Nero
        # ================================================================================================================================
        # Nero decomposition like LoRA
        self.L2T2_nero_A = nn.Linear(out_features, nero_rank, bias=nero_bias).to(self.device)
        self.L2T2_nero_B = nn.Linear(nero_rank, out_features, bias=nero_bias).to(self.device)

        # Initialize Nero matrices like LoRA: A ~ N(0, 1/rank), B initialized to 0
        nn.init.normal_(self.L2T2_nero_A.weight, mean=0.0, std=nero_std)
        nn.init.zeros_(self.L2T2_nero_B.weight)
    
    def _distance_fn(self, a, b, huber_delta=1.0, huber_weight=1.0, eps=1e-8):
        """
        Compute distances between two tensors a and b.

        Args:
            a: [batch_size_a, seq_len, hidden_dim]
            b: [batch_size_b, seq_len, hidden_dim]
            huber_delta: delta for Huber distance
            eps: small constant for numerical stability

        Returns:
            Tensor of distances:
            - If self.distance_fn_pairwise=False: [batch_size_a, seq_len] (elementwise)
            - If self.distance_fn_pairwise=True:  [batch_size_a, batch_size_b, seq_len] (pairwise)
        """
        def huber_loss(diff, delta):
            abs_diff = diff.abs()
            mask = abs_diff < delta
            return torch.where(mask, 0.5 * diff**2, delta * (abs_diff - 0.5 * delta))

        if self.distance_fn_pairwise:
            # Expand for pairwise computation
            a_exp = a.unsqueeze(1)  # [batch_size_a, 1, seq_len, hidden_dim]
            b_exp = b.unsqueeze(0)  # [1, batch_size_b, seq_len, hidden_dim]
            if self.distance_fn == 'cosine':
                a_norm = F.normalize(a_exp, dim=-1, eps=eps)
                b_norm = F.normalize(b_exp, dim=-1, eps=eps)
                return 1 - (a_norm * b_norm).sum(dim=-1)
            elif self.distance_fn in ['euclidean', 'squared_euclidean']:
                diff = a_exp - b_exp
                dist_sq = (diff**2).sum(dim=-1)
                if self.distance_fn == 'euclidean':
                    return torch.sqrt(dist_sq + eps)
                else:
                    return dist_sq
            elif self.distance_fn == 'huber':
                diff = a_exp - b_exp
                return huber_loss(diff, huber_delta).sum(dim=-1)
            else:
                raise ValueError(f"Unsupported distance function: {self.distance_fn}")
        else:
            # Elementwise distances
            if self.distance_fn == 'cosine':
                a_norm = F.normalize(a, dim=-1, eps=eps)
                b_norm = F.normalize(b, dim=-1, eps=eps)
                return 1 - (a_norm * b_norm).sum(dim=-1)
            elif self.distance_fn == 'euclidean':
                return torch.norm(a - b, p=2, dim=-1)
            elif self.distance_fn == 'squared_euclidean':
                return ((a - b)**2).sum(dim=-1)
            elif self.distance_fn == 'huber':
                diff = a - b
                return huber_loss(diff, huber_delta).sum(dim=-1)
            else:
                raise ValueError(f"Unsupported distance function: {self.distance_fn}")
        
    def forward(self, x):
        # print("================================================================")
        # print(self.module_name)
        # print("================================================================")

        base_out = self.base_layer(x)

        requires_conversion = not torch.is_autocast_enabled()
        if requires_conversion:
            assert self.L1T1_lora_A.weight.dtype == self.L2T1_lora_A.weight.dtype
            x = x.to(self.L1T1_lora_A.weight.dtype)

        # L1T1_lora_out = self.L1T1_lora_B(self.L1T1_lora_A(self.L1T1_lora_dropout(x))) #* self.L1T1_lora_scaling
        L1T1_lora_dropout_out = self.L1T1_lora_dropout(x)
        L1T1_lora_A_out = self.L1T1_lora_A(L1T1_lora_dropout_out)
        L1T1_lora_B_out = self.L1T1_lora_B(L1T1_lora_A_out)
        # L1T1_lora_out = L1T1_lora_B_out * self.L1T1_lora_scaling
        L1T1_lora_out = L1T1_lora_B_out

        # L1T2_nero_out = self.L1T2_nero_B(F.relu(self.L1T2_nero_A(L1T1_lora_out)))
        L1T2_nero_A_out = self.L1T2_nero_A(L1T1_lora_out)
        L1T2_nero_relu_out = F.relu(L1T2_nero_A_out)
        L1T2_nero_B_out = self.L1T2_nero_B(L1T2_nero_relu_out)
        L1T2_nero_out = L1T2_nero_B_out * self.L1T1_lora_scaling
        if requires_conversion:
            L1T2_nero_out = L1T2_nero_out.to(base_out.dtype)
        output = base_out + L1T2_nero_out

        if self.train_L2T2_nero:
            # L2T1_lora_out = self.L2T1_lora_B(self.L2T1_lora_A(self.L2T1_lora_dropout(x))) #* self.L2T1_lora_scaling
            L2T1_lora_dropout_out = self.L2T1_lora_dropout(x)
            L2T1_lora_A_out = self.L2T1_lora_A(L2T1_lora_dropout_out)
            L2T1_lora_B_out = self.L2T1_lora_B(L2T1_lora_A_out)
            # L2T1_lora_out = L2T1_lora_B_out * self.L2T1_lora_scaling
            L2T1_lora_out = L2T1_lora_B_out

            # L2T2_nero_out = self.L2T2_nero_B(F.relu(self.L2T2_nero_A(L2T1_lora_out)))
            L2T2_nero_A_out = self.L2T2_nero_A(L2T1_lora_out)
            L2T2_nero_relu_out = F.relu(L2T2_nero_A_out)
            L2T2_nero_B_out = self.L2T2_nero_B(L2T2_nero_relu_out)
            L2T2_nero_out = L2T2_nero_B_out * self.L1T1_lora_scaling
            if requires_conversion:
                L2T2_nero_out = L2T2_nero_out.to(base_out.dtype)

            # ================================================================
            # LOSS
            # ================================================================
            # Language distances: distance(L1T1, L2T1), distance(L1T2, L2T2)
            LxT1_distance = self._distance_fn(L1T1_lora_out.float(), L2T1_lora_out.float())
            LxT2_distance = self._distance_fn(L1T2_nero_out.float(), L2T2_nero_out.float()) # requires grad

            # Task distances: distance(L1T1, L1T2), distance(L2T1, L2T2)
            L1Tx_distance = self._distance_fn(L1T1_lora_out.float(), L1T2_nero_out.float())
            L2Tx_distance = self._distance_fn(L2T1_lora_out.float(), L2T2_nero_out.float()) # requires grad

            # Distance constraints:
            # distance(L1T1, L2T1) ~= distance(L1T2, L2T2)
            # distance(L1T1, L1T2) ~= distance(L2T1, L2T2)
            L2T2_nero_loss = F.mse_loss(LxT2_distance, LxT1_distance) + F.mse_loss(L2Tx_distance, L1Tx_distance)
            
            if self.return_L2T2_nero_loss:
                return output, L2T2_nero_loss

        return output

    def load_L1T1_lora_params(self, state_dict, prefix):
        self.L1T1_lora_A.weight.data = state_dict[f'{prefix}.lora_A.weight'].to(self.device)
        self.L1T1_lora_B.weight.data = state_dict[f'{prefix}.lora_B.weight'].to(self.device)
        if self.L1T1_lora_bias:
            self.L1T1_lora_A.bias.data = state_dict[f'{prefix}.lora_A.bias'].to(self.device)
            self.L1T1_lora_B.bias.data = state_dict[f'{prefix}.lora_B.bias'].to(self.device)
    
    def load_L2T1_lora_params(self, state_dict, prefix):
        self.L2T1_lora_A.weight.data = state_dict[f'{prefix}.lora_A.weight'].to(self.device)
        self.L2T1_lora_B.weight.data = state_dict[f'{prefix}.lora_B.weight'].to(self.device)
        if self.L2T1_lora_bias:
            self.L2T1_lora_A.bias.data = state_dict[f'{prefix}.lora_A.bias'].to(self.device)
            self.L2T1_lora_B.bias.data = state_dict[f'{prefix}.lora_B.bias'].to(self.device)

    def load_L1T2_nero_params(self, state_dict, prefix):
        self.L1T2_nero_A.weight.data = state_dict[f'{prefix}.L1T2_nero_A.weight'].to(self.device)
        self.L1T2_nero_B.weight.data = state_dict[f'{prefix}.L1T2_nero_B.weight'].to(self.device)
        if self.nero_bias:
            self.L1T2_nero_A.bias.data = state_dict[f'{prefix}.L1T2_nero_A.bias'].to(self.device)
            self.L1T2_nero_B.bias.data = state_dict[f'{prefix}.L1T2_nero_B.bias'].to(self.device)
    
    def load_L2T2_nero_params(self, state_dict, prefix):
        self.L2T2_nero_A.weight.data = state_dict[f'{prefix}.L2T2_nero_A.weight'].to(self.device)
        self.L2T2_nero_B.weight.data = state_dict[f'{prefix}.L2T2_nero_B.weight'].to(self.device)
        if self.nero_bias:
            self.L2T2_nero_A.bias.data = state_dict[f'{prefix}.L2T2_nero_A.bias'].to(self.device)
            self.L2T2_nero_B.bias.data = state_dict[f'{prefix}.L2T2_nero_B.bias'].to(self.device)
    
class NeroModel(nn.Module):
    def __init__(self, 
                 base_model: nn.Module, 
                 L1T1_lora_config: LoraConfig, 
                 L2T1_lora_config: LoraConfig, 
                 train_L2T2_nero=False,
                 return_L2T2_nero_loss=False,
                 distance_fn='euclidean',
                 distance_fn_pairwise=True,
                 debug: bool=False,
                ):
        super().__init__()
        self.base_model = base_model
        self.train_L2T2_nero = train_L2T2_nero
        self.return_L2T2_nero_loss = return_L2T2_nero_loss
        
        self.distance_fn = distance_fn
        self.distance_fn_pairwise = distance_fn_pairwise

        self.debug = debug

        # Wrap target layers with NeroLayer
        self.nero_layers = nn.ModuleDict()
        self._wrap_target_layers(L1T1_lora_config, L2T1_lora_config)
        
    def _wrap_target_layers(self, L1T1_lora_config, L2T1_lora_config):
        assert L1T1_lora_config.target_modules == L2T1_lora_config.target_modules, "[ERROR] L1T1 and L2T1 LoRA configurations must have the same target modules."

        for module_name, module in self.base_model.named_modules():
            if isinstance(module, NeroLayer):
                # Convert module name format and store reference
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.nero_layers[module_name] = module
                continue

            if any(module_name.endswith(target_module) for target_module in L1T1_lora_config.target_modules) and isinstance(module, nn.Linear):    
                parent_module, child_name = self._get_parent_module(module_name)
                nero_layer = NeroLayer(
                    base_layer=module,
                    train_L2T2_nero=self.train_L2T2_nero,
                    return_L2T2_nero_loss=self.return_L2T2_nero_loss,

                    # L1T1 LoRA parameters
                    L1T1_lora_rank=L1T1_lora_config.r, 
                    L1T1_lora_alpha=L1T1_lora_config.lora_alpha, 
                    L1T1_lora_dropout=L1T1_lora_config.lora_dropout,
                    L1T1_lora_bias=L1T1_lora_config.lora_bias,
                    L1T1_lora_use_rslora=L1T1_lora_config.use_rslora,

                    # L2T1 LoRA parameters
                    L2T1_lora_rank=L2T1_lora_config.r, 
                    L2T1_lora_alpha=L2T1_lora_config.lora_alpha, 
                    L2T1_lora_dropout=L2T1_lora_config.lora_dropout,
                    L2T1_lora_bias=L2T1_lora_config.lora_bias,
                    L2T1_lora_use_rslora=L2T1_lora_config.use_rslora,

                    # Nero parameters (for temporary, use L1T1 LoRA parameters)
                    nero_rank=L1T1_lora_config.r, 
                    nero_alpha=L1T1_lora_config.lora_alpha, 
                    nero_bias=L1T1_lora_config.lora_bias,
                    # nero_dropout=L1T1_lora_config.lora_dropout, 

                    # Distance function parameters
                    distance_fn=self.distance_fn,
                    distance_fn_pairwise=self.distance_fn_pairwise,

                    debug=self.debug,
                    module_name=module_name,
                )
                setattr(parent_module, child_name, nero_layer)

                # Store LoRA layers for weight loading
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.nero_layers[module_name] = nero_layer
    
    def _get_parent_module(self, module_name):
        parts = module_name.split('.')
        parent_module = self.base_model
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)
        return parent_module, parts[-1]
    
    def set_train_L2T2_nero(self, train_L2T2_nero: bool, verbose: bool=False):
        self.train_L2T2_nero = train_L2T2_nero
        for layer in self.nero_layers.values():
            layer.train_L2T2_nero = train_L2T2_nero
        
        if train_L2T2_nero:
            self.set_return_L2T2_nero_loss(True)
            self.freeze_all_except_L2T2_nero()
        else:
            self.set_return_L2T2_nero_loss(False)
            self.freeze_all_except_L1T2_nero()

        if verbose:
            print(f"[INFO] Learn L2T2 Nero set to '{train_L2T2_nero}'.")
    
    def set_return_L2T2_nero_loss(self, return_L2T2_nero_loss: bool, verbose: bool=False):
        self.return_L2T2_nero_loss = return_L2T2_nero_loss
        for layer in self.nero_layers.values():
            layer.return_L2T2_nero_loss = return_L2T2_nero_loss
        
        if verbose:
            print(f"[INFO] Return L2T2 Nero loss set to '{return_L2T2_nero_loss}'.")
    
    def set_distance_fn(self, distance_fn: str, pairwise: bool, verbose: bool=False):
        self.distance_fn = distance_fn
        self.distance_fn_pairwise = pairwise
        for layer in self.nero_layers.values():
            layer.distance_fn = distance_fn
            layer.distance_fn_pairwise = pairwise
        
        if verbose:
            print(f"[INFO] Distance function set to '{distance_fn}' with pairwise to {pairwise}.")
    
    def freeze_all(self, verbose=False):
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        if verbose:
            print("[INFO] All layers are frozen!")

    def freeze_all_except_L1T2_nero(self, verbose=False):
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        for nero_layer in self.nero_layers.values():
            for param_name, param in nero_layer.named_parameters():
                if 'L1T2_nero_A' in param_name or 'L1T2_nero_B' in param_name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
        
        if verbose:
            print("[INFO] All layers are frozen except L1T2 Nero layers!")
    
    def freeze_all_except_L2T2_nero(self, verbose=False):
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        for nero_layer in self.nero_layers.values():
            for param_name, param in nero_layer.named_parameters():
                if 'L2T2_nero_A' in param_name or 'L2T2_nero_B' in param_name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
        
        if verbose:
            print("[INFO] All layers are frozen except L2T2 Nero layers!")
    
    def unfreeze_all(self):
        for param in self.base_model.parameters():
            param.requires_grad = True
        
        for nero_layer in self.nero_layers.values():
            for param in nero_layer.parameters():
                param.requires_grad = True
        
        print("[INFO] All layers are unfrozen!")
    
    def load_L1T1_lora_params(self, L1T1_lora_path):
        if not os.path.exists(L1T1_lora_path):
            raise FileNotFoundError("[ERROR] L1T1 LoRA file not found:", L1T1_lora_path)
        
        if L1T1_lora_path.endswith('.bin'):
            state_dict = torch.load(L1T1_lora_path, map_location='cpu')
        else:
            state_dict = load_file(L1T1_lora_path) # assuming .safetensors

        prefix = list(state_dict.keys())[0].rsplit('model.', 1)[0] + 'model.'

        for nero_layer_name, nero_layer in self.nero_layers.items():
            nero_layer_name = nero_layer_name.replace('__DOT__', '.')
            nero_layer_name = prefix + nero_layer_name
            if f'{nero_layer_name}.lora_A.weight' in state_dict and f'{nero_layer_name}.lora_B.weight' in state_dict:
                nero_layer.load_L1T1_lora_params(state_dict, nero_layer_name)
            else:
                # TODO: Print warning message
                pass

        print("[INFO] L1T1 LoRA parameters loaded successfully!")
    
    def load_L2T1_lora_params(self, L2T1_lora_path):
        if not os.path.exists(L2T1_lora_path):
            raise FileNotFoundError("[ERROR] L2T1 LoRA file not found:", L2T1_lora_path)
        
        if L2T1_lora_path.endswith('.bin'):
            state_dict = torch.load(L2T1_lora_path, map_location='cpu')
        else:
            state_dict = load_file(L2T1_lora_path) # assuming .safetensors

        prefix = list(state_dict.keys())[0].rsplit('model.', 1)[0] + 'model.'

        for nero_layer_name, nero_layer in self.nero_layers.items():
            nero_layer_name = nero_layer_name.replace('__DOT__', '.')
            nero_layer_name = prefix + nero_layer_name
            if f'{nero_layer_name}.lora_A.weight' in state_dict and f'{nero_layer_name}.lora_B.weight' in state_dict:
                nero_layer.load_L2T1_lora_params(state_dict, nero_layer_name)
            else:
                # TODO: Print warning message
                pass

        print("[INFO] L2T1 LoRA parameters loaded successfully!")
    
    def load_L1T2_nero_params(self, L1T2_nero_path):
        if not os.path.exists(L1T2_nero_path):
            raise FileNotFoundError("[ERROR] L1T2_Nero file not found:", L1T2_nero_path)
        
        if L1T2_nero_path.endswith('.bin'):
            state_dict = torch.load(L1T2_nero_path, map_location='cpu')
        else:
            state_dict = load_file(L1T2_nero_path) # assuming .safetensors

        prefix = list(state_dict.keys())[0].rsplit('model.', 1)[0] + 'model.'

        for nero_layer_name, nero_layer in self.nero_layers.items():
            nero_layer_name = nero_layer_name.replace('__DOT__', '.')
            nero_layer_name = prefix + nero_layer_name
            if f'{nero_layer_name}.L1T2_nero_A.weight' in state_dict and f'{nero_layer_name}.L1T2_nero_B.weight' in state_dict:
                nero_layer.load_L1T2_nero_params(state_dict, nero_layer_name)
            else:
                # TODO: Print warning message
                pass

        print("[INFO] L1T2 Nero parameters loaded successfully!")
    
    def load_L2T2_nero_params(self, L2T2_nero_path):
        if not os.path.exists(L2T2_nero_path):
            raise FileNotFoundError("[ERROR] L2T2_Nero file not found:", L2T2_nero_path)
        
        if L2T2_nero_path.endswith('.bin'):
            state_dict = torch.load(L2T2_nero_path, map_location='cpu')
        else:
            state_dict = load_file(L2T2_nero_path) # assuming .safetensors

        prefix = list(state_dict.keys())[0].rsplit('model.', 1)[0] + 'model.'

        for nero_layer_name, nero_layer in self.nero_layers.items():
            nero_layer_name = nero_layer_name.replace('__DOT__', '.')
            nero_layer_name = prefix + nero_layer_name
            if f'{nero_layer_name}.L2T2_nero_A.weight' in state_dict and f'{nero_layer_name}.L2T2_nero_B.weight' in state_dict:
                nero_layer.load_L2T2_nero_params(state_dict, nero_layer_name)
            else:
                # TODO: Print warning message
                pass

        print("[INFO] L2T2 Nero parameters loaded successfully!")
    
    def forward(self, *args, **kwargs):
        if self.train_L2T2_nero:
            L2T2_nero_losses = []
            
            def _hook_fn(layer_name, module, _in, _out):
                assert isinstance(_out, tuple) and len(_out) == 2
                layer_out, L2T2_nero_layer_loss = _out
                L2T2_nero_losses.append(L2T2_nero_layer_loss)
                return layer_out # Return only layer_out to avoid breaking model flow

            # Register hooks to extract hidden_out during forward pass
            hooks = []
            for layer_name, layer in self.nero_layers.items():
                hook = layer.register_forward_hook(functools.partial(_hook_fn, layer_name))
                hooks.append(hook)
        
            try:
                outputs = self.base_model(*args, **kwargs)
            finally:
                # Remove hooks after forward pass, ensuring it's done even if an error occurs
                for hook in hooks:
                    hook.remove()

            L2T2_nero_loss = torch.stack(L2T2_nero_losses).mean()
            
            if self.return_L2T2_nero_loss:
                return outputs, L2T2_nero_loss
            else:
                return outputs
        
        return self.base_model(*args, **kwargs)
    
    def __getattr__(self, name):
        try:
            return super().__getattr__(name) # Try getting attribute from self
        except AttributeError:
            return getattr(self.base_model, name) # Fallback to base_model

base_nero_model = AutoModelForCausalLM.from_pretrained(
    base_model_name, 
    device_map=device,
)
nero_model = NeroModel(
    base_nero_model, 
    L1T1_lora_config=model_configs['L1T1']['lora_config'], 
    L2T1_lora_config=model_configs['L2T1']['lora_config'], 
    train_L2T2_nero=False,
    return_L2T2_nero_loss=False,
    distance_fn='euclidean',
    distance_fn_pairwise=True,
    debug=False,
)

In [None]:
# nero_model.set_mode('teacher')

# for param in nero_model.base_model.parameters():
#     param.requires_grad = False
        
# for nero_layer in nero_model.nero_layers.values():
#     for param_name, param in nero_layer.named_parameters():
#         if 'teacher_lora_A' in param_name or 'teacher_lora_B' in param_name:
#             param.requires_grad = True
#         else:
#             param.requires_grad = False

# check_loss_and_grad_norm(nero_model, tokenizer, prompt="Despite the ominous clouds gathering on the horizon and the distant rumble of thunder echoing through the valley, the determined hikers,")

In [None]:
print("Check unloaded L1T1 LoRA parameters:")
check_lora_parameters(nero_model, prefix='L1T1')
print()

nero_model.load_L1T1_lora_params(model_configs['L1T1']['lora_path'])
print()

print("Check loaded L1T1 LoRA parameters:")
check_lora_parameters(nero_model, prefix='L1T1')
print()

Check unloaded L1T1 LoRA parameters:
- name    : base_model.model.layers.0.self_attn.q_proj.L1T1_lora_A.weight
- device  : cuda:0
- dtype   : torch.float32
- mean    : 0.00037309923209249973
- min     : -1.5596299171447754
- max     : 1.5415360927581787

[INFO] L1T1 LoRA parameters loaded successfully!

Check loaded L1T1 LoRA parameters:
- name    : base_model.model.layers.0.self_attn.q_proj.L1T1_lora_A.weight
- device  : cuda:0
- dtype   : torch.float32
- mean    : 6.287686119321734e-05
- min     : -0.04176201671361923
- max     : 0.04242725297808647



In [None]:
print("Check unloaded L2T1 LoRA parameters:")
check_lora_parameters(nero_model, prefix='L2T1')
print()

nero_model.load_L2T1_lora_params(model_configs['L2T1']['lora_path'])
print()

print("Check loaded L2T1 LoRA parameters:")
check_lora_parameters(nero_model, prefix='L2T1')
print()

Check unloaded L2T1 LoRA parameters:
- name    : base_model.model.layers.0.self_attn.q_proj.L2T1_lora_A.weight
- device  : cuda:0
- dtype   : torch.float32
- mean    : -0.002967829117551446
- min     : -1.4062590599060059
- max     : 1.7496132850646973

[INFO] L2T1 LoRA parameters loaded successfully!

Check loaded L2T1 LoRA parameters:
- name    : base_model.model.layers.0.self_attn.q_proj.L2T1_lora_A.weight
- device  : cuda:0
- dtype   : torch.float32
- mean    : 2.2152138626552187e-05
- min     : -0.06327299773693085
- max     : 0.0625513345003128



In [None]:
nero_model.freeze_all_except_L1T2_nero()
print()

# check_loss_and_grad_norm(nero_model, tokenizer)




In [None]:
nero_model.gradient_checkpointing_enable({'use_reentrant': False})
print("[INFO] Gradient checkpointing enabled!")
print()

check_loss_and_grad_norm(nero_model, tokenizer)

[INFO] Gradient checkpointing enabled!

Loss: tensor(3.4235, device='cuda:0', grad_fn=<NllLossBackward0>)
Gradient norm: 5.486812315546735


# Data

In [None]:
datasets = load_dataset(hf_data_id, data_dir=hf_data_dir)
print(datasets)

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 5000
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 1000
    })
})


In [None]:
train_loader = DataLoader(
    datasets['train'], 
    batch_size=batch_size, 
    shuffle=True, 
    collate_fn=default_data_collator,
)
print("[INFO] Total batches:", len(train_loader))

[INFO] Total batches: 1250


In [None]:
# Sanity check
first_batch = next(iter(train_loader))
print("First batch data shape (input_ids, attention_mask):")
print((
    first_batch['input_ids'].shape, 
    first_batch['attention_mask'].shape, 
))
print()

first_batch_text = tokenizer.batch_decode(first_batch['input_ids'], skip_special_tokens=True)[0]
print("First batch text:")
print(first_batch_text[:100], "...")
print()

check_loss_and_grad_norm(nero_model, tokenizer)

First batch data shape (input_ids, attention_mask):
(torch.Size([4, 512]), torch.Size([4, 512]))

First batch text:
### Instruction:
Solve the following math problem step by step.

### Question: 
Suraya picked 12 app ...

Loss: tensor(3.4235, device='cuda:0', grad_fn=<NllLossBackward0>)
Gradient norm: 5.486812315546735


# Training

In [None]:
# Sanity check
# Set model to learn L2T2 Nero
nero_model.set_train_L2T2_nero(True)
nero_model.set_distance_fn('euclidean', pairwise=True)
nero_model.train()

# Forward pass
device = next(nero_model.parameters()).device
# prompts = [
#     "The capital city of Japan is", 
#     "The capital city of Indonesia is",
# ]
# inputs = tokenizer(
#     prompts,
#     return_tensors='pt',
#     padding=True,
# )
outputs, L2T2_nero_outs = nero_model(
    # Test with first batch
    input_ids=first_batch['input_ids'].to(device),
    attention_mask=first_batch['attention_mask'].to(device),
    labels=first_batch['input_ids'].to(device),

    # Test with sample prompts
    # input_ids=inputs['input_ids'].to(device),
    # attention_mask=inputs['attention_mask'].to(device),
    # labels=inputs['input_ids'].to(device),
    
    use_cache=False, # Disable cache to not conflict with gradient checkpointing
)

print(L2T2_nero_outs)

# pairwise euclidean: 0.2222
# non-pairwise euclidean: 0.2167

# pairwise cosine: 0.0022
# non-pairwise cosine: 0.0024

# pairwise squared euclidean: 33.0331
# non-pairwise squared euclidean: 33.1182

# pairwise huber: 6.0343
# non-pairwise huber: 6.2125

# pairwise cosine huber: 0.0640
# non-pairwise cosine huber: 0.0663

tensor(0.2594, device='cuda:0', grad_fn=<MeanBackward0>)


In [None]:
# Set the distance function
nero_model.set_distance_fn('euclidean', pairwise=True)

# Set the model to training mode
nero_model.train()

# If `device` is not specified or set to 'auto', use the model's device
# if device is None or device == 'auto':
device = next(iter(nero_model.parameters())).device

# Set up optimizer and gradient scaler
nero_model.set_train_L2T2_nero(True)
# nero_params_L2T2 = [p for n, p in nero_model.named_parameters() if p.requires_grad]
nero_named_params_L2T2 = {n: p for n, p in nero_model.named_parameters() if p.requires_grad}
optimizer_L2T2 = torch.optim.Adam(nero_named_params_L2T2.values(), lr=lr, foreach=False)
scaler_L2T2 = torch.cuda.amp.GradScaler()

nero_model.set_train_L2T2_nero(False)
nero_params = [p for n, p in nero_model.named_parameters() if p.requires_grad]
optimizer = torch.optim.Adam(nero_params, lr=lr, foreach=False)
scaler = torch.cuda.amp.GradScaler()

# Set up LR scheduler
max_global_steps = max_global_steps or len(train_loader) * num_epochs
warmup_steps = int(warmup_ratio * max_global_steps)
if warmup_ratio > 0:
# if num_warmup_steps > 0:
    # If `warmup_ratio` > 0, use cosine annealing scheduler with warm-up 
    from transformers import get_cosine_schedule_with_warmup # type: ignore
    max_optimizer_steps = max_global_steps // grad_accumulation_steps
    num_warmup_steps = int(warmup_ratio * max_optimizer_steps)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=max_optimizer_steps,

        # Cycle length is 2000 steps or less, but at least 1 cycle, and at most 10 cycles
        num_cycles=min(max(1, max_optimizer_steps // 2000), 10), 
    )
    scheduler_L2T2 = get_cosine_schedule_with_warmup(
        optimizer_L2T2,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=max_optimizer_steps,

        # Cycle length is 2000 steps or less, but at least 1 cycle, and at most 10 cycles
        num_cycles=min(max(1, max_optimizer_steps // 2000), 10), 
    )
else:
    # If `warmup_ratio` is 0, use a dummy scheduler that returns constant LR
    from torch.optim.lr_scheduler import LambdaLR # type: ignore
    scheduler = LambdaLR(optimizer, lr_lambda=lambda step: 1.0)
    scheduler_L2T2 = LambdaLR(optimizer_L2T2, lr_lambda=lambda step: 1.0)

# Initialize W&B
wandb.init(
    project='Nero-XLT',
    reinit=True, # End previous run and start a new one
    config=dict(
        # Project configuration
        seed = seed,
        # target_lang=target_lang,
        # target_task=target_task,
        device = device,

        # Data configuration
        # train_size = train_size,
        # test_size = test_size,
        # max_seq_length = max_seq_length,

        # Training configuration
        batch_size = batch_size,
        num_epochs = num_epochs,
        max_global_steps = max_global_steps,
        resume_step = resume_step,
        grad_accumulation_steps = grad_accumulation_steps,
        clip_grad_norm = clip_grad_norm,
        lr = lr,
        warmup_ratio = warmup_ratio,
        # num_warmup_steps = num_warmup_steps,
        checkpoint_steps = checkpoint_steps,
    ),
)

# Resume training
global_step = resume_step
start_epoch = 0

if resume_step > 0:
    checkpoint_dir = os.path.join(nero_dir, f'checkpoint-{resume_step}')
    print(f"[INFO] Resuming training from checkpoint directory:", checkpoint_dir)

    # Load Nero parameters
    nero_path = os.path.join(checkpoint_dir, 'adapter_model.safetensors')
    nero_model.load_nero_params(nero_path)

    # Load optimizer state
    optimizer_path = os.path.join(checkpoint_dir, 'optimizer.pt')
    optimizer.load_state_dict(torch.load(optimizer_path, map_location=device))
    
    # Move optimizer state to the correct device
    for param in optimizer.state:
        param_device = param.device
        param_dtype = param.dtype
        for key, value in optimizer.state[param].items():
            if isinstance(value, torch.Tensor):
                optimizer.state[param][key] = value.to(device=param_device, dtype=param_dtype)

    # Load scheduler state
    scheduler_path = os.path.join(checkpoint_dir, 'scheduler.pt')
    scheduler.load_state_dict(torch.load(scheduler_path, map_location=device))

    # Load scaler state
    scaler_path = os.path.join(checkpoint_dir, 'scaler.pt')
    scaler.load_state_dict(torch.load(scaler_path, map_location=device))

    # Load trainer state
    trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json')
    if os.path.exists(trainer_state_path):
        with open(trainer_state_path, 'r') as f:
            trainer_state = json.load(f)
        log_history = trainer_state.get('log_history', [])
        start_epoch = log_history[-1]['epoch'] if log_history else 0
        print(f"[INFO] Resuming training from epoch {start_epoch} and step {resume_step}.")

    # Load RNG state for reproducibility
    rng_path = os.path.join(checkpoint_dir, 'rng_state.pth')
    if os.path.exists(rng_path):
        rng_state = torch.load(rng_path)
        random.setstate(rng_state['python'])
        np.random.set_state(rng_state['numpy'])
        torch.set_rng_state(rng_state['cpu'])
        if torch.cuda.is_available() and rng_state['cuda']:
            torch.cuda.set_rng_state_all(rng_state['cuda'])
    
    if resume_step % grad_accumulation_steps != 0:
        print("[WARN] Resuming mid-gradient accumulation cycle. Make sure this is intended.")
else:
    # If it's new training, create Hugging Face repository
    print(f"[INFO] Creating Hugging Face repository:", hf_nero_id) # print the link instead
    create_repo(repo_id=hf_nero_id, repo_type='model', exist_ok=True)
    print(f"[INFO] Hugging Face repository created successfully!")
    pass

  scaler_L2T2 = torch.cuda.amp.GradScaler()
  scaler = torch.cuda.amp.GradScaler()
[34m[1mwandb[0m: Currently logged in as: [33malimtegar[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [18]:
train_L2T2_nero = False
log_history = []
done = False

# Safety: Zero gradients at the start of gradient accumulation cycle
# This ensures there are no leftover gradients when resuming mid-cycle or after a previous cycle was interrupted
if global_step % grad_accumulation_steps == 0:
    optimizer.zero_grad(set_to_none=True)
    optimizer_L2T2.zero_grad(set_to_none=True)

# Training loop
for epoch in range(start_epoch, num_epochs):
    for step, batch in enumerate(train_loader):
        # Skip previously completed steps
        if global_step <= resume_step:
            global_step += 1
            continue

        # Stop training if `max_global_steps` reached
        if global_step >= max_global_steps:
            done = True
            break

        # Move inputs to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        with torch.cuda.amp.autocast():
            # Forward pass
            nero_model.set_train_L2T2_nero(train_L2T2_nero)
            
            outputs = nero_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=input_ids,
                use_cache=False, # Disable cache to avoid conflict with gradient checkpointing
            )
            
            if train_L2T2_nero:
                _, L2T2_nero_loss = outputs
                _loss = L2T2_nero_loss
            else:
                _loss = outputs.loss
                
            loss = _loss / grad_accumulation_steps

        log = {
            'mode': 'L1T2' if not train_L2T2_nero else 'L2T2',
            'epoch': epoch,
            'step': global_step,
        }

        # Backward pass
        with torch.autograd.set_detect_anomaly(True):
            if train_L2T2_nero:
                log['L2T2/loss'] = loss.item() * grad_accumulation_steps
                nero_model.set_return_L2T2_nero_loss(False)
                scaler_L2T2.scale(loss).backward()
            else:
                log['L1T2/loss'] = loss.item() * grad_accumulation_steps
                scaler.scale(loss).backward()
        
        # Update parameters only at the end of gradient accumulation cycle
        grad_norm_log = {}
        if (global_step + 1) % grad_accumulation_steps == 0:
            if train_L2T2_nero:
                # Unscale gradients before computing gradient norm and applying clipping
                scaler_L2T2.unscale_(optimizer_L2T2)

                # Compute gradient norm
                grad_norm = compute_named_grad_norm(nero_named_params_L2T2)
                grad_norm_log['L2T2/grad_norm'] = grad_norm

                # Clip gradients
                if clip_grad_norm is not None:
                    torch.nn.utils.clip_grad_norm_(nero_named_params_L2T2.values(), clip_grad_norm)
                
                # Compute clipped gradient norm
                grad_norm_clipped = compute_named_grad_norm(nero_named_params_L2T2)
                grad_norm_log['L2T2/grad_norm_clipped'] = grad_norm_clipped

                # Update parameters
                scaler_L2T2.step(optimizer_L2T2)
                scaler_L2T2.update()
                scheduler_L2T2.step()

                # Zero gradients for the next gradient accumulation cycle
                optimizer_L2T2.zero_grad(set_to_none=True)
            else:
                # Unscale gradients before computing gradient norm and applying clipping
                scaler.unscale_(optimizer)
                
                # Compute gradient norm
                grad_norm = compute_grad_norm(nero_params)
                grad_norm_log['L1T2/grad_norm'] = grad_norm

                # Clip gradients
                if clip_grad_norm is not None:
                    torch.nn.utils.clip_grad_norm_(nero_params, clip_grad_norm)
                
                # Compute clipped gradient norm
                grad_norm_clipped = compute_grad_norm(nero_params)
                grad_norm_log['L1T2/grad_norm_clipped'] = grad_norm_clipped

                # Update parameters
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()

                # Zero gradients for the next gradient accumulation cycle
                optimizer.zero_grad(set_to_none=True)

            # After updating parameters, toggle training mode:
            # if currently training L1T2, switch to L2T2
            # if currently training L2T2, switch to L1T2
            train_L2T2_nero = not train_L2T2_nero

        # Logging
        lr_log = (
            {'L1T2/lr': scheduler.get_last_lr()[0]} 
            if not train_L2T2_nero else 
            {'L2T2/lr': scheduler_L2T2.get_last_lr()[0]}
        )
        log = {
            **log, 
            **lr_log, 
            **grad_norm_log,
        }
        log_history.append(log)
        wandb.log(log)
        print(", ".join(
            f"{k}: {format_float(v)}" if isinstance(v, float) else f"{k}: {v}"
            for k, v in log.items()
        ))
        
        # Save and push checkpoint every `checkpoint_steps`
        if global_step > 0 and global_step % checkpoint_steps == 0:
            # Create checkpoint directory
            checkpoint_dir = os.path.join(nero_dir, f'checkpoint-{global_step}')
            os.makedirs(checkpoint_dir, exist_ok=True)

            # Save Nero parameters, along with optimizer, scheduler, and scaler states
            nero_state_dict = {n: p.detach().cpu() for n, p in nero_model.named_parameters() if 'L1T2' in n}
            save_file(nero_state_dict, os.path.join(checkpoint_dir, 'adapter_model_L1T2.safetensors'))
            torch.save(optimizer.state_dict(), os.path.join(checkpoint_dir, 'optimizer_L1T2.pt'))
            torch.save(scheduler.state_dict(), os.path.join(checkpoint_dir, 'scheduler_L1T2.pt'))
            torch.save(scaler.state_dict(), os.path.join(checkpoint_dir, 'scaler_L1T2.pt'))

            nero_state_dict_L2T2 = {n: p.detach().cpu() for n, p in nero_model.named_parameters() if 'L2T2' in n}
            save_file(nero_state_dict_L2T2, os.path.join(checkpoint_dir, 'adapter_model_L2T2.safetensors'))
            torch.save(optimizer_L2T2.state_dict(), os.path.join(checkpoint_dir, 'optimizer_L2T2.pt'))
            torch.save(scheduler_L2T2.state_dict(), os.path.join(checkpoint_dir, 'scheduler_L2T2.pt'))
            torch.save(scaler_L2T2.state_dict(), os.path.join(checkpoint_dir, 'scaler_L2T2.pt'))

            # Save trainer state for resuming training
            trainer_state = {
                'epoch': epoch,
                'global_step': global_step,
                'log_history': log_history,
            }
            with open(os.path.join(checkpoint_dir, 'trainer_state.json'), 'w') as f:
                json.dump(trainer_state, f, indent=2)

            # Save RNG state for reproducibility
            rng_state = {
                'python': random.getstate(),
                'numpy': np.random.get_state(),
                'cpu': torch.get_rng_state(),
                'cuda': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else [],
            }
            torch.save(rng_state, os.path.join(checkpoint_dir, 'rng_state.pth'))

            # Upload checkpoint directory to Hugging Face repository
            upload_folder(
                folder_path=checkpoint_dir,
                repo_id=hf_nero_id,
                path_in_repo=f"checkpoint-{global_step}",
                commit_message=f"Add checkpoint at step {global_step}",
                repo_type='model',
            )
        
        # Check generated text every `generate_steps`
        if global_step > 0 and global_step % generate_steps == 0:
            nero_model.set_return_hidden_outputs(False)
            nero_model.set_mode('student')
            generated = generate_text(nero_model, tokenizer, sample_prompt, device=device)
            print()
            print("================================")
            print("CHECK GENERATED TEXT")
            print("================================")
            print(f"{'Prompt':<9}:", sample_prompt)
            print(f"{'Generated':<9}:", generated)
            print()
        
        global_step += 1
    
    if done:
        break


  with torch.cuda.amp.autocast():


target: L1T2, epoch: 0, step: 1, L1T2/loss: 8.1048, L1T2/lr: 3.3333e-05, L2T2/lr: 0.0000e+00, L1T2/grad_norm: 0.1440, L1T2/grad_norm_clipped: 0.1440
target: L2T2, epoch: 0, step: 2, L2T2/loss: 0.2984, L1T2/lr: 3.3333e-05, L2T2/lr: 0.0000e+00
target: L2T2, epoch: 0, step: 3, L2T2/loss: 0.2729, L1T2/lr: 3.3333e-05, L2T2/lr: 3.3333e-05, L2T2/grad_norm: 0.0244, L2T2/grad_norm_clipped: 0.0244
target: L1T2, epoch: 0, step: 4, L1T2/loss: 8.4594, L1T2/lr: 3.3333e-05, L2T2/lr: 3.3333e-05
target: L1T2, epoch: 0, step: 5, L1T2/loss: 8.3069, L1T2/lr: 6.6667e-05, L2T2/lr: 3.3333e-05, L1T2/grad_norm: 0.2683, L1T2/grad_norm_clipped: 0.2683
target: L2T2, epoch: 0, step: 6, L2T2/loss: 0.2873, L1T2/lr: 6.6667e-05, L2T2/lr: 3.3333e-05
target: L2T2, epoch: 0, step: 7, L2T2/loss: 0.2741, L1T2/lr: 6.6667e-05, L2T2/lr: 6.6667e-05, L2T2/grad_norm: 0.1044, L2T2/grad_norm_clipped: 0.1044
target: L1T2, epoch: 0, step: 8, L1T2/loss: 8.6767, L1T2/lr: 6.6667e-05, L2T2/lr: 6.6667e-05
target: L1T2, epoch: 0, step: 9,

RepositoryNotFoundError: 404 Client Error. (Request ID: Root=1-689d5078-08ffb899152d178717201155;d6d99ca5-2642-4bbd-ae06-dd8aa9062d6e)

Repository Not Found for url: https://huggingface.co/api/models/alxxtexxr/L3.1-8B-gsm8k-en-5K-1K-1K-512K-Nero-v20250814024441/preupload/main.
Please make sure you specified the correct `repo_id` and `repo_type`.
If you are trying to access a private or gated repo, make sure you are authenticated. For more details, see https://huggingface.co/docs/huggingface_hub/authentication
Note: Creating a commit assumes that the repo already exists on the Huggingface Hub. Please use `create_repo` if it's not the case.

In [33]:
wandb.finish()

0,1
epoch,▁
grad_norm,▁
grad_norm_clipped,▁
loss,▁
lr,▁
step,▁

0,1
epoch,0
grad_norm,0.14058
grad_norm_clipped,0.14058
loss,7.88927
lr,3e-05
mode,L1T2
step,1
