In [2]:
# Kill all processess on GPU
# !fuser -v /dev/nvidia* -k

# Libraries

In [3]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    %pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    %pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    %pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    %pip install --no-deps unsloth
%pip install trl==0.19.1 # Fix error: ImportError: cannot import name 'ConstantLengthDataset' from 'trl.trainer.utils'

In [4]:
import os
import functools
import gc
import json
import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch.utils.data import DataLoader
from datasets import load_dataset, Dataset
from transformers import default_data_collator, AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
from huggingface_hub import snapshot_download, create_repo, upload_folder
from safetensors.torch import load_file, save_file
from datetime import datetime
from pprint import pprint

In [5]:
def download_hf_model(
        repo_id, 
        checkpoint=None, 
        max_checkpoints=10_000, 
        checkpoint_steps=25,
    ):
    local_dir = repo_id.split('/')[-1]
    ignore_checkpoints = None
    
    if checkpoint is not None:
        ignore_checkpoints = [f'checkpoint-{i}/*' for i in range(0, max_checkpoints, checkpoint_steps) if i != checkpoint]

    snapshot_download(
        repo_id=repo_id,
        local_dir=local_dir,
        ignore_patterns=ignore_checkpoints,
    )

    checkpoint_dir = None
    if checkpoint is not None:
        checkpoint_dir = os.path.join(local_dir, f'checkpoint-{checkpoint}')
    return local_dir, checkpoint_dir

def check_loss_and_grad_norm(model, tokenizer, prompt="Paris is the capital of"):
    # Set model to training mode
    model.train()

    # Zero gradients manually
    for p in model.parameters():
        p.grad = None

    # Forward pass
    inputs = tokenizer(
        prompt, 
        return_tensors='pt',
    ).to(next(model.parameters()).device)

    outputs = model(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        labels=inputs['input_ids'],
        use_cache=False, # Disable cache to not conflict with gradient checkpointing
    )
    print("Loss:", outputs.loss)

    # Backward pass
    outputs.loss.backward()

    # Compute gradient norm
    grad_norm = 0.0
    for p in [p for n, p in model.named_parameters() if p.requires_grad]:
        p_grad_norm = p.grad.data.norm(2)
        grad_norm += p_grad_norm.item() ** 2
    grad_norm = grad_norm ** 0.5

    print("Gradient norm:", grad_norm)

@torch.no_grad()
def check_lora_parameters(model):
    for n, p in model.named_parameters():
        if 'lora' in n:
            print(f"- {'Name':<8}:", n)
            print(f"- {'Mean':<8}:", p.mean().item())
            print(f"- {'Min':<8}:", p.min().item())
            print(f"- {'Max':<8}:", p.max().item())
            break

@torch.no_grad()
def generate_text(model, tokenizer, prompt, max_new_tokens=32, device=None, skip_special_tokens=True):
    device = device or next(model.parameters()).device
    model.eval()
    inputs = tokenizer(prompt, return_tensors='pt').to(device)
    outputs = model.generate(input_ids=inputs['input_ids'], max_new_tokens=max_new_tokens)
    model.train()
    return tokenizer.decode(outputs[0], skip_special_tokens=skip_special_tokens)

def load_hf_dataset_from_lora(
    lora_repo_id,
    train_size = 5000,
    test_size = 1000,
):
    # Get task and language
    task, lang, _ = lora_repo_id.split('B-')[-1].split('K-')[0].split('-')

    # Set up Hugging Face configuration
    data_id_map = {
        'wikipedia': 'wikimedia/wikipedia',
        'gsm8k': 'openai/gsm8k',
    }
    data_id = data_id_map[task]
    data_dir = f'20231101.{lang}' if task == 'wikipedia' else 'main'
    split = f'train[:{(train_size+test_size)}]'

    # Load dataset
    # TODO: Use streaming to not download the entire dataset
    # dataset = load_dataset(data_id, data_dir=data_dir, split=split)

    # Use streaming
    dataset_stream = load_dataset(data_id, data_dir=data_dir, split='train', streaming=True)

    # Manually take train_size + test_size samples
    total_size = train_size + test_size
    sliced_data = []
    for i, example in enumerate(dataset_stream):
        if i >= total_size:
            break
        sliced_data.append(example)

    # Convert to regular in-memory dataset
    dataset = Dataset.from_list(sliced_data)
    
    return dataset

