In [1]:
# Kill all processess on GPU
# !fuser -v /dev/nvidia* -k

# Libraries

In [2]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    %pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    %pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    %pip install sentencepiece protobuf "datasets>=4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    %pip install --no-deps unsloth
%pip install trl==0.19.1 # Fix error: ImportError: cannot import name 'ConstantLengthDataset' from 'trl.trainer.utils'

In [3]:
import os
import functools
import gc
import json
import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from typing import Optional
from torch.utils.data import DataLoader
from datasets import load_dataset, Dataset
from transformers import default_data_collator, AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
from huggingface_hub import snapshot_download, create_repo, upload_folder
from safetensors.torch import load_file, save_file
from datetime import datetime
from pprint import pprint

In [4]:
def download_hf_model(
        repo_id: str, 
        checkpoint: Optional[int], 
        max_checkpoints: str = 10_000, 
        checkpoint_steps: str = 25,
    ):
    local_dir = repo_id.split('/')[-1]
    ignore_checkpoints = None
    
    if checkpoint is not None:
        ignore_checkpoints = [f'checkpoint-{i}/*' for i in range(0, max_checkpoints, checkpoint_steps) if i != checkpoint]

    snapshot_download(
        repo_id=repo_id,
        local_dir=local_dir,
        ignore_patterns=ignore_checkpoints,
    )

    checkpoint_dir = None
    if checkpoint is not None:
        checkpoint_dir = os.path.join(local_dir, f'checkpoint-{checkpoint}')
    return local_dir, checkpoint_dir

def check_loss_and_grad_norm(model, tokenizer, prompt="Paris is the capital of"):
    # Set model to training mode
    model.train()

    # Zero gradients manually
    for p in model.parameters():
        p.grad = None

    # Forward pass
    inputs = tokenizer(
        prompt, 
        return_tensors='pt',
    ).to(next(model.parameters()).device)

    outputs = model(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        labels=inputs['input_ids'],
        use_cache=False, # Disable cache to not conflict with gradient checkpointing
    )
    print("Loss:", outputs.loss)

    # Backward pass
    outputs.loss.backward()

    # Compute gradient norm
    grad_norm = 0.0
    for n, p in model.named_parameters():
        if p.requires_grad:
            p_grad_norm = p.grad.data.norm(2)
            grad_norm += p_grad_norm.item() ** 2
    grad_norm = grad_norm ** 0.5

    print("Gradient norm:", grad_norm)

def check_parameter(n, p):
    print(f"- {'name':<8}:", n)
    print(f"- {'device':<8}:", p.device)
    print(f"- {'dtype':<8}:", p.dtype)
    print(f"- {'mean':<8}:", p.mean().item())
    print(f"- {'min':<8}:", p.min().item())
    print(f"- {'max':<8}:", p.max().item())

def check_lora_parameters(model, mode=None):
    prefix = mode + '_lora' if mode != None else 'lora'
    for n, p in model.named_parameters():
        if prefix in n:
            check_parameter(n, p)
            break

@torch.no_grad()
def generate_text(model, tokenizer, prompt, max_new_tokens=32, device=None, skip_special_tokens=True):
    device = device or next(model.parameters()).device
    model.eval()
    inputs = tokenizer(prompt, return_tensors='pt').to(device)
    outputs = model.generate(input_ids=inputs['input_ids'], max_new_tokens=max_new_tokens)
    model.train()
    return tokenizer.decode(outputs[0], skip_special_tokens=skip_special_tokens)

def get_task_and_lang_from_repo_id(repo_id: str):
    task, lang, _ = repo_id.split('B-')[-1].split('K-')[0].split('-')
    return task, lang

def load_hf_dataset_from_lora(
    lora_repo_id: str,
    train_size: int = 5000,
    test_size: int = 1000,
):
    # Get task and language
    task, lang = get_task_and_lang_from_repo_id(lora_repo_id)

    # Set up Hugging Face configuration
    data_id_map = {
        'wikipedia': 'wikimedia/wikipedia',
        'gsm8k': 'openai/gsm8k',
    }
    data_id = data_id_map[task]
    data_dir = f'20231101.{lang}' if task == 'wikipedia' else 'main'
    
    # Load dataset using streaming
    dataset_stream = load_dataset(data_id, data_dir=data_dir, split='train', streaming=True)

    # Manually take train_size + test_size samples
    total_size = train_size + test_size
    sliced_data = []
    for i, example in enumerate(dataset_stream):
        if i >= total_size:
            break
        sliced_data.append(example)

    # Convert to regular in-memory dataset
    dataset = Dataset.from_list(sliced_data)
    
    return dataset

def load_hf_dataset(
    lang, 
    task,
    split='train',
    train_size = 5000,
    test_size = 1000,
):
    # Set up Hugging Face configuration
    data_id_map = {
        'wikipedia': 'wikimedia/wikipedia',
        'gsm8k': 'openai/gsm8k',
    }
    data_id = data_id_map[task]
    data_dir = f'20231101.{lang}' if task == 'wikipedia' else 'main'

    # Use streaming
    dataset_stream = load_dataset(data_id, data_dir=data_dir, split=split, streaming=True)

    # Manually take train_size + test_size samples
    total_size = train_size + test_size
    sliced_data = []
    for i, example in enumerate(dataset_stream):
        if i >= total_size:
            break
        sliced_data.append(example)

    # Convert to regular in-memory dataset
    dataset = Dataset.from_list(sliced_data)
    
    return dataset

def compute_grad_norm(params):
    grad_norm = 0.0
    for p in params:
        p_grad_norm = p.grad.data.norm(2)
        grad_norm += p_grad_norm.item() ** 2
    grad_norm = grad_norm ** 0.5
    return grad_norm

# Configurations

In [5]:
# Project configuration
seed = 69
device = 'auto'

# Data configuration
hf_data_id = 'alxxtexxr/Nero-XLT-Dataset'
hf_data_dir = 'wikipedia_ja_5K_1K_1K_64'

# Training configuration
batch_size = 4
num_epochs = 1
max_global_steps = None
grad_accumulation_steps = 2
clip_grad_norm = 3.0
lr = 2e-4
warmup_ratio = 0.01
# num_warmup_steps = 100
checkpoint_steps = 50
generate_steps = 50
sample_prompt = '日本は'

# Model configurations
model_configs = {
    # L1T1 (Source Language - Source Task)
    'source': {
        'label': 'L1T1',
        'hf_lora_id': 'alxxtexxr/L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650',
        'checkpoint': 650,
    },

    # L2T1 (Target Language - Source Task)
    'target': {
        'label': 'L2T1',
        'hf_lora_id': 'alxxtexxr/L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629',
        'checkpoint': 650,
    },

    # L1T2 (Source Language - Target Task)
    # 'target': {
    #     'label': 'L1T2',
    #     'hf_lora_id': 'alxxtexxr/L3.1-8B-gsm8k-en-5K-LoRA-v20250701060457',
    #     'checkpoint': 1875,
    # },
}

