In [1]:
# Kill all processess on GPU
# !fuser -v /dev/nvidia* -k

# Libraries

In [2]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    %pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    %pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    %pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    %pip install --no-deps unsloth
%pip install trl==0.19.1 # Fix error: ImportError: cannot import name 'ConstantLengthDataset' from 'trl.trainer.utils'

In [3]:
import os
import functools
import gc
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch.utils.data import DataLoader
from datasets import load_dataset, Dataset
from transformers import (
    default_data_collator, 
    AutoModelForCausalLM, 
    AutoTokenizer,
)
from peft import LoraConfig
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from datetime import datetime
from pprint import pprint

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [4]:
def download_hf_model(repo_id, checkpoint=None, max_checkpoints=2000, checkpoint_step=25):
    local_dir = repo_id.split('/')[-1]
    ignore_checkpoints = None
    
    if checkpoint is not None:
        ignore_checkpoints = [f'checkpoint-{i}/*' for i in range(0, max_checkpoints, checkpoint_step) if i != checkpoint]

    snapshot_download(
        repo_id=repo_id,
        local_dir=local_dir,
        ignore_patterns=ignore_checkpoints,
    )

    if checkpoint is not None:
        return os.path.join(local_dir, f'checkpoint-{checkpoint}')
    return local_dir

def check_loss_and_grad_norm(model, tokenizer, prompt="Paris is the capital of"):
    # Set model in training mode
    model.train()

    # Zero gradients manually
    for p in model.parameters():
        p.grad = None

    # Forward pass
    inputs = tokenizer(
        prompt, 
        return_tensors='pt',
    ).to(next(model.parameters()).device)

    outputs = model(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        labels=inputs['input_ids'],
        use_cache=False, # Disable cache to not conflict with gradient checkpointing
    )
    print("Loss:", outputs.loss)

    # Backward pass
    outputs.loss.backward()

    # Compute gradient norm
    grad_norm = 0.0
    for p in [p for n, p in model.named_parameters() if p.requires_grad]:
        p_grad_norm = p.grad.data.norm(2)
        grad_norm += p_grad_norm.item() ** 2
    grad_norm = grad_norm ** 0.5

    print("Gradient norm:", grad_norm)

@torch.no_grad()
def check_lora_parameters(model):
    for n, p in model.named_parameters():
        if 'lora' in n:
            print(f"- {'Name':<8}:", n)
            print(f"- {'Mean':<8}:", p.mean().item())
            print(f"- {'Min':<8}:", p.min().item())
            print(f"- {'Max':<8}:", p.max().item())
            break

@torch.no_grad()
def generate_text(model, tokenizer, prompt, max_new_tokens=32, device=None, skip_special_tokens=True):
    device = device or next(model.parameters()).device
    model.eval()
    inputs = tokenizer(prompt, return_tensors='pt').to(device)
    outputs = model.generate(input_ids=inputs['input_ids'], max_new_tokens=max_new_tokens)
    model.train()
    return tokenizer.decode(outputs[0], skip_special_tokens=skip_special_tokens)

def load_hf_dataset_from_lora(
    lora_repo_id,
    train_size = 5000,
    test_size = 1000,
):
    # Get task and language
    task, lang, _ = lora_repo_id.split('B-')[-1].split('K-')[0].split('-')

    # Set up Hugging Face configuration
    data_id_map = {
        'wikipedia': 'wikimedia/wikipedia',
        'gsm8k': 'openai/gsm8k',
    }
    data_id = data_id_map[task]
    data_dir = f'20231101.{lang}' if task == 'wikipedia' else 'main'
    split = f'train[:{(train_size+test_size)}]'

    # Load dataset
    # TODO: Use streaming to not download the entire dataset
    # dataset = load_dataset(data_id, data_dir=data_dir, split=split)

    # Use streaming
    dataset_stream = load_dataset(data_id, data_dir=data_dir, split='train', streaming=True)

    # Manually take train_size + test_size samples
    total_size = train_size + test_size
    sliced_data = []
    for i, example in enumerate(dataset_stream):
        if i >= total_size:
            break
        sliced_data.append(example)

    # Convert to regular in-memory dataset
    dataset = Dataset.from_list(sliced_data)
    
    return dataset