def load_hf_dataset(
    lang, 
    task,
    split='train',
    train_size = 5000,
    test_size = 1000,
):
    # Set up Hugging Face configuration
    data_id_map = {
        'wikipedia': 'wikimedia/wikipedia',
        'gsm8k': 'openai/gsm8k',
    }
    data_id = data_id_map[task]
    data_dir = f'20231101.{lang}' if task == 'wikipedia' else 'main'

    # Use streaming
    dataset_stream = load_dataset(data_id, data_dir=data_dir, split=split, streaming=True)

    # Manually take train_size + test_size samples
    total_size = train_size + test_size
    sliced_data = []
    for i, example in enumerate(dataset_stream):
        if i >= total_size:
            break
        sliced_data.append(example)

    # Convert to regular in-memory dataset
    dataset = Dataset.from_list(sliced_data)
    
    return dataset

def compute_grad_norm(params):
    grad_norm = 0.0
    for p in params:
        p_grad_norm = p.grad.data.norm(2)
        grad_norm += p_grad_norm.item() ** 2
    grad_norm = grad_norm ** 0.5
    return grad_norm

# Configurations

In [6]:
# Project configuration
seed = 69
target_lang = 'ja' # 'en' | 'id' | 'es'
target_task = 'wikipedia' # 'wikipedia' | 'gsm8k'
device = 'auto' # 'cpu' | 'cuda' | 'auto'

# Data configuration
train_size = 5000
test_size = 0 # Temporary
max_seq_length = 1024

# Training configuration
batch_size = 4
num_epochs = 1
max_global_steps = None
grad_accumulation_steps = 2
clip_grad_norm = 5.0
lr = 2e-4
warmup_ratio = 0.05
checkpoint_steps = 50
generate_steps = 50
sample_prompt = '日本は'

# Model configurations
hf_lora_id = 'alxxtexxr/L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650'
checkpoint = 650

_, lora_checkpoint_dir = download_hf_model(hf_lora_id, checkpoint)
lora_path = os.path.join(lora_checkpoint_dir, 'adapter_model.safetensors')
lora_config = LoraConfig.from_pretrained(lora_checkpoint_dir)

base_model_name = lora_config.base_model_name_or_path
tokenizer = AutoTokenizer.from_pretrained(base_model_name)

print("[CONFIG] Base model name:", base_model_name)

Fetching 20 files:   0%|          | 0/20 [00:00<?, ?it/s]

.gitattributes: 0.00B [00:00, ?B/s]

adapter_config.json:   0%|          | 0.00/871 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