for key, config in model_configs.items():
    _, lora_dir = download_hf_model(config['hf_lora_id'], config['checkpoint'])
    model_configs[key]['lora_dir'] = lora_dir
    model_configs[key]['lora_path'] = os.path.join(lora_dir, 'adapter_model.safetensors')
    model_configs[key]['lora_config'] = LoraConfig.from_pretrained(lora_dir)

print("Model configurations:",)
for key, config in model_configs.items():
    print(f"- {key}:")
    for config_name, config_value in config.items():
        if config_name == 'lora_config':
            continue
        print(f"{'-':>3} {config_name:<10}: {config_value}")
print()

assert model_configs['source']['lora_config'].base_model_name_or_path == model_configs['target']['lora_config'].base_model_name_or_path, "Base models must be the same"
base_model_name = model_configs['source']['lora_config'].base_model_name_or_path
print(f"Base model name: {base_model_name}")

tokenizer = AutoTokenizer.from_pretrained(base_model_name)

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


Fetching 20 files:   0%|          | 0/20 [00:00<?, ?it/s]

Fetching 20 files:   0%|          | 0/20 [00:00<?, ?it/s]

Model configurations:
- source:
  - label     : L1T1
  - hf_lora_id: alxxtexxr/L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650
  - checkpoint: 650
  - lora_dir  : L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650/checkpoint-650
  - lora_path : L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650/checkpoint-650/adapter_model.safetensors
- target:
  - label     : L2T1
  - hf_lora_id: alxxtexxr/L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629
  - checkpoint: 650
  - lora_dir  : L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629/checkpoint-650
  - lora_path : L3.1-8B-wikipedia-ja-5K-LoRA-v20250728141629/checkpoint-650/adapter_model.safetensors

Base model name: unsloth/meta-llama-3.1-8b-unsloth-bnb-4bit


In [6]:
# Hugging Face configuration
hf_nero_id = None
resume_step = 0

if hf_nero_id is not None and resume_step > 0:
    print(f"[INFO] Downloading Nero checkpoint at step {resume_step} from Hugging Face repository:", hf_nero_id)
    nero_dir, _ = download_hf_model(hf_nero_id, resume_step)
    print(f"[INFO] Nero checkpoint downloaded successfully!")
else:
    hf_username = 'alxxtexxr'
    nero_dir = f'L3.1-8B-{hf_data_dir.replace("_", "-")}K-Nero-v{datetime.now().strftime("%Y%m%d%H%M%S")}'
    print(f"[INFO] Creating Nero directory:", nero_dir)
    hf_nero_id = f'{hf_username}/{nero_dir}'
    os.makedirs(nero_dir, exist_ok=True)
    print(f"[INFO] Nero directory created!")

[INFO] Creating Nero directory: L3.1-8B-wikipedia-ja-5K-1K-1K-64K-Nero-v20250807184323
[INFO] Nero directory created!


# Model