def load_hf_dataset(
    lang, 
    task,
    split='train',
    train_size = 5000,
    test_size = 1000,
):
    # Set up Hugging Face configuration
    data_id_map = {
        'wikipedia': 'wikimedia/wikipedia',
        'gsm8k': 'openai/gsm8k',
    }
    data_id = data_id_map[task]
    data_dir = f'20231101.{lang}' if task == 'wikipedia' else 'main'

    # Use streaming
    dataset_stream = load_dataset(data_id, data_dir=data_dir, split=split, streaming=True)

    # Manually take train_size + test_size samples
    total_size = train_size + test_size
    sliced_data = []
    for i, example in enumerate(dataset_stream):
        if i >= total_size:
            break
        sliced_data.append(example)

    # Convert to regular in-memory dataset
    dataset = Dataset.from_list(sliced_data)
    
    return dataset

# Configurations

In [23]:
from huggingface_hub import HfApi, HfFolder, Repository, create_repo

# Project configuration
seed = 69
target_lang = 'ja' # 'en' | 'id' | 'es'
target_task = 'wikipedia' # 'wikipedia' | 'gsm8k'
device = 'auto' # 'cpu' | 'cuda' | 'auto'

# Data configuration
train_size = 5000
test_size = 0 # Temporary
max_seq_length = 1024

# Training configuration
batch_size = 4
num_epochs = 1
max_global_steps = 10
resume_step = 0
grad_accumulation_steps = 2
clip_grad_norm = 0.5
lr = 2e-4
warmup_ratio = 0.05
checkpoint_steps = 10
generate_steps = 10
sample_prompt = '日本国は、'

# Model configurations
hf_lora_id = 'alxxtexxr/L3.1-8B-wikipedia-en-5K-LoRA-v20250630122650'
checkpoint = 650

lora_dir = download_hf_model(hf_lora_id, checkpoint)
lora_path = os.path.join(lora_dir, 'adapter_model.safetensors')
lora_config = LoraConfig.from_pretrained(lora_dir)
base_model_name = lora_config.base_model_name_or_path
tokenizer = AutoTokenizer.from_pretrained(base_model_name)

project_name = f'L3.1-8B-{target_task}-{target_lang}-{train_size//1000}K-Nero-v{datetime.now().strftime("%Y%m%d%H%M%S")}'
nero_dir = project_name

print(f"Project name: {project_name}")
print(f"Base model name: {base_model_name}")

Project name: L3.1-8B-wikipedia-ja-5K-Nero-v20250803164454
Base model name: unsloth/meta-llama-3.1-8b-unsloth-bnb-4bit


In [26]:
# Hugging Face configuration
hf_username = 'alxxtexxr'
hf_nero_id = f'{hf_username}/{nero_dir}'
os.makedirs(nero_dir, exist_ok=True)

if not os.path.exists(os.path.join(nero_dir, '.git')):
    create_repo(hf_nero_id, exist_ok=True)
    repo = Repository(local_dir=nero_dir, clone_from=hf_nero_id)
else:
    repo = Repository(local_dir=nero_dir)

# Model