README.md: 0.00B [00:00, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/83.9M [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/83.9M [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/871 [00:00<?, ?B/s]

optimizer.pt:   0%|          | 0.00/43.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

scheduler.pt:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

rng_state.pth:   0%|          | 0.00/14.2k [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

trainer_state.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.88k [00:00<?, ?B/s]

(…)t.tfevents.1751286655.ec3c369f91a7.385.0:   0%|          | 0.00/402k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.88k [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

[CONFIG] Base model name: unsloth/meta-llama-3.1-8b-unsloth-bnb-4bit


In [7]:
# Hugging Face configuration
hf_nero_id = None
resume_step = 0

if hf_nero_id is not None and resume_step > 0:
    print(f"[INFO] Downloading Nero checkpoint at step {resume_step} from Hugging Face repository:", hf_nero_id)
    nero_dir, _ = download_hf_model(hf_nero_id, resume_step)
    print(f"[INFO] Nero checkpoint downloaded successfully!")
else:
    hf_username = 'alxxtexxr'
    nero_dir = f'L3.1-8B-{target_task}-{target_lang}-{train_size//1000}K-Nero-v{datetime.now().strftime("%Y%m%d%H%M%S")}'
    print(f"[INFO] Creating Nero directory:", nero_dir)
    hf_nero_id = f'{hf_username}/{nero_dir}'
    os.makedirs(nero_dir, exist_ok=True)
    print(f"[INFO] Nero directory created!")

[INFO] Creating Nero directory: L3.1-8B-wikipedia-ja-5K-Nero-v20250805174511
[INFO] Nero directory created!


# Model

In [8]:
class NeroLayer(nn.Module):
    def __init__(self, base_layer, 
                 # LoRA parameters
                 rank, alpha, dropout, lora_bias, use_rslora, 
                 
                 # Nero parameters
                 nero_bias=False, 
                 return_nero_output=False,
                 
                 # For debugging 
                 debug=False,
                 module_name=None,
                 ):
        super().__init__()
        self.base_layer = base_layer
        self.device = base_layer.weight.device
        self.alpha = alpha
        self.lora_bias = lora_bias
        self.scaling = alpha / math.sqrt(rank) if use_rslora else alpha / rank
        self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
        self.nero_bias = nero_bias
        self.return_nero_output = return_nero_output

        # For debugging
        self.debug = debug
        self.module_name = module_name

        # Extract input and output features from the base layer
        in_features = getattr(base_layer, 'in_features', None)
        out_features = getattr(base_layer, 'out_features', None)

        if in_features is None or out_features is None:
            raise ValueError(f"Cannot determine in_features or out_features from {base_layer}.")
        
        # LoRA decomposition: A (down-projection) and B (up-projection)
        self.lora_A = nn.Linear(in_features, rank, bias=lora_bias).to(self.device)  # Projects down
        self.lora_B = nn.Linear(rank, out_features, bias=lora_bias).to(self.device) # Projects up

        # Initialize LoRA matrices: A ~ N(0, 1/rank), B initialized to 0
        std = 1 / torch.sqrt(torch.tensor(rank).float())
        nn.init.normal_(self.lora_A.weight, mean=0.0, std=std)
        nn.init.zeros_(self.lora_B.weight)

        # Nero decomposition: additional transformation applied to LoRA output
        self.nero_A = nn.Linear(out_features, rank, bias=nero_bias).to(self.device)
        self.nero_B = nn.Linear(rank, out_features, bias=nero_bias).to(self.device)

        # Initialize Nero matrices similarly
        nn.init.normal_(self.nero_A.weight, mean=0.0, std=std)
        nn.init.normal_(self.nero_B.weight, mean=0.0, std=std)
        
    def forward(self, x):
        # Forward through base layer
        base_out = self.base_layer(x)

        # LoRA transformation
        requires_conversion = not torch.is_autocast_enabled()
        if requires_conversion:
            x = x.to(self.lora_A.weight.dtype)
        lora_out = self.lora_B(self.lora_A(self.dropout(x))) # * self.scaling
        # if requires_conversion:
        #     lora_out = lora_out.to(base_out.dtype)

        nero_out = self.nero_B(F.relu(self.nero_A(lora_out)))
        if requires_conversion:
            nero_out = nero_out.to(base_out.dtype)

        output = base_out + (nero_out * self.scaling)

        if self.return_nero_output:
            return output, nero_out
        
        return output

    def load_lora_params(self, state_dict, prefix):
        self.lora_A.weight.data = state_dict[f'{prefix}.lora_A.weight'].to(self.device)
        self.lora_B.weight.data = state_dict[f'{prefix}.lora_B.weight'].to(self.device)
        if self.lora_bias:
            self.lora_A.bias.data = state_dict[f'{prefix}.lora_A.bias'].to(self.device)
            self.lora_B.bias.data = state_dict[f'{prefix}.lora_B.bias'].to(self.device)
    
    def load_nero_params(self, state_dict, prefix):
        self.nero_A.weight.data = state_dict[f'{prefix}.nero_A.weight'].to(self.device)
        self.nero_B.weight.data = state_dict[f'{prefix}.nero_B.weight'].to(self.device)
        if self.nero_bias:
            self.nero_A.bias.data = state_dict[f'{prefix}.nero_A.bias'].to(self.device)
            self.nero_B.bias.data = state_dict[f'{prefix}.nero_B.bias'].to(self.device)
    
class NeroModel(nn.Module):
    def __init__(self, base_model: nn.Module, lora_config: LoraConfig, nero_bias: bool=False, 
                 return_nero_outputs: bool=False, debug: bool=False):
        super().__init__()
        self.base_model = base_model
        self.nero_bias = nero_bias
        self.nero_layers = nn.ModuleDict()
        self.return_nero_outputs = return_nero_outputs
        self.debug = debug

        # Wrap target layers with NeroLayer
        self._wrap_target_layers(lora_config)
        
    def _wrap_target_layers(self, lora_config):
        for module_name, module in self.base_model.named_modules():
            if isinstance(module, NeroLayer):
                # Convert module name format and store reference
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.nero_layers[module_name] = module
                continue

            if any(module_name.endswith(target_module) for target_module in lora_config.target_modules) and isinstance(module, nn.Linear):    
                parent_module, child_name = self._get_parent_module(module_name)
                nero_layer = NeroLayer(
                    module, 
                    lora_config.r, 
                    lora_config.lora_alpha, 
                    lora_config.lora_dropout, 
                    lora_config.lora_bias, 
                    lora_config.use_rslora,
                    nero_bias=self.nero_bias,
                    return_nero_output=self.return_nero_outputs,
                    debug=self.debug,
                    module_name=module_name,
                )
                setattr(parent_module, child_name, nero_layer)

                # Store LoRA layers for weight loading
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.nero_layers[module_name] = nero_layer
    
    def _get_parent_module(self, module_name):
        parts = module_name.split('.')
        parent_module = self.base_model
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)
        return parent_module, parts[-1]
    
    def set_return_nero_outputs(self, return_nero_outputs: bool):
        self.return_nero_outputs = return_nero_outputs
        for layer in self.nero_layers.values():
            layer.return_nero_output = return_nero_outputs

    def freeze_all_except_nero(self):
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        for nero_layer in self.nero_layers.values():
            for param_name, param in nero_layer.named_parameters():
                if 'nero_A' in param_name or 'nero_B' in param_name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
        
        print("[INFO] All layers are frozen except Nero layers!")
    
    def unfreeze_all(self):
        for param in self.base_model.parameters():
            param.requires_grad = True
        
        for nero_layer in self.nero_layers.values():
            for param in nero_layer.parameters():
                param.requires_grad = True
        
        print("[INFO] All layers are unfrozen!")
    
    def load_lora_params(self, lora_path):
        if not os.path.exists(lora_path):
            raise FileNotFoundError("[ERROR] LoRA file not found:", lora_path)
        
        if lora_path.endswith('.bin'):
            state_dict = torch.load(lora_path, map_location='cpu')
        else:
            state_dict = load_file(lora_path) # assuming .safetensors

        prefix = list(state_dict.keys())[0].rsplit('model.', 1)[0] + 'model.'

        for nero_layer_name, nero_layer in self.nero_layers.items():
            nero_layer_name = nero_layer_name.replace('__DOT__', '.')
            nero_layer_name = prefix + nero_layer_name
            if f'{nero_layer_name}.lora_A.weight' in state_dict and f'{nero_layer_name}.lora_B.weight' in state_dict:
                nero_layer.load_lora_params(state_dict, nero_layer_name)
            else:
                # TODO: Print warning message
                pass

        print("[INFO] LoRA parameters loaded successfully!")
    
    def load_nero_params(self, nero_path):
        if not os.path.exists(lora_path):
            raise FileNotFoundError("[ERROR] Nero file not found:", lora_path)
        
        if nero_path.endswith('.bin'):
            state_dict = torch.load(nero_path, map_location='cpu')
        else:
            state_dict = load_file(nero_path) # assuming .safetensors

        prefix = list(state_dict.keys())[0].rsplit('model.', 1)[0] + 'model.'

        for nero_layer_name, nero_layer in self.nero_layers.items():
            nero_layer_name = nero_layer_name.replace('__DOT__', '.')
            nero_layer_name = prefix + nero_layer_name
            if f'{nero_layer_name}.nero_A.weight' in state_dict and f'{nero_layer_name}.nero_B.weight' in state_dict:
                nero_layer.load_nero_params(state_dict, nero_layer_name)
            else:
                # TODO: Print warning message
                pass

        print("[INFO] Nero parameters loaded successfully!")
    
    def forward(self, *args, **kwargs):
        if self.return_nero_outputs:
            nero_outs = {}
            
            def _hook_fn(layer_name, module, _in, _out):
                if isinstance(_out, tuple) and len(_out) == 2:
                    layer_out, nero_out = _out
                    nero_outs[layer_name] = nero_out # Store nero_out separately
                    return layer_out # Return only layer_out to avoid breaking model flow

            # Register hooks to extract nero_out during forward pass
            hooks = []
            for layer_name, layer in self.nero_layers.items():
                hook = layer.register_forward_hook(functools.partial(_hook_fn, layer_name))
                hooks.append(hook)
        
            try:
                output = self.base_model(*args, **kwargs)
            finally:
                # Remove hooks after forward pass, ensuring it's done even if an error occurs
                for hook in hooks:
                    hook.remove()

            return output, nero_outs
        
        return self.base_model(*args, **kwargs)
    
    def __getattr__(self, name):
        try:
            return super().__getattr__(name) # Try getting attribute from self
        except AttributeError:
            return getattr(self.base_model, name) # Fallback to base_model

base_nero_model = AutoModelForCausalLM.from_pretrained(base_model_name, device_map=device)
nero_model = NeroModel(
    base_nero_model, 
    lora_config, 
    nero_bias=True, 
    return_nero_outputs=False,
    debug=False,
)

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/5.96G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/235 [00:00<?, ?B/s]

In [9]:
nero_model.freeze_all_except_nero()
print()
check_loss_and_grad_norm(nero_model, tokenizer)

[INFO] All layers are frozen except Nero layers!

Loss: tensor(12.9127, device='cuda:0', grad_fn=<ToCopyBackward0>)
Gradient norm: 8.642618892672493


In [10]:
print("Check LoRA parameters (unloaded):")
check_lora_parameters(nero_model)
print()

nero_model.load_lora_params(lora_path)
print()

print("Check LoRA parameters (loaded):")
check_lora_parameters(nero_model)
print()

check_loss_and_grad_norm(nero_model, tokenizer)

Check LoRA parameters (unloaded):
- Name    : base_model.model.layers.0.self_attn.q_proj.lora_A.weight
- Mean    : 0.004079858772456646
- Min     : -1.4468613862991333
- Max     : 1.3746638298034668

[INFO] LoRA parameters loaded successfully!

Check LoRA parameters (loaded):
- Name    : base_model.model.layers.0.self_attn.q_proj.lora_A.weight
- Mean    : 6.287686119321734e-05
- Min     : -0.04176201671361923
- Max     : 0.04242725297808647

Loss: tensor(13.4193, device='cuda:0', grad_fn=<ToCopyBackward0>)
Gradient norm: 9.637150062361702


In [11]:
nero_model.gradient_checkpointing_enable({'use_reentrant': False})
print("[INFO] Gradient checkpointing enabled!")
print()

check_loss_and_grad_norm(nero_model, tokenizer)

[INFO] Gradient checkpointing enabled!

Loss: tensor(13.4193, device='cuda:0', grad_fn=<ToCopyBackward0>)


  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


Gradient norm: 9.637150062361702


# Data

In [12]:
dataset = load_hf_dataset(
    target_lang, 
    target_task, 
    train_size=train_size, 
    test_size=test_size,
)

README.md: 0.00B [00:00, ?B/s]

In [13]:
def get_dataset_map_config(remove_columns=None):
    return dict(
        batched=True, 
        remove_columns=remove_columns,
        num_proc=os.cpu_count(), # Use all available CPUs
    )

if target_task == 'gsm8k':
    eos_token = tokenizer.eos_token
    
    def format_gsm8k_prompt(example):
        gsm8k_prompt = """### Instruction:
Solve the following math problem step by step.

### Question: 
{question}

### Answer: 
{answer}""" + eos_token

        return {'text': gsm8k_prompt.format(
            question=example['question'], 
            answer=example['answer'],
        )}

    def tokenize_fn(example):
        return tokenizer(
            example["text"],
            truncation=True,
            padding='max_length',
            max_length=max_seq_length,
        )

    def add_labels(example):
        example['labels'] = example['input_ids'].copy()
        return example

    dataset_formatted = dataset.map(
        format_gsm8k_prompt, 
        **get_dataset_map_config(remove_columns=dataset.column_names),
    )
    dataset_tokenized = dataset_formatted.map(
        tokenize_fn, 
        **get_dataset_map_config(remove_columns=dataset_formatted.column_names),
    )
    dataset_final = dataset_tokenized.map(
        add_labels,
        **get_dataset_map_config(remove_columns=dataset_tokenized.column_names),
    )
else:
    block_size = max_seq_length

    def tokenize_fn(example):
        return tokenizer(example['text'])

    def group_texts(examples):
        concatenated = []
        for input_ids in examples['input_ids']:
            concatenated += input_ids

        total_length = len(concatenated) // block_size * block_size

        input_ids = [concatenated[i:i + block_size] for i in range(0, total_length, block_size)]
        attention_mask = [[1] * block_size for _ in input_ids]
        labels = input_ids.copy()

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels,
        }

    dataset_tokenized = dataset.map(
        tokenize_fn,
        **get_dataset_map_config(remove_columns=dataset.column_names),
    )
    dataset_final = dataset_tokenized.map(
        group_texts, 
        **get_dataset_map_config(remove_columns=dataset_tokenized.column_names),
    )

  self.pid = os.fork()


Map (num_proc=4):   0%|          | 0/5000 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (141723 > 131072). Running this sequence through the model will result in indexing errors
  self.pid = os.fork()


Map (num_proc=4):   0%|          | 0/5000 [00:00<?, ? examples/s]

In [14]:
train_loader = DataLoader(
    dataset_final, 
    batch_size=batch_size, 
    shuffle=True, 
    collate_fn=default_data_collator,
)
print("[INFO] Total batches:", len(train_loader))

[INFO] Total batches: 6923


In [15]:
# Sanity check
first_batch = next(iter(train_loader))
print("First batch data shape (input_ids, attention_mask, labels):")
print((
    first_batch['input_ids'].shape, 
    first_batch['attention_mask'].shape, 
    first_batch['labels'].shape, 
))
print()

first_batch_text = tokenizer.batch_decode(first_batch['input_ids'], skip_special_tokens=True)[0]
print("First batch text:")
print(first_batch_text[:100], "...")
print()

check_loss_and_grad_norm(nero_model, tokenizer, prompt=first_batch_text)

First batch data shape (input_ids, attention_mask, labels):
(torch.Size([4, 1024]), torch.Size([4, 1024]), torch.Size([4, 1024]))

First batch text:
くる息子を、その日のために一人で育てることを決意する。
 木彫りの僧侶
 声 - 松尾銀三
 旺気楼に祀られている像。元は人間であったが、青龍の見立てを受けている最中に邪気を受け、木彫りの像に変わって ...

Loss: tensor(13.9221, device='cuda:0', grad_fn=<ToCopyBackward0>)
Gradient norm: 2.0530590763437186


# Training

In [16]:
# If `device` is not specified or set to 'auto', use model's device
if device is None or device == 'auto':
    device = next(iter(nero_model.parameters())).device

# Set up optimizer and gradient scaler
nero_params = [p for n, p in nero_model.named_parameters() if p.requires_grad]
optimizer = torch.optim.Adam(nero_params, lr=lr, foreach=False)
scaler = torch.cuda.amp.GradScaler()

# Set up LR scheduler
max_global_steps = max_global_steps or len(train_loader) * num_epochs
warmup_steps = int(warmup_ratio * max_global_steps)
if warmup_ratio > 0:
    # If `warmup_ratio` > 0, use cosine annealing scheduler with warm-up 
    from transformers import get_cosine_schedule_with_warmup # type: ignore
    max_optimizer_steps = max_global_steps // grad_accumulation_steps
    warmup_steps = int(warmup_ratio * max_optimizer_steps)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=max_optimizer_steps,

        # Cycle length is 2000 steps or less, but at least 1 cycle, and at most 10 cycles
        num_cycles=min(max(1, max_optimizer_steps // 2000), 10), 
    )
else:
    # If `warmup_ratio` is 0, use a dummy scheduler that returns constant LR
    from torch.optim.lr_scheduler import LambdaLR # type: ignore
    scheduler = LambdaLR(optimizer, lr_lambda=lambda step: 1.0)

# Initialize W&B
wandb.init(
    project='Nero-XLT',
    reinit=True, # End previous run and start a new one
    config=dict(
        # Project configuration
        seed = seed,
        target_lang=target_lang,
        target_task=target_task,
        device = device,

        # Data configuration
        train_size = train_size,
        test_size = test_size,
        max_seq_length = max_seq_length,

        # Training configuration
        batch_size = batch_size,
        num_epochs = num_epochs,
        max_global_steps = max_global_steps,
        resume_step = resume_step,
        grad_accumulation_steps = grad_accumulation_steps,
        clip_grad_norm = clip_grad_norm,
        lr = lr,
        warmup_ratio = warmup_ratio,
        checkpoint_steps = checkpoint_steps,
    ),
)

# Resume training
global_step = resume_step
start_epoch = 0

if resume_step > 0:
    checkpoint_dir = os.path.join(nero_dir, f'checkpoint-{resume_step}')
    print(f"[INFO] Resuming training from checkpoint directory:", checkpoint_dir)

    # Load Nero parameters
    nero_path = os.path.join(checkpoint_dir, 'adapter_model.safetensors')
    nero_model.load_nero_params(nero_path)

    # Load optimizer state
    optimizer_path = os.path.join(checkpoint_dir, 'optimizer.pt')
    optimizer.load_state_dict(torch.load(optimizer_path, map_location=device))
    
    # Move optimizer state to the correct device
    for param in optimizer.state:
        param_device = param.device
        param_dtype = param.dtype
        for key, value in optimizer.state[param].items():
            if isinstance(value, torch.Tensor):
                optimizer.state[param][key] = value.to(device=param_device, dtype=param_dtype)

    # Load scheduler state
    scheduler_path = os.path.join(checkpoint_dir, 'scheduler.pt')
    scheduler.load_state_dict(torch.load(scheduler_path, map_location=device))

    # Load scaler state
    scaler_path = os.path.join(checkpoint_dir, 'scaler.pt')
    scaler.load_state_dict(torch.load(scaler_path, map_location=device))

    # Load trainer state
    trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json')
    if os.path.exists(trainer_state_path):
        with open(trainer_state_path, 'r') as f:
            trainer_state = json.load(f)
        log_history = trainer_state.get('log_history', [])
        start_epoch = log_history[-1]['epoch'] if log_history else 0
        print(f"[INFO] Resuming training from epoch {start_epoch} and step {resume_step}.")

    # Load RNG state for reproducibility
    rng_path = os.path.join(checkpoint_dir, 'rng_state.pth')
    if os.path.exists(rng_path):
        rng_state = torch.load(rng_path)
        random.setstate(rng_state['python'])
        np.random.set_state(rng_state['numpy'])
        torch.set_rng_state(rng_state['cpu'])
        if torch.cuda.is_available() and rng_state['cuda']:
            torch.cuda.set_rng_state_all(rng_state['cuda'])
    
    if resume_step % grad_accumulation_steps != 0:
        print("[WARN] Resuming mid-gradient accumulation cycle. Make sure this is intended.")
else:
    # If it's new training, create Hugging Face repository
    print(f"[INFO] Creating Hugging Face repository:", hf_nero_id) # print the link instead
    create_repo(repo_id=hf_nero_id, repo_type='model', exist_ok=True)
    print(f"[INFO] Hugging Face repository created successfully!")

# Set model to training mode
nero_model.train()

log_history = []
done = False

# Safety: Zero gradients at the start of gradient accumulation cycle
# This ensures there are no leftover gradients when resuming mid-cycle or after a previous cycle was interrupted
if global_step % grad_accumulation_steps == 0:
    optimizer.zero_grad(set_to_none=True)

# Training loop
for epoch in range(start_epoch, num_epochs):
    for step, batch in enumerate(train_loader):
        # Skip previously completed steps
        if global_step <= resume_step:
            global_step += 1
            continue

        # Stop training if `max_global_steps` reached
        if global_step >= max_global_steps:
            done = True
            break

        # Move inputs to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        with torch.cuda.amp.autocast():
            # Forward pass
            outputs = nero_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,

                # Disable cache to avoid conflict with gradient checkpointing
                use_cache=False, 
            )

            # Compute loss
            loss = outputs.loss / grad_accumulation_steps

        log = {
            'epoch': epoch,
            'step': global_step,
            'loss': loss.item() * grad_accumulation_steps,
        }

        # Backward pass
        scaler.scale(loss).backward()
        
        # Update parameters only at the end of gradient accumulation cycle
        grad_norm_log = {}
        if (global_step + 1) % grad_accumulation_steps == 0:
            # Unscale gradients before computing gradient norm and applying clipping
            scaler.unscale_(optimizer)

            # Compute gradient norm
            grad_norm = compute_grad_norm(nero_params)
            grad_norm_log['grad_norm'] = grad_norm

            # Clip gradients
            if clip_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(nero_params, clip_grad_norm)
            
            # Compute clipped gradient norm
            grad_norm_clipped = compute_grad_norm(nero_params)
            grad_norm_log['grad_norm_clipped'] = grad_norm_clipped

            # Update parameters
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            # Zero gradients for the next gradient accumulation cycle
            optimizer.zero_grad(set_to_none=True)

        # Logging
        log = {
            **log, 
            'lr': scheduler.get_last_lr()[0], 
            **grad_norm_log,
        }
        log_history.append(log)
        wandb.log(log)
        print(", ".join(f"{k}: {v}" for k, v in log.items()))
        
        # Save and push checkpoint every `checkpoint_steps`
        if global_step > 0 and global_step % checkpoint_steps == 0:
            # Create checkpoint directory
            checkpoint_dir = os.path.join(nero_dir, f'checkpoint-{global_step}')
            os.makedirs(checkpoint_dir, exist_ok=True)

            # Save Nero parameters, along with optimizer, scheduler, and scaler states
            nero_state_dict = {n: p.detach().cpu() for n, p in nero_model.named_parameters() if p.requires_grad}
            save_file(nero_state_dict, os.path.join(checkpoint_dir, 'adapter_model.safetensors'))  # NEW
            torch.save(optimizer.state_dict(), os.path.join(checkpoint_dir, 'optimizer.pt'))
            torch.save(scheduler.state_dict(), os.path.join(checkpoint_dir, 'scheduler.pt'))
            torch.save(scaler.state_dict(), os.path.join(checkpoint_dir, 'scaler.pt'))

            # Save trainer state for resuming training
            trainer_state = {
                'epoch': epoch,
                'global_step': global_step,
                'log_history': log_history,
            }
            with open(os.path.join(checkpoint_dir, 'trainer_state.json'), 'w') as f:
                json.dump(trainer_state, f, indent=2)

            # Save RNG state for reproducibility
            rng_state = {
                'python': random.getstate(),
                'numpy': np.random.get_state(),
                'cpu': torch.get_rng_state(),
                'cuda': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else [],
            }
            torch.save(rng_state, os.path.join(checkpoint_dir, 'rng_state.pth'))

            # Upload checkpoint directory to Hugging Face repository
            upload_folder(
                folder_path=checkpoint_dir,
                repo_id=hf_nero_id,
                path_in_repo=f"checkpoint-{global_step}",
                commit_message=f"Add checkpoint at step {global_step}",
                repo_type='model',
            )
        
        # Check generated text every `generate_steps`
        if global_step > 0 and global_step % generate_steps == 0:
            generated = generate_text(nero_model, tokenizer, sample_prompt, device=device)
            print()
            print("================================")
            print("CHECK GENERATED TEXT")
            print("================================")
            print(f"{'Prompt':<9}:", sample_prompt)
            print(f"{'Generated':<9}:", generated)
            print()
        
        global_step += 1
    
    if done:
        break

wandb.finish()

  scaler = torch.cuda.amp.GradScaler()
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33malimtegar[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112639466667436, max=1.0…

[INFO] Creating Hugging Face repository: alxxtexxr/L3.1-8B-wikipedia-ja-5K-Nero-v20250805174511
[INFO] Hugging Face repository created successfully!


  with torch.cuda.amp.autocast():
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


epoch: 0, step: 1, loss: 14.206133842468262, lr: 1.1560693641618497e-06, grad_norm: 1.1859621034997168, grad_norm_clipped: 1.1859621034997168
epoch: 0, step: 2, loss: 14.073974609375, lr: 1.1560693641618497e-06
epoch: 0, step: 3, loss: 14.092690467834473, lr: 2.3121387283236993e-06, grad_norm: 2.1565261534406677, grad_norm_clipped: 2.1565261534406677
epoch: 0, step: 4, loss: 14.026183128356934, lr: 2.3121387283236993e-06
epoch: 0, step: 5, loss: 13.972622871398926, lr: 3.468208092485549e-06, grad_norm: 2.1121887367775036, grad_norm_clipped: 2.1121887367775036
epoch: 0, step: 6, loss: 14.047845840454102, lr: 3.468208092485549e-06
epoch: 0, step: 7, loss: 14.141060829162598, lr: 4.624277456647399e-06, grad_norm: 2.404423062023708, grad_norm_clipped: 2.404423062023708
epoch: 0, step: 8, loss: 14.005545616149902, lr: 4.624277456647399e-06
epoch: 0, step: 9, loss: 14.026354789733887, lr: 5.780346820809249e-06, grad_norm: 2.0687781371260408, grad_norm_clipped: 2.0687781371260408
epoch: 0, st

rng_state.pth:   0%|          | 0.00/14.6k [00:00<?, ?B/s]

optimizer.pt:   0%|          | 0.00/188M [00:00<?, ?B/s]

scheduler.pt:   0%|          | 0.00/1.00k [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/93.7M [00:00<?, ?B/s]

scaler.pt:   0%|          | 0.00/988 [00:00<?, ?B/s]

Upload 5 LFS files:   0%|          | 0/5 [00:00<?, ?it/s]


CHECK GENERATED TEXT
Prompt   : 日本は
Generated: 日本はivesyn مک.syntaxistonRY.syntax.syntaxadians.syntaxrynúčastivesivesadianihatrynRY natsieองจาก McL مک El McLsieadiansองจากúčastRYiami Flux

epoch: 0, step: 51, loss: 13.465641975402832, lr: 3.0057803468208097e-05, grad_norm: 2.2052367795871057, grad_norm_clipped: 2.2052367795871057
epoch: 0, step: 52, loss: 13.470633506774902, lr: 3.0057803468208097e-05
epoch: 0, step: 53, loss: 13.301692008972168, lr: 3.1213872832369946e-05, grad_norm: 2.2361969397974732, grad_norm_clipped: 2.2361969397974732
epoch: 0, step: 54, loss: 13.442814826965332, lr: 3.1213872832369946e-05
epoch: 0, step: 55, loss: 13.381743431091309, lr: 3.2369942196531794e-05, grad_norm: 2.000464217754044, grad_norm_clipped: 2.000464217754044
epoch: 0, step: 56, loss: 13.18196964263916, lr: 3.2369942196531794e-05
epoch: 0, step: 57, loss: 13.35496711730957, lr: 3.352601156069364e-05, grad_norm: 2.13352416352135, grad_norm_clipped: 2.13352416352135
epoch: 0, step: 58, loss: 13.

Upload 5 LFS files:   0%|          | 0/5 [00:00<?, ?it/s]

optimizer.pt:   0%|          | 0.00/188M [00:00<?, ?B/s]

scheduler.pt:   0%|          | 0.00/1.00k [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/93.7M [00:00<?, ?B/s]

scaler.pt:   0%|          | 0.00/988 [00:00<?, ?B/s]

rng_state.pth:   0%|          | 0.00/14.6k [00:00<?, ?B/s]


CHECK GENERATED TEXT
Prompt   : 日本は
Generated: 日本は EmbRYRYúčast338、účastynRWラ、、、 Flux032garyRYsie32RY Esper Flux EsperRY、ivesúčast、RY03208 Esper

epoch: 0, step: 101, loss: 11.77790641784668, lr: 5.895953757225434e-05, grad_norm: 1.8393874360646352, grad_norm_clipped: 1.8393874360646352


KeyboardInterrupt: 