In [7]:
class NeroLayer(nn.Module):
    def __init__(self, base_layer, mode,
                 
                 # Teacher LoRA parameters
                 teacher_rank, teacher_alpha, teacher_dropout, teacher_bias, teacher_use_rslora,

                 # Student LoRA parameters
                 student_rank, student_alpha, student_dropout, student_bias, student_use_rslora,
                 
                 return_hidden_output=False,
                 debug=False,
                 module_name=None,
                 ):
        super().__init__()

        self.base_layer = base_layer
        self.device = base_layer.weight.device
        self.mode = mode

        self.return_hidden_output = return_hidden_output
        self.debug = debug
        self.module_name = module_name

        # Extract input and output features from the base layer
        in_features = getattr(base_layer, 'in_features', None)
        out_features = getattr(base_layer, 'out_features', None)

        if in_features is None or out_features is None:
            raise ValueError(f"Cannot determine in_features or out_features from {base_layer}.")

        # ================================================================================================================================
        # Teacher LoRA
        # ================================================================================================================================
        self.teacher_alpha = teacher_alpha
        self.teacher_bias = teacher_bias
        self.teacher_scaling = teacher_alpha / math.sqrt(teacher_rank) if teacher_use_rslora else teacher_alpha / teacher_rank
        self.teacher_dropout = nn.Dropout(teacher_dropout) if teacher_dropout > 0.0 else nn.Identity()

        # LoRA decomposition: A (down-projection) and B (up-projection)
        self.teacher_lora_A = nn.Linear(in_features, teacher_rank, bias=teacher_bias).to(self.device)  # Projects down
        self.teacher_lora_B = nn.Linear(teacher_rank, out_features, bias=teacher_bias).to(self.device) # Projects up

        # Initialize LoRA matrices: A ~ N(0, 1/rank), B initialized to 0
        teacher_std = 1 / torch.sqrt(torch.tensor(teacher_rank).float())
        nn.init.normal_(self.teacher_lora_A.weight, mean=0.0, std=teacher_std)
        nn.init.zeros_(self.teacher_lora_B.weight) 
        
        # ================================================================================================================================
        # Student LoRA
        # ================================================================================================================================
        self.student_alpha = student_alpha
        self.student_bias = student_bias
        self.student_scaling = student_alpha / math.sqrt(student_rank) if student_use_rslora else student_alpha / student_rank
        self.student_dropout = nn.Dropout(student_dropout) if student_dropout > 0.0 else nn.Identity()

        # LoRA decomposition: A (down-projection) and B (up-projection)
        self.student_lora_A = nn.Linear(in_features, student_rank, bias=student_bias).to(self.device)  # Projects down
        self.student_lora_B = nn.Linear(student_rank, out_features, bias=student_bias).to(self.device) # Projects up

        # Initialize LoRA matrices: A ~ N(0, 1/rank), B initialized to 0
        student_std = 1 / torch.sqrt(torch.tensor(student_rank).float())
        nn.init.normal_(self.student_lora_A.weight, mean=0.0, std=student_std)
        nn.init.zeros_(self.student_lora_B.weight) 
        
        # ================================================================================================================================
        # Student Nero
        # ================================================================================================================================
        # Nero decomposition: additional transformation applied to LoRA output
        self.student_nero_A = nn.Linear(out_features, student_rank, bias=student_bias).to(self.device)
        self.student_nero_B = nn.Linear(student_rank, out_features, bias=student_bias).to(self.device)

        # Initialize Nero matrices similarly
        nn.init.normal_(self.student_nero_A.weight, mean=0.0, std=student_std)
        nn.init.zeros_(self.student_nero_B.weight) 
        
    def forward(self, x):
        # print("================================================================")
        # print(self.module_name)
        # print("================================================================")

        base_out = self.base_layer(x)

        if self.mode == 'teacher':
            requires_conversion = not torch.is_autocast_enabled()
            if requires_conversion:
                x = x.to(self.teacher_lora_A.weight.dtype)
                
            # lora_out = self.teacher_lora_B(self.teacher_lora_A(self.teacher_dropout(x))) * self.teacher_scaling
            dropout_out = self.teacher_dropout(x)
            lora_A_out = self.teacher_lora_A(dropout_out)
            lora_B_out = self.teacher_lora_B(lora_A_out)
            lora_out = lora_B_out * self.teacher_scaling
            if requires_conversion:
                lora_out = lora_out.to(base_out.dtype)
                
            output = base_out + lora_out

            if self.return_hidden_output:
                return output, lora_out
        else:
            requires_conversion = not torch.is_autocast_enabled()
            if requires_conversion:
                x = x.to(self.student_lora_A.weight.dtype)

            # lora_out = self.student_lora_B(self.student_lora_A(self.student_dropout(x))) # * self.scaling
            dropout_out = self.student_dropout(x)
            lora_A_out = self.student_lora_A(dropout_out)
            lora_B_out = self.student_lora_B(lora_A_out)
            # lora_out = lora_B_out * self.student_scaling
            lora_out = lora_B_out
            # if requires_conversion:
            #     lora_out = lora_out.to(base_out.dtype)

            # nero_out = self.student_nero_B(F.relu(self.student_nero_A(lora_out)))
            nero_A_out = self.student_nero_A(lora_out)
            relu_out = F.relu(nero_A_out)
            nero_B_out = self.student_nero_B(relu_out)
            nero_out = nero_B_out * self.student_scaling
            if requires_conversion:
                nero_out = nero_out.to(base_out.dtype)

            output = base_out + nero_out

            if self.return_hidden_output:
                return output, nero_out
        
        return output

    def load_teacher_lora_params(self, state_dict, prefix):
        self.teacher_lora_A.weight.data = state_dict[f'{prefix}.lora_A.weight'].to(self.device)
        self.teacher_lora_B.weight.data = state_dict[f'{prefix}.lora_B.weight'].to(self.device)
        if self.teacher_bias:
            self.teacher_lora_A.bias.data = state_dict[f'{prefix}.lora_A.bias'].to(self.device)
            self.teacher_lora_B.bias.data = state_dict[f'{prefix}.lora_B.bias'].to(self.device)

    def load_student_lora_params(self, state_dict, prefix):
        self.student_lora_A.weight.data = state_dict[f'{prefix}.lora_A.weight'].to(self.device)
        self.student_lora_B.weight.data = state_dict[f'{prefix}.lora_B.weight'].to(self.device)
        if self.student_bias:
            self.student_lora_A.bias.data = state_dict[f'{prefix}.lora_A.bias'].to(self.device)
            self.student_lora_B.bias.data = state_dict[f'{prefix}.lora_B.bias'].to(self.device)
    
    def load_student_nero_params(self, state_dict, prefix):
        self.student_nero_A.weight.data = state_dict[f'{prefix}.student_nero_A.weight'].to(self.device)
        self.student_nero_B.weight.data = state_dict[f'{prefix}.student_nero_B.weight'].to(self.device)
        if self.student_bias:
            self.student_nero_A.bias.data = state_dict[f'{prefix}.student_nero_A.bias'].to(self.device)
            self.student_nero_B.bias.data = state_dict[f'{prefix}.student_nero_B.bias'].to(self.device)
    