In [7]:
class NeroLayer(nn.Module):
    def __init__(self, base_layer, 
                 # LoRA parameters
                 rank, alpha, dropout, lora_bias, use_rslora, 
                 # Nero parameters
                 nero_bias=False, 
                 return_nero_output=False,
                 # For debugging 
                 debug=False,
                 module_name=None,
                 ):
        super().__init__()
        self.base_layer = base_layer
        self.device = base_layer.weight.device
        self.alpha = alpha
        self.lora_bias = lora_bias
        self.scaling = alpha / math.sqrt(rank) if use_rslora else alpha / rank
        self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
        self.return_nero_output = return_nero_output

        # For debugging
        self.debug = debug
        self.module_name = module_name

        # Extract input and output features from the base layer
        in_features = getattr(base_layer, 'in_features', None)
        out_features = getattr(base_layer, 'out_features', None)

        if in_features is None or out_features is None:
            raise ValueError(f"Cannot determine in_features or out_features from {base_layer}.")
        
        # LoRA decomposition: A (down-projection) and B (up-projection)
        self.lora_A = nn.Linear(in_features, rank, bias=lora_bias).to(self.device)  # Projects down
        self.lora_B = nn.Linear(rank, out_features, bias=lora_bias).to(self.device) # Projects up

        # Initialize LoRA matrices: A ~ N(0, 1/rank), B initialized to 0
        std = 1 / torch.sqrt(torch.tensor(rank).float())
        nn.init.normal_(self.lora_A.weight, mean=0.0, std=std)
        nn.init.zeros_(self.lora_B.weight)

        # Nero decomposition: additional transformation applied to LoRA output
        self.nero_A = nn.Linear(out_features, rank, bias=nero_bias).to(self.device)
        self.nero_B = nn.Linear(rank, out_features, bias=nero_bias).to(self.device)

        # Initialize Nero matrices similarly
        nn.init.normal_(self.nero_A.weight, mean=0.0, std=std)
        nn.init.zeros_(self.nero_B.weight)
        
    def forward(self, x):
        # Forward through base layer
        base_out = self.base_layer(x)

        if self.debug:
            print("================================================================")
            print(self.module_name)
            print("================================================================")
            print("base_out.requires_grad:", base_out.requires_grad)
            print("base_out.grad_fn:", base_out.grad_fn)
            print()

        # LoRA transformation
        requires_conversion = not torch.is_autocast_enabled()
        if requires_conversion:
            x = x.to(self.lora_A.weight.dtype)
        lora_out = self.lora_B(self.lora_A(self.dropout(x))) * self.scaling
        # if requires_conversion:
        #     lora_out = lora_out.to(base_out.dtype)

        if self.debug:
            print("lora_out.requires_grad:", lora_out.requires_grad)
            print("lora_out.grad_fn:", lora_out.grad_fn)
            print()

        # nero_out = F.relu(self.nero_B(self.nero_A(self.dropout(lora_out))) * self.scaling)
        nero_dropout_out = self.dropout(lora_out)
        nero_A_out = self.nero_A(nero_dropout_out)
        nero_B_out = self.nero_B(nero_A_out)
        nero_scaling_out = nero_B_out * self.scaling
        nero_out = F.relu(nero_scaling_out)
        if requires_conversion:
            nero_out = nero_out.to(base_out.dtype)

        if self.debug:
            print("nero_out.requires_grad:", nero_out.requires_grad)
            print("nero_out.grad_fn:", nero_out.grad_fn)
            print()

            nero_out_has_nan = torch.isnan(nero_out).any()
            if nero_out_has_nan:
                print("!!! NERO OUT HAS NAN !!!")
                print("nero_out:")
                print(nero_out)
                print()
                print("nero_scaling_out:")
                print(nero_scaling_out)
                print()
                print("nero_B_out:")
                print(nero_B_out)
                print()
                print("nero_A_out:")
                print(nero_A_out)
                print()
                print("nero_dropout_out:")
                print(nero_dropout_out)
                print()
                print("lora_out:")
                print(lora_out)
                print()

        # Add `base_out` with gradients-detached `nero_out`, 
        # so that `base_out` does not carry gradients
        # nero_out_detached = nero_out.detach()

        # if self.debug:
        #     print("nero_out_detached.requires_grad:", nero_out_detached.requires_grad)
        #     print("nero_out_detached.grad_fn:", nero_out_detached.grad_fn)
        #     print()

        # output = base_out + nero_out_detached
        output = base_out + nero_out

        if self.debug:
            print("output.requires_grad:", output.requires_grad)
            print("output.grad_fn:", output.grad_fn)
            print()

        if self.return_nero_output:
            return output, nero_out
        
        return output

    def load_lora_params(self, state_dict, prefix):
        self.lora_A.weight.data = state_dict[f'{prefix}.lora_A.weight'].to(self.device)
        self.lora_B.weight.data = state_dict[f'{prefix}.lora_B.weight'].to(self.device)
        if self.lora_bias:
            self.lora_A.bias.data = state_dict[f'{prefix}.lora_A.bias'].to(self.device)
            self.lora_B.bias.data = state_dict[f'{prefix}.lora_B.bias'].to(self.device)
    