class NeroModel(nn.Module):
    def __init__(self, 
                 base_model: nn.Module, 
                 teacher_lora_config: LoraConfig, 
                 student_lora_config: LoraConfig, 
                 mode='student',
                 return_hidden_outputs: bool=False,
                 debug: bool=False,
                ):
        super().__init__()
        self.base_model = base_model
        self.mode = mode
        self.nero_layers = nn.ModuleDict()
        self.return_hidden_outputs = return_hidden_outputs
        self.debug = debug

        # Wrap target layers with NeroLayer
        self._wrap_target_layers(teacher_lora_config, student_lora_config)
        
    def _wrap_target_layers(self, teacher_lora_config, student_lora_config):
        assert teacher_lora_config.target_modules == student_lora_config.target_modules, "[ERROR] Teacher and student LoRA configurations must have the same target modules."

        for module_name, module in self.base_model.named_modules():
            if isinstance(module, NeroLayer):
                # Convert module name format and store reference
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.nero_layers[module_name] = module
                continue

            if any(module_name.endswith(target_module) for target_module in student_lora_config.target_modules) and isinstance(module, nn.Linear):    
                parent_module, child_name = self._get_parent_module(module_name)
                nero_layer = NeroLayer(
                    base_layer=module,
                    mode=self.mode,

                    # Teacher LoRA parameters
                    teacher_rank=teacher_lora_config.r,
                    teacher_alpha=teacher_lora_config.lora_alpha,
                    teacher_dropout=teacher_lora_config.lora_dropout,
                    teacher_bias=teacher_lora_config.lora_bias,
                    teacher_use_rslora=teacher_lora_config.use_rslora,

                    # Student LoRA parameters
                    student_rank=student_lora_config.r,
                    student_alpha=student_lora_config.lora_alpha,
                    student_dropout=student_lora_config.lora_dropout,
                    student_bias=student_lora_config.lora_bias,
                    student_use_rslora=student_lora_config.use_rslora,

                    return_hidden_output=self.return_hidden_outputs,
                    debug=self.debug,
                    module_name=module_name,
                )
                setattr(parent_module, child_name, nero_layer)

                # Store LoRA layers for weight loading
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.nero_layers[module_name] = nero_layer
    
    def _get_parent_module(self, module_name):
        parts = module_name.split('.')
        parent_module = self.base_model
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)
        return parent_module, parts[-1]
    
    def set_mode(self, mode: str, verbose: bool=False):
        if mode not in ['teacher', 'student']:
            raise ValueError(f"[ERROR] Invalid mode: {mode}. Must be either 'teacher' or 'student'.")
        self.mode = mode
        for layer in self.nero_layers.values():
            layer.mode = mode
        
        if mode == 'teacher':
            self.freeze_all()
        else:
            self.freeze_all_except_nero()

        if verbose:
            print(f"[INFO] Model mode set to '{mode}'.")
    
    def set_return_hidden_outputs(self, return_hidden_outputs: bool, verbose: bool=False):
        self.return_hidden_outputs = return_hidden_outputs
        for layer in self.nero_layers.values():
            layer.return_hidden_output = return_hidden_outputs
        
        if verbose:
            print(f"[INFO] Return hidden outputs set to {return_hidden_outputs}.")
    
    def freeze_all(self, verbose=False):
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        if verbose:
            print("[INFO] All layers are frozen!")

    def freeze_all_except_nero(self, verbose=False):
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        for nero_layer in self.nero_layers.values():
            for param_name, param in nero_layer.named_parameters():
                if 'nero_A' in param_name or 'nero_B' in param_name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
        
        if verbose:
            print("[INFO] All layers are frozen except Nero layers!")
    
    def unfreeze_all(self):
        for param in self.base_model.parameters():
            param.requires_grad = True
        
        for nero_layer in self.nero_layers.values():
            for param in nero_layer.parameters():
                param.requires_grad = True
        
        print("[INFO] All layers are unfrozen!")
    
    def load_teacher_lora_params(self, teacher_lora_path):
        if not os.path.exists(teacher_lora_path):
            raise FileNotFoundError("[ERROR] Teacher LoRA file not found:", teacher_lora_path)
        
        if teacher_lora_path.endswith('.bin'):
            state_dict = torch.load(teacher_lora_path, map_location='cpu')
        else:
            state_dict = load_file(teacher_lora_path) # assuming .safetensors

        prefix = list(state_dict.keys())[0].rsplit('model.', 1)[0] + 'model.'

        for nero_layer_name, nero_layer in self.nero_layers.items():
            nero_layer_name = nero_layer_name.replace('__DOT__', '.')
            nero_layer_name = prefix + nero_layer_name
            if f'{nero_layer_name}.lora_A.weight' in state_dict and f'{nero_layer_name}.lora_B.weight' in state_dict:
                nero_layer.load_teacher_lora_params(state_dict, nero_layer_name)
            else:
                # TODO: Print warning message
                pass

        print("[INFO] Teacher LoRA parameters loaded successfully!")
    
    def load_student_lora_params(self, student_lora_path):
        if not os.path.exists(student_lora_path):
            raise FileNotFoundError("[ERROR] Student LoRA file not found:", student_lora_path)
        
        if student_lora_path.endswith('.bin'):
            state_dict = torch.load(student_lora_path, map_location='cpu')
        else:
            state_dict = load_file(student_lora_path) # assuming .safetensors

        prefix = list(state_dict.keys())[0].rsplit('model.', 1)[0] + 'model.'

        for nero_layer_name, nero_layer in self.nero_layers.items():
            nero_layer_name = nero_layer_name.replace('__DOT__', '.')
            nero_layer_name = prefix + nero_layer_name
            if f'{nero_layer_name}.lora_A.weight' in state_dict and f'{nero_layer_name}.lora_B.weight' in state_dict:
                nero_layer.load_student_lora_params(state_dict, nero_layer_name)
            else:
                # TODO: Print warning message
                pass

        print("[INFO] Student LoRA parameters loaded successfully!")
    
    def load_student_nero_params(self, student_nero_path):
        if not os.path.exists(student_nero_path):
            raise FileNotFoundError("[ERROR] student_Nero file not found:", student_nero_path)
        
        if student_nero_path.endswith('.bin'):
            state_dict = torch.load(student_nero_path, map_location='cpu')
        else:
            state_dict = load_file(student_nero_path) # assuming .safetensors

        prefix = list(state_dict.keys())[0].rsplit('model.', 1)[0] + 'model.'

        for nero_layer_name, nero_layer in self.nero_layers.items():
            nero_layer_name = nero_layer_name.replace('__DOT__', '.')
            nero_layer_name = prefix + nero_layer_name
            if f'{nero_layer_name}.student_nero_A.weight' in state_dict and f'{nero_layer_name}.student_nero_B.weight' in state_dict:
                nero_layer.load_student_nero_params(state_dict, nero_layer_name)
            else:
                # TODO: Print warning message
                pass

        print("[INFO] Student Nero parameters loaded successfully!")
    
    def forward(self, *args, **kwargs):
        if self.return_hidden_outputs:
            hidden_outs = {}
            
            def _hook_fn(layer_name, module, _in, _out):
                if isinstance(_out, tuple) and len(_out) == 2:
                    layer_out, hidden_out = _out
                    hidden_outs[layer_name] = hidden_out # Store hidden_out separately
                    return layer_out # Return only layer_out to avoid breaking model flow

            # Register hooks to extract hidden_out during forward pass
            hooks = []
            for layer_name, layer in self.nero_layers.items():
                hook = layer.register_forward_hook(functools.partial(_hook_fn, layer_name))
                hooks.append(hook)
        
            try:
                output = self.base_model(*args, **kwargs)
            finally:
                # Remove hooks after forward pass, ensuring it's done even if an error occurs
                for hook in hooks:
                    hook.remove()

            return output, hidden_outs
        
        return self.base_model(*args, **kwargs)
    
    def __getattr__(self, name):
        try:
            return super().__getattr__(name) # Try getting attribute from self
        except AttributeError:
            return getattr(self.base_model, name) # Fallback to base_model

base_nero_model = AutoModelForCausalLM.from_pretrained(
    base_model_name, 
    device_map=device,
)
nero_model = NeroModel(
    base_nero_model, 
    teacher_lora_config=model_configs['target']['lora_config'], 
    student_lora_config=model_configs['source']['lora_config'], 
    mode='student',
    return_hidden_outputs=False,
    debug=False,
)

In [8]:
# nero_model.set_mode('teacher')

# for param in nero_model.base_model.parameters():
#     param.requires_grad = False
        
# for nero_layer in nero_model.nero_layers.values():
#     for param_name, param in nero_layer.named_parameters():
#         if 'teacher_lora_A' in param_name or 'teacher_lora_B' in param_name:
#             param.requires_grad = True
#         else:
#             param.requires_grad = False

# check_loss_and_grad_norm(nero_model, tokenizer, prompt="Despite the ominous clouds gathering on the horizon and the distant rumble of thunder echoing through the valley, the determined hikers,")

In [9]:
print("Check unloaded teacher LoRA parameters:")
check_lora_parameters(nero_model, mode='teacher')
print()

nero_model.load_teacher_lora_params(model_configs['target']['lora_path'])
print()

print("Check loaded teacher LoRA parameters:")
check_lora_parameters(nero_model, mode='teacher')
print()

Check unloaded teacher LoRA parameters:
- name    : base_model.model.layers.0.self_attn.q_proj.teacher_lora_A.weight
- device  : cuda:0
- dtype   : torch.float32
- mean    : -0.001495478441938758
- min     : -1.6334571838378906
- max     : 1.582234501838684

[INFO] Teacher LoRA parameters loaded successfully!

Check loaded teacher LoRA parameters:
- name    : base_model.model.layers.0.self_attn.q_proj.teacher_lora_A.weight
- device  : cuda:0
- dtype   : torch.float32
- mean    : 2.2152138626552187e-05
- min     : -0.06327299773693085
- max     : 0.0625513345003128



In [10]:
print("Check unloaded student LoRA parameters:")
check_lora_parameters(nero_model, mode='student')
print()

nero_model.load_student_lora_params(model_configs['source']['lora_path'])
print()

print("Check loaded student LoRA parameters:")
check_lora_parameters(nero_model, mode='student')
print()

Check unloaded student LoRA parameters:
- name    : base_model.model.layers.0.self_attn.q_proj.student_lora_A.weight
- device  : cuda:0
- dtype   : torch.float32
- mean    : 0.000745462893974036
- min     : -1.3428223133087158
- max     : 1.3509130477905273

[INFO] Student LoRA parameters loaded successfully!

Check loaded student LoRA parameters:
- name    : base_model.model.layers.0.self_attn.q_proj.student_lora_A.weight
- device  : cuda:0
- dtype   : torch.float32
- mean    : 6.287686119321734e-05
- min     : -0.04176201671361923
- max     : 0.04242725297808647



In [11]:
nero_model.freeze_all_except_nero()
print()

# check_loss_and_grad_norm(nero_model, tokenizer)




In [12]:
nero_model.gradient_checkpointing_enable({'use_reentrant': False})
print("[INFO] Gradient checkpointing enabled!")
print()

check_loss_and_grad_norm(nero_model, tokenizer)

[INFO] Gradient checkpointing enabled!

Loss: tensor(3.4235, device='cuda:0', grad_fn=<NllLossBackward0>)
Gradient norm: 9.222118304758878


# Data

In [13]:
datasets = load_dataset(hf_data_id, data_dir=hf_data_dir)
print(datasets)

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 443134
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 71172
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 65808
    })
})


In [14]:
train_loader = DataLoader(
    datasets['train'], 
    batch_size=batch_size, 
    shuffle=True, 
    collate_fn=default_data_collator,
)
print("[INFO] Total batches:", len(train_loader))

[INFO] Total batches: 110784


In [15]:
# Sanity check
first_batch = next(iter(train_loader))
print("First batch data shape (input_ids, attention_mask):")
print((
    first_batch['input_ids'].shape, 
    first_batch['attention_mask'].shape, 
))
print()

first_batch_text = tokenizer.batch_decode(first_batch['input_ids'], skip_special_tokens=True)[0]
print("First batch text:")
print(first_batch_text[:100], "...")
print()

check_loss_and_grad_norm(nero_model, tokenizer)

First batch data shape (input_ids, attention_mask):
(torch.Size([4, 64]), torch.Size([4, 64]))

First batch text:
18030
 
 規格名は。GBKをさらに拡張し、少数民族言語の文字なども含む大規模な文字セットで、GBKに取って代わる正式な国家規格。2000年3月17日に国家質量技術監督局（当 ...

Loss: tensor(3.4235, device='cuda:0', grad_fn=<NllLossBackward0>)
Gradient norm: 9.222118304758878


# Training

In [16]:
# def loss_fn_v4(nero_outs, lora_outs, nero_logits, lora_logits, 
#                alpha=1.0, beta=0.00015, temperature=2.0):
#     assert nero_outs.keys() == lora_outs.keys(), "`nero_outs` and `lora_outs` must have the same layers."
#     loss_device = next(iter(nero_outs.values())).device
#     total_hidden_loss = torch.tensor(0.0, device=loss_device)

#     # --- Hidden representation loss ---
#     for layer_name in lora_outs.keys():
#         nero_out = nero_outs[layer_name]
#         lora_out = lora_outs[layer_name].to(nero_out.device)

#         # Scale by sequence length to avoid length-sensitive loss scaling
#         hidden_loss = torch.mean(torch.abs(nero_out.float() - lora_out.float()), dim=-1).mean()
#         total_hidden_loss += hidden_loss

#     total_hidden_loss /= len(lora_outs)
#     print("total_hidden_loss:", total_hidden_loss)
#     total_hidden_loss *= alpha

#     # --- Logit KL divergence loss ---
#     # Important: apply softmax with temperature for distillation
#     logit_loss = F.kl_div(
#         F.log_softmax(nero_logits.to(loss_device) / temperature, dim=-1),
#         F.softmax(lora_logits.to(loss_device) / temperature, dim=-1),
#         reduction='batchmean'
#     ) * (temperature ** 2)
#     print("logit_loss:", logit_loss)
#     logit_loss *= beta

#     # --- Final combined loss ---
#     total_loss = total_hidden_loss + logit_loss
#     return total_loss, total_hidden_loss, logit_loss

In [17]:
# def loss_fn_v5(nero_model_outs, lora_model_outs, attention_mask, alpha=0.5, temperature=2.0):
#     lm_loss = nero_model_outs.loss
#     nero_logits = nero_model_outs.logits  # (batch_size, seq_len, vocab_size)
#     lora_logits = lora_model_outs.logits  # (batch_size, seq_len, vocab_size)

#     # Compute softened distributions
#     student_log_probs = F.log_softmax(nero_logits / temperature, dim=-1)
#     teacher_probs = F.softmax(lora_logits / temperature, dim=-1)

#     # Flatten for masking
#     student_log_probs = student_log_probs.view(-1, student_log_probs.size(-1))
#     teacher_probs = teacher_probs.view(-1, teacher_probs.size(-1))
#     flat_mask = attention_mask.view(-1).unsqueeze(-1).to(dtype=student_log_probs.dtype)

#     # Element-wise KL divergence per token
#     kl_div = F.kl_div(student_log_probs, teacher_probs, reduction='none').sum(-1)  # (batch * seq)

#     # Apply mask and normalize
#     masked_kl = (kl_div * flat_mask.squeeze(-1)).sum() / flat_mask.sum()
#     logit_loss = masked_kl * (temperature ** 2)

#     return (lm_loss * alpha) + (logit_loss * (1 - alpha))