class NeroModel(nn.Module):
    def __init__(self, base_model: nn.Module, lora_config: LoraConfig, nero_bias: bool=False, 
                 return_nero_outputs: bool=False, debug: bool=False):
        super().__init__()
        self.base_model = base_model
        self.nero_bias = nero_bias
        self.nero_layers = nn.ModuleDict()
        self.return_nero_outputs = return_nero_outputs
        self.debug = debug

        # Wrap target layers with NeroLayer
        self._wrap_target_layers(lora_config)
        
    def _wrap_target_layers(self, lora_config):
        for module_name, module in self.base_model.named_modules():
            if isinstance(module, NeroLayer):
                # Convert module name format and store reference
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.nero_layers[module_name] = module
                continue

            if any(module_name.endswith(target_module) for target_module in lora_config.target_modules) and isinstance(module, nn.Linear):    
                parent_module, child_name = self._get_parent_module(module_name)
                nero_layer = NeroLayer(
                    module, 
                    lora_config.r, 
                    lora_config.lora_alpha, 
                    lora_config.lora_dropout, 
                    lora_config.lora_bias, 
                    lora_config.use_rslora,
                    nero_bias=self.nero_bias,
                    return_nero_output=self.return_nero_outputs,
                    debug=self.debug,
                    module_name=module_name,
                )
                setattr(parent_module, child_name, nero_layer)

                # Store LoRA layers for weight loading
                module_name = module_name.rsplit('model.', 1)[-1]
                module_name = module_name.replace('.', '__DOT__')
                self.nero_layers[module_name] = nero_layer
    
    def _get_parent_module(self, module_name):
        parts = module_name.split('.')
        parent_module = self.base_model
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)
        return parent_module, parts[-1]
    
    def set_return_nero_outputs(self, return_nero_outputs: bool):
        self.return_nero_outputs = return_nero_outputs
        for layer in self.nero_layers.values():
            layer.return_nero_output = return_nero_outputs

    def freeze_all_except_nero(self):
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        for nero_layer in self.nero_layers.values():
            for param_name, param in nero_layer.named_parameters():
                if 'nero_A' in param_name or 'nero_B' in param_name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
        
        print("All layers are frozen except Nero layers!")
    
    def unfreeze_all(self):
        for param in self.base_model.parameters():
            param.requires_grad = True
        
        for nero_layer in self.nero_layers.values():
            for param in nero_layer.parameters():
                param.requires_grad = True
        
        print("All layers are unfrozen!")
    
    def load_lora_params(self, lora_path):
        state_dict = load_file(lora_path)
        prefix = list(state_dict.keys())[0].rsplit('model.', 1)[0] + 'model.'

        for nero_layer_name, nero_layer in self.nero_layers.items():
            nero_layer_name = nero_layer_name.replace('__DOT__', '.')
            nero_layer_name = prefix + nero_layer_name
            if f'{nero_layer_name}.lora_A.weight' in state_dict and f'{nero_layer_name}.lora_B.weight' in state_dict:
                nero_layer.load_lora_params(state_dict, nero_layer_name)
            else:
                # TODO: Print warning message
                pass

        print("LoRA parameters loaded successfully!")
    
    def forward(self, *args, **kwargs):
        if self.return_nero_outputs:
            nero_outs = {}
            
            def _hook_fn(layer_name, module, _in, _out):
                if isinstance(_out, tuple) and len(_out) == 2:
                    layer_out, nero_out = _out
                    nero_outs[layer_name] = nero_out # Store nero_out separately
                    return layer_out # Return only layer_out to avoid breaking model flow

            # Register hooks to extract nero_out during forward pass
            hooks = []
            for layer_name, layer in self.nero_layers.items():
                hook = layer.register_forward_hook(functools.partial(_hook_fn, layer_name))
                hooks.append(hook)
        
            try:
                output = self.base_model(*args, **kwargs)
            finally:
                # Remove hooks after forward pass, ensuring it's done even if an error occurs
                for hook in hooks:
                    hook.remove()

            return output, nero_outs
        
        return self.base_model(*args, **kwargs)
    
    def __getattr__(self, name):
        try:
            return super().__getattr__(name) # Try getting attribute from self
        except AttributeError:
            return getattr(self.base_model, name) # Fallback to base_model