def loss_fn_v5(nero_model_outs, lora_model_outs, attention_mask, alpha=0.5, temperature=2.0):
    lm_loss = nero_model_outs.loss
    nero_logits = nero_model_outs.logits  # (batch_size, seq_len, vocab_size)
    lora_logits = lora_model_outs.logits  # (batch_size, seq_len, vocab_size)

    # Compute softened distributions (no clamping)
    student_log_probs = F.log_softmax(nero_logits / temperature, dim=-1)
    teacher_probs = F.softmax(lora_logits / temperature, dim=-1)

    # Flatten logits and attention mask
    student_log_probs = student_log_probs.view(-1, student_log_probs.size(-1))
    teacher_probs = teacher_probs.view(-1, teacher_probs.size(-1))
    flat_mask = attention_mask.view(-1).unsqueeze(-1).to(dtype=student_log_probs.dtype)

    # Compute KL divergence for each token (elementwise)
    kl_div = F.kl_div(student_log_probs, teacher_probs, reduction='none').sum(-1)  # shape: (batch * seq_len)

    # Mask padding positions and normalize
    masked_kl = (kl_div * flat_mask.squeeze(-1)).sum() / (flat_mask.sum() + 1e-8)

    # Temperature scaling
    logit_loss = masked_kl * (temperature ** 2)

    # Final loss: blend LM loss and KD loss
    total_loss = (lm_loss * alpha) + (logit_loss * (1 - alpha))

    return total_loss

# Sanity check
first_batch_text = "Roy Marshall (1930–1992) was a Barbadian cricketer who played in four Test matches for the West Indies and had an extensive domestic career with Hampshire in"
device = next(iter(nero_model.parameters())).device
inputs = tokenizer(first_batch_text, return_tensors='pt')

nero_model.set_mode('teacher')
nero_model.eval()
teacher_outs = nero_model(
    input_ids=inputs['input_ids'].to(device), 
    attention_mask=inputs['attention_mask'].to(device),
)

nero_model.set_mode('student')
nero_model.train()
student_outs = nero_model(
    input_ids=inputs['input_ids'].to(device), 
    attention_mask=inputs['attention_mask'].to(device),
    labels=inputs['input_ids'].to(device), 
)

loss = loss_fn_v5(student_outs, teacher_outs, inputs['attention_mask'].to(device))
print("loss:", loss.item())

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


loss: 1.606545329093933


In [18]:
def loss_fn_v6(
      nero_model_outs, 
      lora_model_outs, 
      nero_hidden_outs,
      lora_hidden_outs,
      attention_mask, 
      alpha=0.33, 
      beta=1.0, 
      gamma=1.0,
      temperature=2.0,
    ):
    # 1. LM loss
    lm_loss = nero_model_outs.loss
    device = lm_loss.device
    
    # 2. Logit distillation loss (KL divergence)
    nero_logits = nero_model_outs.logits
    lora_logits = lora_model_outs.logits
    student_log_probs = F.log_softmax(nero_logits / temperature, dim=-1)
    teacher_probs = F.softmax(lora_logits / temperature, dim=-1)
    student_log_probs = student_log_probs.view(-1, student_log_probs.size(-1))
    teacher_probs = teacher_probs.view(-1, teacher_probs.size(-1))
    flat_mask = attention_mask.view(-1).unsqueeze(-1).to(dtype=student_log_probs.dtype)
    kl_div = F.kl_div(student_log_probs, teacher_probs, reduction='none').sum(-1)
    logit_loss = (kl_div * flat_mask.squeeze(-1)).sum() / (flat_mask.sum() + 1e-8)
    logit_loss = logit_loss * (temperature ** 2)

    # 3. Representation distillation loss (MSE between last hidden states)
    # student_hidden = nero_model_outs.hidden_states[-1]
    # teacher_hidden = lora_model_outs.hidden_states[-1]
    # mask = attention_mask.unsqueeze(-1).to(dtype=student_hidden.dtype)
    # mse_loss = F.mse_loss(student_hidden * mask, teacher_hidden * mask, reduction='sum')
    # rep_loss = mse_loss / (mask.sum() + 1e-8)

    total_hidden_loss = torch.tensor(0.0, device=device)
    for layer_name in lora_hidden_outs.keys():
        nero_out = nero_hidden_outs[layer_name]
        lora_out = lora_hidden_outs[layer_name].to(nero_out.device)

        # MAE Loss
        # hidden_loss = torch.mean(torch.abs(nero_out.float() - lora_out.float()), dim=-1).mean()
        
        # MSE Loss
        mse_loss = F.mse_loss(nero_out.float(), lora_out.float(), reduction='sum')
        hidden_loss = mse_loss / torch.sum(lora_out ** 2)

        total_hidden_loss += hidden_loss
    total_hidden_loss /= len(lora_hidden_outs)

    # Combine all losses
    # total_loss = (lm_loss * alpha) + (logit_loss * (1 - alpha)) + (total_hidden_loss * beta)
    total_loss = (alpha * lm_loss) + (beta * logit_loss) + (gamma * total_hidden_loss)
    return total_loss, lm_loss, logit_loss, total_hidden_loss


In [19]:
# Sanity check
first_batch_text = "Roy Marshall (1930–1992) was a Barbadian cricketer who played in four Test matches for the West Indies and had an extensive domestic career with Hampshire in"
device = next(iter(nero_model.parameters())).device
inputs = tokenizer(first_batch_text, return_tensors='pt')

nero_model.set_return_hidden_outputs(True)

nero_model.set_mode('teacher')
nero_model.eval()
teacher_outs, teacher_hidden_outs = nero_model(
    input_ids=inputs['input_ids'].to(device), 
    attention_mask=inputs['attention_mask'].to(device),
)

nero_model.set_mode('student')
nero_model.train()
student_outs, student_hidden_outs = nero_model(
    input_ids=inputs['input_ids'].to(device), 
    attention_mask=inputs['attention_mask'].to(device),
    labels=inputs['input_ids'].to(device), 
)

loss = loss_fn_v6(
    student_outs, 
    teacher_outs, 
    student_hidden_outs,
    teacher_hidden_outs,
    inputs['attention_mask'].to(device),
)
print("loss:", loss)

loss: (tensor(2.7843, device='cuda:0', grad_fn=<AddBackward0>), tensor(2.1330, device='cuda:0', grad_fn=<NllLossBackward0>), tensor(1.0801, device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>), tensor(1.0003, device='cuda:0', grad_fn=<DivBackward0>))


In [20]:
# Set model to training mode
nero_model.set_mode('student')
nero_model.train()

# If `device` is not specified or set to 'auto', use model's device
# if device is None or device == 'auto':
device = next(iter(nero_model.parameters())).device

# Set up optimizer and gradient scaler
nero_params = [p for n, p in nero_model.named_parameters() if p.requires_grad]
optimizer = torch.optim.Adam(nero_params, lr=lr, foreach=False)
scaler = torch.cuda.amp.GradScaler()

# Set up LR scheduler
max_global_steps = max_global_steps or len(train_loader) * num_epochs
warmup_steps = int(warmup_ratio * max_global_steps)
if warmup_ratio > 0:
# if num_warmup_steps > 0:
    # If `warmup_ratio` > 0, use cosine annealing scheduler with warm-up 
    from transformers import get_cosine_schedule_with_warmup # type: ignore
    max_optimizer_steps = max_global_steps // grad_accumulation_steps
    num_warmup_steps = int(warmup_ratio * max_optimizer_steps)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=max_optimizer_steps,

        # Cycle length is 2000 steps or less, but at least 1 cycle, and at most 10 cycles
        num_cycles=min(max(1, max_optimizer_steps // 2000), 10), 
    )
else:
    # If `warmup_ratio` is 0, use a dummy scheduler that returns constant LR
    from torch.optim.lr_scheduler import LambdaLR # type: ignore
    scheduler = LambdaLR(optimizer, lr_lambda=lambda step: 1.0)

# Initialize W&B
wandb.init(
    project='Nero-XLT',
    reinit=True, # End previous run and start a new one
    config=dict(
        # Project configuration
        seed = seed,
        # target_lang=target_lang,
        # target_task=target_task,
        device = device,

        # Data configuration
        # train_size = train_size,
        # test_size = test_size,
        # max_seq_length = max_seq_length,

        # Training configuration
        batch_size = batch_size,
        num_epochs = num_epochs,
        max_global_steps = max_global_steps,
        resume_step = resume_step,
        grad_accumulation_steps = grad_accumulation_steps,
        clip_grad_norm = clip_grad_norm,
        lr = lr,
        warmup_ratio = warmup_ratio,
        # num_warmup_steps = num_warmup_steps,
        checkpoint_steps = checkpoint_steps,
    ),
)

# Resume training
global_step = resume_step
start_epoch = 0

if resume_step > 0:
    checkpoint_dir = os.path.join(nero_dir, f'checkpoint-{resume_step}')
    print(f"[INFO] Resuming training from checkpoint directory:", checkpoint_dir)

    # Load Nero parameters
    nero_path = os.path.join(checkpoint_dir, 'adapter_model.safetensors')
    nero_model.load_nero_params(nero_path)

    # Load optimizer state
    optimizer_path = os.path.join(checkpoint_dir, 'optimizer.pt')
    optimizer.load_state_dict(torch.load(optimizer_path, map_location=device))
    
    # Move optimizer state to the correct device
    for param in optimizer.state:
        param_device = param.device
        param_dtype = param.dtype
        for key, value in optimizer.state[param].items():
            if isinstance(value, torch.Tensor):
                optimizer.state[param][key] = value.to(device=param_device, dtype=param_dtype)

    # Load scheduler state
    scheduler_path = os.path.join(checkpoint_dir, 'scheduler.pt')
    scheduler.load_state_dict(torch.load(scheduler_path, map_location=device))

    # Load scaler state
    scaler_path = os.path.join(checkpoint_dir, 'scaler.pt')
    scaler.load_state_dict(torch.load(scaler_path, map_location=device))

    # Load trainer state
    trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json')
    if os.path.exists(trainer_state_path):
        with open(trainer_state_path, 'r') as f:
            trainer_state = json.load(f)
        log_history = trainer_state.get('log_history', [])
        start_epoch = log_history[-1]['epoch'] if log_history else 0
        print(f"[INFO] Resuming training from epoch {start_epoch} and step {resume_step}.")

    # Load RNG state for reproducibility
    rng_path = os.path.join(checkpoint_dir, 'rng_state.pth')
    if os.path.exists(rng_path):
        rng_state = torch.load(rng_path)
        random.setstate(rng_state['python'])
        np.random.set_state(rng_state['numpy'])
        torch.set_rng_state(rng_state['cpu'])
        if torch.cuda.is_available() and rng_state['cuda']:
            torch.cuda.set_rng_state_all(rng_state['cuda'])
    
    if resume_step % grad_accumulation_steps != 0:
        print("[WARN] Resuming mid-gradient accumulation cycle. Make sure this is intended.")
else:
    # If it's new training, create Hugging Face repository
    print(f"[INFO] Creating Hugging Face repository:", hf_nero_id) # print the link instead
    create_repo(repo_id=hf_nero_id, repo_type='model', exist_ok=True)
    print(f"[INFO] Hugging Face repository created successfully!")

log_history = []
done = False

# Safety: Zero gradients at the start of gradient accumulation cycle
# This ensures there are no leftover gradients when resuming mid-cycle or after a previous cycle was interrupted
if global_step % grad_accumulation_steps == 0:
    optimizer.zero_grad(set_to_none=True)

# Training loop
for epoch in range(start_epoch, num_epochs):
    for step, batch in enumerate(train_loader):
        # Skip previously completed steps
        if global_step <= resume_step:
            global_step += 1
            continue

        # Stop training if `max_global_steps` reached
        if global_step >= max_global_steps:
            done = True
            break

        # Move inputs to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        with torch.cuda.amp.autocast():
            nero_model.set_return_hidden_outputs(True)

            # Forward pass for teacher
            nero_model.set_mode('teacher')
            nero_model.eval()
            teacher_model_outs, teacher_hidden_outs = nero_model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            
            # Forward pass for student
            nero_model.set_mode('student')
            nero_model.train()
            student_model_outs, student_hidden_outs = nero_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=input_ids,
            )

            nero_model.set_return_hidden_outputs(False)

            # Compute loss
            # _loss = loss_fn_v5(student_model_outs, teacher_model_outs, attention_mask)
            total_loss, lm_loss, logit_loss, hidden_loss = loss_fn_v6(
                student_model_outs, 
                teacher_model_outs, 
                student_hidden_outs,
                teacher_hidden_outs,
                attention_mask,
            )
            loss = total_loss / grad_accumulation_steps

        log = {
            'epoch': epoch,
            'step': global_step,
            'loss': loss.item() * grad_accumulation_steps,
            'lm_loss': lm_loss.item(),
            'logit_loss': logit_loss.item(),
            'hidden_loss': hidden_loss.item(),
        }

        # Backward pass
        with torch.autograd.set_detect_anomaly(True):
            scaler.scale(loss).backward()
        
        # Update parameters only at the end of gradient accumulation cycle
        grad_norm_log = {}
        if (global_step + 1) % grad_accumulation_steps == 0:
            # Unscale gradients before computing gradient norm and applying clipping
            scaler.unscale_(optimizer)
            
            # Compute gradient norm
            grad_norm = compute_grad_norm(nero_params)
            grad_norm_log['grad_norm'] = grad_norm

            # Clip gradients
            if clip_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(nero_params, clip_grad_norm)
            
            # Compute clipped gradient norm
            grad_norm_clipped = compute_grad_norm(nero_params)
            grad_norm_log['grad_norm_clipped'] = grad_norm_clipped

            # Update parameters
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            # Zero gradients for the next gradient accumulation cycle
            optimizer.zero_grad(set_to_none=True)

        # Logging
        log = {
            **log, 
            'lr': scheduler.get_last_lr()[0], 
            **grad_norm_log,
        }
        log_history.append(log)
        wandb.log(log)
        print(", ".join(f"{k}: {v}" for k, v in log.items()))
        
        # Save and push checkpoint every `checkpoint_steps`
        if global_step > 0 and global_step % checkpoint_steps == 0:
            # Create checkpoint directory
            checkpoint_dir = os.path.join(nero_dir, f'checkpoint-{global_step}')
            os.makedirs(checkpoint_dir, exist_ok=True)

            # Save Nero parameters, along with optimizer, scheduler, and scaler states
            nero_state_dict = {n: p.detach().cpu() for n, p in nero_model.named_parameters() if p.requires_grad}
            save_file(nero_state_dict, os.path.join(checkpoint_dir, 'adapter_model.safetensors'))  # NEW
            torch.save(optimizer.state_dict(), os.path.join(checkpoint_dir, 'optimizer.pt'))
            torch.save(scheduler.state_dict(), os.path.join(checkpoint_dir, 'scheduler.pt'))
            torch.save(scaler.state_dict(), os.path.join(checkpoint_dir, 'scaler.pt'))

            # Save trainer state for resuming training
            trainer_state = {
                'epoch': epoch,
                'global_step': global_step,
                'log_history': log_history,
            }
            with open(os.path.join(checkpoint_dir, 'trainer_state.json'), 'w') as f:
                json.dump(trainer_state, f, indent=2)

            # Save RNG state for reproducibility
            rng_state = {
                'python': random.getstate(),
                'numpy': np.random.get_state(),
                'cpu': torch.get_rng_state(),
                'cuda': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else [],
            }
            torch.save(rng_state, os.path.join(checkpoint_dir, 'rng_state.pth'))

            # Upload checkpoint directory to Hugging Face repository
            upload_folder(
                folder_path=checkpoint_dir,
                repo_id=hf_nero_id,
                path_in_repo=f"checkpoint-{global_step}",
                commit_message=f"Add checkpoint at step {global_step}",
                repo_type='model',
            )
        
        # Check generated text every `generate_steps`
        if global_step > 0 and global_step % generate_steps == 0:
            nero_model.set_return_hidden_outputs(False)
            nero_model.set_mode('student')
            generated = generate_text(nero_model, tokenizer, sample_prompt, device=device)
            print()
            print("================================")
            print("CHECK GENERATED TEXT")
            print("================================")
            print(f"{'Prompt':<9}:", sample_prompt)
            print(f"{'Generated':<9}:", generated)
            print()
        
        global_step += 1
    
    if done:
        break