base_nero_model = AutoModelForCausalLM.from_pretrained(base_model_name, device_map=device)
nero_model = NeroModel(
    base_nero_model, 
    lora_config, 
    nero_bias=True, 
    return_nero_outputs=False,
    debug=False,
)

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/5.96G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/235 [00:00<?, ?B/s]

In [8]:
nero_model.freeze_all_except_nero()
check_loss_and_grad_norm(nero_model, tokenizer)

All layers are frozen except Nero layers!
Loss: tensor(14.5990, device='cuda:0', grad_fn=<ToCopyBackward0>)
Gradient norm: 1.5991796752380212


In [9]:
print("Check LoRA parameters (unloaded):")
check_lora_parameters(nero_model)
print()

nero_model.load_lora_params(lora_path)
print()

print("Check LoRA parameters (loaded):")
check_lora_parameters(nero_model)
print()

check_loss_and_grad_norm(nero_model, tokenizer)

Check LoRA parameters (unloaded):
- Name    : base_model.model.layers.0.self_attn.q_proj.lora_A.weight
- Mean    : 0.0016585660632699728
- Min     : -1.574367880821228
- Max     : 1.5179626941680908

LoRA parameters loaded successfully!

Check LoRA parameters (loaded):
- Name    : base_model.model.layers.0.self_attn.q_proj.lora_A.weight
- Mean    : 6.287686119321734e-05
- Min     : -0.04176201671361923
- Max     : 0.04242725297808647

Loss: tensor(14.5990, device='cuda:0', grad_fn=<ToCopyBackward0>)
Gradient norm: 1.8391875321443354


In [10]:
nero_model.gradient_checkpointing_enable({'use_reentrant': False})
print("Gradient checkpointing enabled!")
print()

check_loss_and_grad_norm(nero_model, tokenizer)

Gradient checkpointing enabled!

Loss: tensor(14.5990, device='cuda:0', grad_fn=<ToCopyBackward0>)


  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


Gradient norm: 1.8391875321443354


In [11]:
for k, v in nero_model.config.__dict__.items():
    if 'token' in k:
        print(k, ":", v)

tokenizer_class : None
bos_token_id : 128000
pad_token_id : 128004
eos_token_id : 128001
sep_token_id : None
decoder_start_token_id : None
forced_bos_token_id : None
forced_eos_token_id : None
suppress_tokens : None
begin_suppress_tokens : None


In [12]:
# nero_model.eval()
# generate_text(
#     nero_model, 
#     tokenizer, 
#     prompt="Paris is the capital of",
# )

# Data

In [13]:
dataset = load_hf_dataset(
    target_lang, 
    target_task, 
    train_size=train_size, 
    test_size=test_size,
)

README.md: 0.00B [00:00, ?B/s]

In [14]:
def get_dataset_map_config(remove_columns=None):
    return dict(
        batched=True, 
        remove_columns=remove_columns,
        num_proc=4,
    )


if target_task == 'gsm8k':
    eos_token = tokenizer.eos_token
    
    def format_gsm8k_prompt(example):
        gsm8k_prompt = """### Instruction:
Solve the following math problem step by step.

### Question: 
{question}

### Answer: 
{answer}""" + eos_token

        return {'text': gsm8k_prompt.format(
            question=example['question'], 
            answer=example['answer'],
        )}

    def tokenize_fn(example):
        return tokenizer(
            example["text"],
            truncation=True,
            padding='max_length',
            max_length=max_seq_length,
        )

    def add_labels(example):
        example['labels'] = example['input_ids'].copy()
        return example

    dataset_formatted = dataset.map(
        format_gsm8k_prompt, 
        **get_dataset_map_config(remove_columns=dataset.column_names),
    )
    dataset_tokenized = dataset_formatted.map(
        tokenize_fn, 
        **get_dataset_map_config(remove_columns=dataset.column_names),
    )
    dataset_final = dataset_tokenized.map(
        add_labels,
        **get_dataset_map_config(remove_columns=dataset.column_names),
    )
else:
    def tokenize_fn(example):
        return tokenizer(example['text'])

    # Concatenate all tokens into one long stream, then split into blocks
    block_size = max_seq_length

    def group_texts(examples):
        concatenated = []
        for input_ids in examples['input_ids']:
            concatenated += input_ids

        total_length = len(concatenated) // block_size * block_size

        input_ids = [concatenated[i:i + block_size] for i in range(0, total_length, block_size)]
        attention_mask = [[1] * block_size for _ in input_ids]
        labels = input_ids.copy()

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels,
        }

    dataset_tokenized = dataset.map(
        tokenize_fn,
        **get_dataset_map_config(remove_columns=dataset.column_names),
    )
    dataset_final = dataset_tokenized.map(
        group_texts, 
        **get_dataset_map_config(remove_columns=dataset_tokenized.column_names),
    )

Map (num_proc=4):   0%|          | 0/5000 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/5000 [00:00<?, ? examples/s]

In [15]:
train_loader = DataLoader(
    dataset_final, 
    batch_size=batch_size, 
    shuffle=True, 
    collate_fn=default_data_collator,
)

print("Total batches:", len(train_loader))

Total batches: 6923


In [16]:
# Sanity check
first_batch = next(iter(train_loader))
print("First batch data shape (input_ids, attention_mask, labels):")
print((
    first_batch['input_ids'].shape, 
    first_batch['attention_mask'].shape, 
    first_batch['labels'].shape, 
))
print()

first_batch_text = tokenizer.batch_decode(first_batch['input_ids'], skip_special_tokens=True)[0]
print("First batch text:")
print(first_batch_text[:100], "...")
print()

check_loss_and_grad_norm(nero_model, tokenizer, prompt=first_batch_text)

First batch data shape (input_ids, attention_mask, labels):
(torch.Size([4, 1024]), torch.Size([4, 1024]), torch.Size([4, 1024]))

First batch text:
NSII CX/UX（CRT一体型386SX TOWNS）　1992年度グッドデザイン賞受賞
 1992年2月 - FM TOWNSII UX40
 1992年11月4日 - FM TOWNSII H ...

Loss: tensor(13.9625, device='cuda:0', grad_fn=<ToCopyBackward0>)
Gradient norm: 0.71104503952095


# Training

In [None]:
generate_text(nero_model, tokenizer=tokenizer, prompt=sample_prompt)

In [27]:
# Ensure device is set
if device is None or device == 'auto':
    device = next(iter(nero_model.parameters())).device

# Setup optimizer and gradient scaler
nero_params = [p for n, p in nero_model.named_parameters() if p.requires_grad]
optimizer = torch.optim.Adam(nero_params, lr=lr)
scaler = torch.cuda.amp.GradScaler()

# Setup LR scheduler
max_global_steps = max_global_steps or len(train_loader) * num_epochs
warmup_steps = int(warmup_ratio * max_global_steps)
if warmup_ratio > 0:
    # If `warmup_ratio` > 0, use cosine annealing scheduler with warm-up 
    from transformers import get_cosine_schedule_with_warmup # type: ignore
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=max_global_steps,
    )
else:
    # If `warmup_ratio` is 0, use a dummy scheduler that returns constant LR
    from torch.optim.lr_scheduler import LambdaLR # type: ignore
    scheduler = LambdaLR(optimizer, lr_lambda=lambda step: 1.0)

# Initialize W&B
wandb.init(
    project='Nero-XLT',
    reinit=True,
    config=dict(
        seed = seed,
        target_lang=target_lang,
        target_task=target_task,
        device = device,
        train_size = train_size,
        test_size = test_size,
        max_seq_length = max_seq_length,
        batch_size = batch_size,
        num_epochs = num_epochs,
        max_global_steps = max_global_steps,
        resume_step = resume_step,
        lr = lr,
        grad_accumulation_steps = grad_accumulation_steps,
        clip_grad_norm = clip_grad_norm,
    ),
)