wandb.finish()

  scaler = torch.cuda.amp.GradScaler()
[34m[1mwandb[0m: Currently logged in as: [33malimtegar[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[INFO] Creating Hugging Face repository: alxxtexxr/L3.1-8B-wikipedia-ja-5K-1K-1K-64K-Nero-v20250807184323
[INFO] Hugging Face repository created successfully!


  with torch.cuda.amp.autocast():


epoch: 0, step: 1, loss: 2.748265504837036, lm_loss: 2.8924598693847656, logit_loss: 0.7937536239624023, hidden_loss: 1.0, lr: 3.616636528028933e-07, grad_norm: 0.723087389408253, grad_norm_clipped: 0.723087389408253
epoch: 0, step: 2, loss: 2.6068525314331055, lm_loss: 2.317270278930664, logit_loss: 0.8421533107757568, hidden_loss: 1.0, lr: 3.616636528028933e-07
epoch: 0, step: 3, loss: 2.6405463218688965, lm_loss: 2.8171303272247314, logit_loss: 0.7108933925628662, hidden_loss: 1.0, lr: 7.233273056057866e-07, grad_norm: 1.3192691207366212, grad_norm_clipped: 1.3192691207366212
epoch: 0, step: 4, loss: 2.610607385635376, lm_loss: 2.466327428817749, logit_loss: 0.7967228889465332, hidden_loss: 0.9999963641166687, lr: 7.233273056057866e-07
epoch: 0, step: 5, loss: 2.7453324794769287, lm_loss: 2.7352848052978516, logit_loss: 0.8426922559738159, hidden_loss: 0.9999961853027344, lr: 1.08499095840868e-06, grad_norm: 1.3787425645496496, grad_norm_clipped: 1.3787425645496496
epoch: 0, step: 6

adapter_model.safetensors:   0%|          | 0.00/88.1M [00:00<?, ?B/s]

rng_state.pth:   0%|          | 0.00/14.2k [00:00<?, ?B/s]

Upload 5 LFS files:   0%|          | 0/5 [00:00<?, ?it/s]

optimizer.pt:   0%|          | 0.00/177M [00:00<?, ?B/s]

scaler.pt:   0%|          | 0.00/988 [00:00<?, ?B/s]

scheduler.pt:   0%|          | 0.00/1.00k [00:00<?, ?B/s]


CHECK GENERATED TEXT
Prompt   : 日本は
Generated: 日本は、世界の他の国々と比較して、人口が多い国で、国土面積が小さい国である。日本は、人口が多い国

epoch: 0, step: 51, loss: 2.395453453063965, lm_loss: 2.406221866607666, logit_loss: 0.6029367446899414, hidden_loss: 0.99846351146698, lr: 9.403254972875227e-06, grad_norm: 1.190116266288135, grad_norm_clipped: 1.190116266288135
epoch: 0, step: 52, loss: 2.6948933601379395, lm_loss: 2.6254611015319824, logit_loss: 0.83009272813797, hidden_loss: 0.99839848279953, lr: 9.403254972875227e-06
epoch: 0, step: 53, loss: 2.8036134243011475, lm_loss: 2.3813555240631104, logit_loss: 1.019332766532898, hidden_loss: 0.9984333515167236, lr: 9.76491862567812e-06, grad_norm: 1.5427950808324384, grad_norm_clipped: 1.5427950808324384
epoch: 0, step: 54, loss: 2.5087358951568604, lm_loss: 2.9238173961639404, logit_loss: 0.5457111597061157, hidden_loss: 0.9981648921966553, lr: 9.76491862567812e-06
epoch: 0, step: 55, loss: 2.66558837890625, lm_loss: 2.6368231773376465, logit_loss: 0.7971600294113159, hidden

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


RuntimeError: Function 'MmBackward0' returned nan values in its 1th output.