# Resume training
global_step = resume_step
start_epoch = 0
if resume_step > 0:
    checkpoint_path = os.path.join(nero_dir, f'checkpoint-{resume_step}')
    print(f"[INFO] Resuming training from {checkpoint_path}")

    # Load adapter state
    model_path = os.path.join(checkpoint_path, 'adapter_model.bin')
    nero_model.load_state_dict(torch.load(model_path, map_location=device))

    # Load optimizer state
    optimizer_path = os.path.join(checkpoint_path, 'optimizer.pt')
    optimizer.load_state_dict(torch.load(optimizer_path, map_location=device))

    # Load scheduler state
    scheduler_path = os.path.join(checkpoint_path, 'scheduler.pt')
    scheduler.load_state_dict(torch.load(scheduler_path, map_location=device))

    # Load scaler state
    scaler_path = os.path.join(checkpoint_path, 'scaler.pt')
    scaler.load_state_dict(torch.load(scaler_path, map_location=device))

    # Load training state
    training_state_path = os.path.join(checkpoint_path, 'training_state.pt')
    if os.path.exists(training_state_path):
        training_state = torch.load(training_state_path)
        start_epoch = training_state.get('epoch', 0)
        print(f"[INFO] Resumed from epoch {start_epoch}, step {resume_step}")

# Set model to training mode
nero_model.train()
done = False

# Safety: Zero gradients at the start of gradient accumulation cycle
# This ensures there are no leftover gradients when resuming mid-cycle or after a previous cycle was interrupted
if global_step % grad_accumulation_steps == 0:
    optimizer.zero_grad(set_to_none=True)

# Training loop
for epoch in range(start_epoch, num_epochs):
    for step, batch in enumerate(train_loader):
        # Stop training if `max_global_steps` reached
        if global_step >= max_global_steps:
            done = True
            break

        # Move inputs to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        with torch.cuda.amp.autocast():
            # Forward pass
            outputs = nero_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,

                # Disable cache to avoid conflict with gradient checkpointing
                use_cache=False, 
            )

            # Compute loss
            loss = outputs.loss / grad_accumulation_steps

        # Backward pass
        scaler.scale(loss).backward()
        
        # Update parameters and log only at the end of gradient accumulation cycle
        if (global_step + 1) % grad_accumulation_steps == 0:
            # Unscale gradients before computing gradient norm and applying clipping
            scaler.unscale_(optimizer)

            # Compute gradient norm
            grad_norm = 0.0
            for p in nero_params:
                p_grad_norm = p.grad.data.norm(2)
                grad_norm += p_grad_norm.item() ** 2
            grad_norm = grad_norm ** 0.5

            # Clip gradients
            if clip_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(nero_params, clip_grad_norm)

            # Update parameters
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            # Zero gradients for the next gradient accumulation cycle
            optimizer.zero_grad(set_to_none=True)

            # Logging
            wandb.log({
                'epoch': epoch,
                'step': global_step,
                'lr': scheduler.get_last_lr()[0],
                'loss': loss.item() * grad_accumulation_steps,
                'grad_norm': grad_norm,
            })
            print(f"epoch: {epoch}/{num_epochs}, "
                f"step: {global_step}/{max_global_steps}, "
                f"loss: {loss.item() * grad_accumulation_steps}, "
                f"grad_norm: {grad_norm}, "
                f"lr: {scheduler.get_last_lr()[0]}")
        
        # Save and push checkpoint every `checkpoint_steps`
        if global_step > 0 and global_step % checkpoint_steps == 0:
            # Create checkpoint directory
            checkpoint_path = os.path.join(nero_dir, f'checkpoint-{global_step}')
            os.makedirs(checkpoint_path, exist_ok=True)

            # Save Nero adapter, optimizer, scheduler, and scaler states
            torch.save(nero_model.state_dict(), os.path.join(checkpoint_path, 'adapter_model.bin'))
            torch.save(optimizer.state_dict(), os.path.join(checkpoint_path, 'optimizer.pt'))
            torch.save(scheduler.state_dict(), os.path.join(checkpoint_path, 'scheduler.pt'))
            torch.save(scaler.state_dict(), os.path.join(checkpoint_path, 'scaler.pt'))

            # Save training state for resuming training
            training_state = {
                'step': global_step,
                'epoch': epoch,
                'lr': scheduler.get_last_lr()[0],
            }
            torch.save(training_state, os.path.join(checkpoint_path, 'training_state.pt'))

            # Commit and push checkpoint to Hugging Face
            repo.git_add(pattern=f"checkpoint-{global_step}")
            repo.git_commit(f"Add checkpoint at step {global_step}")
            repo.git_push()
        
        # Check generated text every `generate_steps`
        if global_step > 0 and global_step % generate_steps == 0:
            generated = generate_text(nero_model, tokenizer, sample_prompt, device=device)
            print()
            print("================================")
            print("CHECK GENERATED TEXT")
            print("================================")
            print(f"{'Prompt':<9}:", sample_prompt)
            print(f"{'Generated':<9}:", generated)
            print()

        global_step += 1
    
    if done:
        break

wandb.finish()

  scaler = torch.cuda.amp.GradScaler()
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33malimtegar[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112705911111536, max=1.0…

  with torch.cuda.amp.autocast():
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


Upload file checkpoint-0/adapter_model.bin:   0%|          | 1.00/5.72G [00:00<?, ?B/s]

Upload file checkpoint-0/scaler.pt:   0%|          | 1.00/988 [00:00<?, ?B/s]

Upload file checkpoint-0/training_state.pt:   0%|          | 1.00/892 [00:00<?, ?B/s]

Upload file checkpoint-0/scheduler.pt:   0%|          | 1.00/0.98k [00:00<?, ?B/s]

Upload file checkpoint-0/optimizer.pt:   0%|          | 1.00/3.48k [00:00<?, ?B/s]

To https://huggingface.co/alxxtexxr/L3.1-8B-wikipedia-ja-5K-Nero-v20250803164454
   e92510d..0d7f73d  main -> main

  with torch.cuda.amp.autocast():
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


epoch: 0/1, step: 1/10, loss: 13.894257545471191, grad_norm: 0.5216511298905168, lr: 0.00019510565162951537
epoch: 0/1, step: 3/10, loss: 13.740740776062012, grad_norm: 0.408808239260627, lr: 0.00018090169943749476
epoch: 0/1, step: 5/10, loss: 13.746880531311035, grad_norm: 0.47036268693720007, lr: 0.00015877852522924732
epoch: 0/1, step: 7/10, loss: 13.788190841674805, grad_norm: 0.45381971411593797, lr: 0.00013090169943749476
epoch: 0/1, step: 9/10, loss: 13.611076354980469, grad_norm: 0.4550647138584545, lr: 0.0001


VBox(children=(Label(value='0.018 MB of 0.018 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▁
grad_norm,█▁▅▄▄
loss,█▄▄▅▁
lr,█▇▅▃▁
step,▁▃▅▆█

0,1
epoch,0.0
grad_norm,0.45506
loss,13.61108
lr,0.0001
step,9.0


In [None]:
# project_name = f'L3.1-8B-{target_task}-{target_lang}-{train_size//1000}K-Nero-v{datetime.now().strftime("%Y%m%d%H%M%S")}'
# hub_model_id = f'alxxtexxr/{project_name}'

# print("Project name:", project_name)
# print("Hugging Face model ID:", hub_model_id)

In [None]:
# # Save Nero parameters
# nero_params_path = f"nero_params_L1T1_to_{model_configs['target']['label']}.pth"
# lora_state_dict = {k: v for k, v in nero_model.state_dict().items() if 'nero_' in k}
# torch.save(lora_state_dict, nero_params_path)
# print("Nero parameters saved to: ", nero_params_path)

In [None]:
# nero_model.eval()
# nero_model.set_return_nero_outputs(False)
# generate_text(
#     nero_model, 
#     tokenizer, 
#     prompt="海は",
# )
# nero_model.set_return_nero_outputs(True)
# nero_model.train();