In [1]:
import os
import json
import math
import gc

import numpy as np

import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer

from peft import (
    # AdaLoraConfig,
    LoraConfig,
    get_peft_model,
    PeftModel,
    LoHaConfig,
)

from cortexsubsetloader import tokenize
from pytorch_optimizer import Ranger21, Lamb, DAdaptLion, SAM, DAdaptAdam, LOMO, SophiaH

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [2]:
def filter_data(train, eval):
    evalf = []
    for f in eval:
        if f not in train:
            evalf.append(f)
    if len(evalf) < len(eval): print(f"Removed {len(eval)-len(evalf)} duplicates from eval")
    else: print("No duplicates found in eval")
    return evalf


def data_collator(features):
    batches = []
    for feature in features:
        inputs, prompt_len = feature
        data = [inputs]
        b_labels = inputs.clone()
        b_labels[:, :prompt_len] = -100
        labels = [b_labels]
            
        batch = {}
        batch['input_ids'] = torch.concat(data)
        batch['labels'] = torch.concat(labels)
        batches.append(batch)
    return batches


def get_data(train_name, eval_name, tokenizer, train_subset=None, eval_subset=None, shuffle=True, extend_train_length=0):
    if type(train_name) == str:
        with open(train_name, "r") as f:
            train_data = np.array(json.load(f))
    else: # list of str
        train_data = []
        for name in train_name:
            with open(name, "r") as f:
                train_data.extend(json.load(f))
        train_data = np.array(train_data)
    with open(eval_name, "r") as f:
        eval_data = np.array(json.load(f))

    if train_subset is not None:
        train_data = train_data[train_subset]
    if eval_subset is not None:
        eval_data = eval_data[eval_subset]
    
    eval_data = filter_data(train_data, eval_data)

    if shuffle:
        p = np.random.permutation(len(train_data))
        train_data = train_data[p]

    train_data = tokenize(tokenizer, train_data, 2048+extend_train_length)
    eval_data = tokenize(tokenizer, eval_data, 2048)

    train_data = data_collator(train_data)
    eval_data = data_collator(eval_data)

    return train_data, eval_data


true_eps = 0.01
# initial_loss_eps = 0.0001
intermed_check_step_split = 8

def evaluate(model, eval_d, return_to_cpu=False, return_stats=False, print_stats=True, 
             base_model=None, precompute_base_loss=True, device="cuda", instruction_finetuning=True):
    print("Evaluating", end=" ")
    model = model.to("cuda")
    eval_base_loss = 0
    lora_diff = 0
    eval_loss = 0
    head_to_head = 0
    eps0_head_to_head = 0
    overshoot = 0
    model.eval()
    steps_so_far = 1

    LORA = True
    if base_model is not None:
        LORA = False

    precomputed_base_losses = []
    if precompute_base_loss:

        if LORA:
            model.disable_adapter_layers()
        else:
            base_model = base_model.to(device)

        for batch in eval_d:
            if instruction_finetuning:
                inputs = batch['input_ids'].to(device)
                labels = batch['labels'].to(device)
            else:
                inputs = batch.to(device)
                labels = inputs.clone()

            with torch.no_grad():
                if LORA:
                    base_outputs_loss = model(inputs, labels=labels).loss
                else:
                    base_outputs_loss = base_model(inputs, labels=labels).loss
                precomputed_base_losses.append(base_outputs_loss)
        
        if LORA:
            model.enable_adapter_layers()
        else:
            base_model = base_model.to("cpu")
    else:
        base_model = base_model.to(device)


    for batch in eval_d:
        if instruction_finetuning:
            inputs = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
        else:
            inputs = batch.to(device)
            labels = inputs.clone()
        with torch.no_grad():
            if precompute_base_loss:
                base_outputs_loss = precomputed_base_losses.pop(0)
            else:
                if LORA:
                    model.disable_adapter_layers()
                    base_outputs_loss = model(inputs, labels=labels).loss
                    model.enable_adapter_layers()
                else:
                    base_outputs_loss = base_model(inputs, labels=labels).loss
            outputs_loss = model(inputs, labels=labels).loss

            base_loss = base_outputs_loss
            partial_loss = torch.nn.functional.relu(outputs_loss - (base_loss * (1.0 - true_eps)))
            overshoot_penalty = torch.nn.functional.relu(-(outputs_loss - (base_loss * (1.0 - true_eps))))
            loss = partial_loss / base_loss

            eval_loss += loss.item() / len(eval_d)
            eval_base_loss += base_outputs_loss.item() / len(eval_d)
            lora_diff += (outputs_loss.item() - base_outputs_loss.item()) / len(eval_d)
            head_to_head += 100.0 / len(eval_d) if outputs_loss < (base_outputs_loss * (1.0 - true_eps)) else 0.0
            head_to_head += 50.0 / len(eval_d) if outputs_loss == (base_outputs_loss * (1.0 - true_eps)) else 0.0
            eps0_head_to_head += 100.0 / len(eval_d) if outputs_loss < base_outputs_loss else 0.0
            eps0_head_to_head += 50.0 / len(eval_d) if outputs_loss == base_outputs_loss else 0.0
            overshoot += overshoot_penalty.item() / len(eval_d)

        if steps_so_far % (len(eval_d) // intermed_check_step_split) == 0:
            print(".", end="")
            gc.collect(); torch.cuda.empty_cache()
        steps_so_far += 1

    if return_to_cpu:
        model = model.to("cpu")

    if not LORA:
        base_model = base_model.to("cpu")

    gc.collect(); torch.cuda.empty_cache()

    data = {
        "loss": eval_loss,
        "base_loss": eval_base_loss,
        "lora_diff": lora_diff,
        "head_to_head": head_to_head,
        "eps0_head_to_head": eps0_head_to_head,
        "overshoot": overshoot
    }

    if print_stats:
        print(f" Loss: {eval_loss:.8f}, Base Loss: {eval_base_loss:.6f}, Lora Diff: {lora_diff:.8f},",
            f"WR: {head_to_head:.2f}%, 0epsWR: {eps0_head_to_head:.2f}%, OShL: {overshoot:.8f}")
    if return_stats:
        return data



def train(model, train_d, eval_d, base_model=None, inf_training=False, training_device="cuda",
            acc_batch_size=32, instruction_finetuing=True, precalculate_base_loss=True, precalculate_batch_mult=1.5,
            lr=1e-4, weight_decay=0.001, lr_scheduler="constant", warmup_steps=0, betas=(0.9, 0.99), squared_loss=False,
            use_sam=False, sam_rho=0.05, do_dadapt=False, use_LOMO=False, opt="adamw",
            manual_grad_clip_norm=0.0, manual_grad_clip_value=0.0, wait_for_full_batch=True, sam_reuse_base_outputs=False,
            do_base_gradient=True, add_overshoot_penalty=True, ignore_overshot_samples=True, bad_sample_mult=1.0,
            loss_eps = 0.015, overshoot_buffer = 0.01,
            prompt_dropout=0.0,
            eval_steps=1024, save_steps=1024, save_name="lora", do_save=True,
            average_stats=False,
            partial_eval_steps=0, partial_eval_size=128, save_n_start=0,
            gradient_checkpointing=False):
    if warmup_steps is None:
        warmup_steps = (eval_steps // acc_batch_size) // 2
    LORA = True
    if base_model is not None:
        LORA = False
        # base_model = base_model.to("cuda")

    if inf_training:
        # from subsetfalconloader import SubsetFalconLoader
        pass

    if precalculate_base_loss and do_base_gradient:
        raise ValueError("Precalculating base loss will disconnect base gradients")

    gc.collect(); torch.cuda.empty_cache()
    model = model.to("cuda")
    model.enable_input_require_grads()

    if gradient_checkpointing:
        model.config.use_cache = False
        grad_check_kwargs = {"use_reentrant": False}
        if do_base_gradient:
            grad_check_kwargs["use_reentrant"] = True
        model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=grad_check_kwargs)
    model.train()


    if not use_sam:
        if opt == "dadapt_adam":
            optimizer = DAdaptAdam(model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay, fixed_decay=True)
        elif opt == "sophia":
            optimizer = SophiaH(model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
        elif opt == "ranger":
            optimizer = Ranger21(model.parameters(), num_iterations=1, lr=lr, betas=betas, weight_decay=weight_decay,
                                 num_warm_up_iterations=0, num_warm_down_iterations=0)
        elif opt == "adamw":
            optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
        elif opt == "sgd":
            optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=betas[0], weight_decay=weight_decay)
        else:
            raise ValueError(f"Unknown optimizer {opt}")
    else:
        base_optimizer_args = {"lr": lr, "weight_decay": weight_decay, "betas": betas, "eps": 1e-8}

        if opt == "dadapt_adam":
            base_optimizer = DAdaptAdam
        elif opt == "sophia":
            base_optimizer = SophiaH
        elif opt == "ranger":
            base_optimizer = Ranger21
            base_optimizer_args["num_iterations"] = 1
            base_optimizer_args["num_warm_up_iterations"] = 0
            base_optimizer_args["num_warm_down_iterations"] = 0
        elif opt == "adamw":
            base_optimizer = torch.optim.AdamW
        elif opt == "sgd":
            base_optimizer = torch.optim.SGD
            del base_optimizer_args["betas"], base_optimizer_args["eps"]
            base_optimizer_args["momentum"] = betas[0]
        else:
            raise ValueError(f"Unknown optimizer {opt}")

        optimizer = SAM(model.parameters(), base_optimizer=base_optimizer, rho=sam_rho, adaptive=True, **base_optimizer_args)
        sam_optimizer = optimizer

    if use_LOMO:
        if use_sam:
            raise ValueError("LOMO and SAM are not compatible")
        if do_dadapt:
            raise ValueError("LOMO and DAdapt are not compatible")
        # check for lora layers
        for name, param in model.named_parameters():
            if "lora" in name:
                print("training lora using LOMO?")
                break
        optimizer = LOMO(model, lr=lr)
        lomo_optimizer = optimizer

    if lr_scheduler == "cosine":
        lr_scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, warmup_steps, (len(train_d)//acc_batch_size)+warmup_steps)
    else:
        lr_scheduler = transformers.get_constant_schedule_with_warmup(optimizer, warmup_steps)
    lr_scheduler.step() # don't want to start at 0

    @torch.jit.script
    def combined_loss_bgrad_os(outputs_loss, base_loss, loss_eps:float=loss_eps, overshoot_buffer:float=overshoot_buffer):
        partial_loss = torch.nn.functional.relu(outputs_loss - (base_loss * (1.0 - loss_eps)))
        loss = partial_loss / base_loss
        overshoot_penalty = torch.nn.functional.relu(-(loss - (loss_eps + overshoot_buffer)))
        loss = loss + overshoot_penalty
        return loss, overshoot_penalty.item()
    
    @torch.jit.script
    def combined_loss_os(outputs_loss, base_loss_in, loss_eps:float=loss_eps, overshoot_buffer:float=overshoot_buffer):
        base_loss = base_loss_in.item()
        partial_loss = torch.nn.functional.relu(outputs_loss - (base_loss * (1.0 - loss_eps)))
        loss = partial_loss / base_loss
        overshoot_penalty = torch.nn.functional.relu(-(loss - (loss_eps + overshoot_buffer)))
        loss = loss + overshoot_penalty
        return loss, overshoot_penalty.item()
    
    @torch.jit.script
    def combined_loss_os_noshot(outputs_loss, base_loss_in, loss_eps:float=loss_eps, overshoot_buffer:float=overshoot_buffer):
        base_loss = base_loss_in.item()
        partial_loss = torch.nn.functional.relu(outputs_loss - (base_loss * (1.0 - loss_eps)))
        loss = partial_loss / base_loss
        return loss, torch.nn.functional.relu(-(loss - (loss_eps + overshoot_buffer))).detach().item()
    
    @torch.jit.script
    def combined_loss_bgrad_os_noshot(outputs_loss, base_loss, loss_eps:float=loss_eps, overshoot_buffer:float=overshoot_buffer):
        partial_loss = torch.nn.functional.relu(outputs_loss - (base_loss * (1.0 - loss_eps)))
        loss = partial_loss / base_loss
        return loss, torch.nn.functional.relu(-(loss - (loss_eps + overshoot_buffer))).detach().item()
    
    if add_overshoot_penalty:
        if do_base_gradient:
            combined_loss = combined_loss_bgrad_os
        else:
            combined_loss = combined_loss_os
    else:
        if do_base_gradient:
            combined_loss = combined_loss_bgrad_os_noshot
        else:
            combined_loss = combined_loss_os_noshot


    steps_so_far = 1 # start at one to avoid all the modulo checks
    epoch_loss = 0; epoch_overshoot = 0; epoch_base_loss = 0; lora_diff = 0
    epoch_wr = 0; epoch_0eps_wr = 0
    fit_samples = 0; unfit_samples = 0
    sam_batch = []
    # lomo_batch_loss = []
    accum_steps = 0
    true_steps_taken = 0; prev_dot_step = -1
    last_tst = true_steps_taken
    sam_saved_base_outputs = []
    precalculated_base_outputs = []
    while len(train_d) > 0:

        if precalculate_base_loss and len(precalculated_base_outputs) == 0:
            batches = train_d[:int(acc_batch_size * precalculate_batch_mult)]

            if LORA:
                model.disable_adapter_layers()
            else:
                base_model = base_model.to(training_device)
            
            for batch in batches:
                if instruction_finetuing:
                    inputs = batch['input_ids'].to(training_device)
                    labels = batch['labels'].to(training_device)
                else:
                    inputs = batch.to(training_device)
                    labels = inputs.clone()

                if prompt_dropout > 0.0:
                    labels = labels.clone()
                    labels[:, torch.randperm(labels.size(1))[:int(labels.size(1) * prompt_dropout)]] = -100

                with torch.no_grad():
                    if LORA:
                        base_outputs = model(inputs, labels=labels)
                    else:
                        base_outputs = base_model(inputs, labels=labels)
                    precalculated_base_outputs.append(base_outputs.loss)

            if LORA:
                model.enable_adapter_layers()
            else:
                base_model = base_model.to("cpu")


        batch = train_d.pop(0)
        if instruction_finetuing:
            inputs = batch['input_ids'].to(training_device)
            labels = batch['labels'].to(training_device)
        else:
            inputs = batch.to(training_device)
            labels = inputs.clone()

        if prompt_dropout > 0.0:
            labels = labels.clone()
            labels[:, torch.randperm(labels.size(1))[:int(labels.size(1) * prompt_dropout)]] = -100
        
        if not precalculated_base_outputs:
            if LORA:
                model.disable_adapter_layers()
                base_outputs_loss = model(inputs, labels=labels).loss
                model.enable_adapter_layers()
                outputs_loss = model(inputs, labels=labels).loss
            else:
                base_outputs_loss = base_model(inputs, labels=labels).loss
                outputs_loss = model(inputs, labels=labels).loss
        else:
            base_outputs_loss = precalculated_base_outputs.pop(0)
            outputs_loss = model(inputs, labels=labels).loss

        # current_loss_eps = final_loss_eps 
        current_loss_eps = loss_eps # ((final_loss_eps * (min(steps_so_far * 8 / len(train_d), 1.0))) + initial_loss_eps) / (1.0 + initial_loss_eps)
        # current_loss_eps = 0.011 # lr_scheduler.get_last_lr()[0] * 100.0 # * 3.33

        loss, overshoot_penalty = combined_loss(outputs_loss, base_outputs_loss)
        loss = loss / acc_batch_size

        if not ignore_overshot_samples or overshoot_penalty == 0.0:
            if loss.item() > ((loss_eps / acc_batch_size)+1e-8):
                unfit_samples += -1
                if bad_sample_mult is not None and bad_sample_mult != 1.0:
                    loss = loss * bad_sample_mult
            elif loss.item() < ((loss_eps - (loss_eps + overshoot_buffer)) / acc_batch_size):
                fit_samples += 1

            # if not use_LOMO:
            if squared_loss:
                floss = loss ** 2
            else:
                floss = loss
            if opt == "sophia":
                floss.backward(create_graph=True, retain_graph=False)
            else:
                floss.backward()
            
            # else:
            #     lomo_batch_loss.append(loss)
            accum_steps += 1
            true_steps_taken += 1
            if use_sam:
                sam_batch.append((inputs, labels))
                if sam_reuse_base_outputs:
                    sam_saved_base_outputs.append(base_outputs_loss.detach())
        else:
            fit_samples += 1

        outputs_loss_item = outputs_loss.detach().item()
        base_loss_item = base_outputs_loss.detach().item()
        epoch_base_loss += base_loss_item
        lora_diff += (outputs_loss_item - base_loss_item)
        epoch_loss += loss.detach().item() * acc_batch_size
        epoch_wr += 100.0 if outputs_loss_item < (base_loss_item * (1.0 - true_eps)) else 0.0
        epoch_wr += 50.0 if outputs_loss_item == (base_loss_item * (1.0 - true_eps)) else 0.0
        epoch_0eps_wr += 100.0 if outputs_loss_item < base_loss_item else 0.0
        epoch_0eps_wr += 50.0 if outputs_loss_item == base_loss_item else 0.0
        epoch_overshoot += overshoot_penalty

        if (true_steps_taken % (acc_batch_size // intermed_check_step_split) == 0) and accum_steps != prev_dot_step:
            prev_dot_step = accum_steps
            print(".", end="")
            gc.collect(); torch.cuda.empty_cache()

        if (((steps_so_far % acc_batch_size == 0 or steps_so_far == len(train_d)) and not wait_for_full_batch) or 
                                                                            (wait_for_full_batch and accum_steps == acc_batch_size)):
            if manual_grad_clip_norm > 0.0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), manual_grad_clip_norm)
            if manual_grad_clip_value > 0.0:
                torch.nn.utils.clip_grad_value_(model.parameters(), manual_grad_clip_value)

            if use_LOMO:
                lomo_optimizer.fuse_update()
                # lomo_optimizer.fused_backward(loss, lr_scheduler.get_last_lr()[0])
                # lomo_optimizer.para

            if not use_sam and not use_LOMO:
                optimizer.step()
                optimizer.zero_grad()
            else:
                sub_steps = 1
                sam_optimizer.first_step(zero_grad=True)
                
                for inputs, labels in sam_batch:
                    if LORA:
                        if sam_reuse_base_outputs:
                            base_outputs_loss = sam_saved_base_outputs.pop(0)
                        else:
                            model.disable_adapter_layers()
                            base_outputs_loss = model(inputs, labels=labels).loss
                            model.enable_adapter_layers()
                        outputs = model(inputs, labels=labels)
                    else:
                        if sam_reuse_base_outputs:
                            base_outputs_loss = sam_saved_base_outputs.pop(0)
                        else:
                            base_outputs_loss = base_model(inputs, labels=labels).loss
                        outputs = model(inputs, labels=labels)

                    loss, overshoot_penalty = combined_loss(outputs.loss, base_outputs_loss)
                    if loss.item() > ((loss_eps / acc_batch_size)+1e-7):
                        if bad_sample_mult is not None and bad_sample_mult != 1.0:
                            loss = loss * bad_sample_mult
                    loss = loss / accum_steps
                    
                    if squared_loss:
                        floss = loss ** 2
                    else:
                        floss = loss
                    if opt == "sophia":
                        floss.backward(create_graph=True, retain_graph=False)
                    else:
                        floss.backward()

                    sub_steps += 1
                    if sub_steps % (acc_batch_size // intermed_check_step_split) == 0:
                        print(".", end="")
                    gc.collect(); torch.cuda.empty_cache()

                sam_optimizer.second_step(zero_grad=True)
                sam_batch = []

            if average_stats:
                stat_steps = steps_so_far
            else:
                stat_steps = accum_steps
                if wait_for_full_batch:
                    stat_steps += fit_samples
            print(f"Step {steps_so_far}/{len(train_d)}\tLoss: {epoch_loss/stat_steps:.6f}",
                                                    f"OShL: {epoch_overshoot/stat_steps:.3e}"
                                                    f"\tBase: {epoch_base_loss/stat_steps:.4f}",
                                                    f"Diff: {lora_diff/stat_steps:.4e}",
                                                    f"\tWR: {epoch_wr/stat_steps:2.2f}%",
                                                    f"0eps: {epoch_0eps_wr/stat_steps:2.2f}% ",
                                                    f"\tLR: {lr_scheduler.get_last_lr()[0]:.2e}",
                                                    f"eps: {loss_eps:.2e}",
                                                    f"fit: {fit_samples}/{unfit_samples}"
                                                    )

            if inf_training:
                # pages = [random.randint(1, SubsetFalconLoader.max_pages) for _ in range(accum_steps+fit_samples)]
                # batches = list(SubsetFalconLoader(pages=pages, rows_per_page=100, sample_total=accum_steps+fit_samples, silent=True))
                # random.shuffle(batches)
                # train_d.extend(batches[:accum_steps+fit_samples])
                # random.shuffle(train_d) # reshuffling may be bad if precomputing base loss
                pass

            if not average_stats:
                epoch_overshoot = 0
                epoch_loss = 0
                epoch_base_loss = 0
                lora_diff = 0
                epoch_wr = 0
                epoch_0eps_wr = 0
            unfit_samples = 0
            fit_samples = 0
            accum_steps = 0
            if sam_saved_base_outputs != []:
                raise ValueError("SAM saved base outputs didn't get used?")
            lr_scheduler.step()

            gc.collect(); torch.cuda.empty_cache()


        if do_save and true_steps_taken % save_steps == 0 and len(train_d) > 0 and true_steps_taken != last_tst:
            model.save_pretrained(save_name + '_' + str((true_steps_taken // save_steps) + save_n_start).format("02d"))
            
        do_full_eval = true_steps_taken % eval_steps == 0 and len(train_d) > 0 and true_steps_taken != last_tst
        if do_full_eval:
            evaluate(model, eval_d, base_model=base_model, device=training_device, instruction_finetuning=instruction_finetuing)
            model.train()
        if ((partial_eval_steps > 0 and true_steps_taken % partial_eval_steps == 0) and not do_full_eval 
                                            and len(train_d) > 0 and true_steps_taken != last_tst):
            evaluate(model, eval_d[:partial_eval_size], base_model=base_model, device=training_device, instruction_finetuning=instruction_finetuing)
            model.train()
        
        steps_so_far += 1
        last_tst = true_steps_taken

        gc.collect(); torch.cuda.empty_cache()

    if do_save:
        if save_n_start > 0:
            model.save_pretrained(save_name+"_X"+str(save_n_start))
        else:
            model.save_pretrained(save_name)

    model.eval()
    final_eval_stats = evaluate(model, eval_d, return_stats=True, base_model=base_model, device=training_device, instruction_finetuning=instruction_finetuing)

    model = model.to("cpu")
    gc.collect(); torch.cuda.empty_cache()

    return final_eval_stats

In [3]:
lora_name = "Hydrogen"
model_name = ""

neft_noise = 0.0 # bad actually?


rank = 1024
config = LoraConfig(
    r=rank, lora_alpha=16,
    target_modules=['q_proj','v_proj', 
                    # "k_proj", "o_proj", 
                    # "gate_proj", 
                    # "up_proj", 
                    # "down_proj"
                    ],  #   , 
    lora_dropout=0.0,
    bias="none", task_type="CAUSAL_LM",
    use_rslora=True,
    # init_lora_weights="gaussian",
)

# config = LoHaConfig(
#     r=rank, 
#     alpha=rank,
#     target_modules=["q_proj", "v_proj"], #,  , "up_proj" # , "o_proj" , "k_proj", "down_proj"
#     rank_dropout=0.0,
#     module_dropout=0.0,
#     init_weights=True,
#     task_type="CAUSAL_LM",
# )

# config = AdaLoraConfig(
#     peft_type="ADALORA",
#     task_type="CAUSAL_LM",
#     r=rank,
#     init_r=int(rank * 1.5),
#     target_r=rank, 
#     lora_alpha=rank,
#     orth_reg_weight=0.5,
#     target_modules=["q_proj", "v_proj"], #, , "down_proj" , "up_proj", "o_proj" # , "k_proj"
#     lora_dropout=0.0,
#     tinit=8, tfinal=32,
#     deltaT=1,
#     use_rslora=True,
#     total_step = 8192 // 64,
# )

params = {
    'low_cpu_mem_usage': True,
    'trust_remote_code': False,
    'torch_dtype': torch.bfloat16,
    'use_safetensors': True,
    'attn_implementation': "flash_attention_2"
}


model = AutoModelForCausalLM.from_pretrained(model_name, **params, cache_dir="Models")
lora_model = get_peft_model(model, config)
# lora_model = PeftModel.from_pretrained(model, model_id="lora_0x0", is_trainable=True)
# lora_model.to(torch.bfloat16)
lora_model.print_trainable_parameters()
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=False, use_fast=True, cache_dir="Models")

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

trainable params: 436,207,616 || all params: 7,677,939,712 || trainable%: 5.681310772970029


In [4]:
#0x0 2_16_e512 {'eval_loss': 0.34912705421447754, 'eval_runtime': 99.3432, 'eval_samples_per_second': 5.154, 'eval_steps_per_second': 5.154}
#t27 2_16_e512 {'eval_loss': 0.3499789834022522, 'eval_runtime': 103.1445, 'eval_samples_per_second': 4.964, 'eval_steps_per_second': 4.964}

#0x0 2_e26_512 {'eval_loss': 0.4443358778953552, 'eval_runtime': 92.638, 'eval_samples_per_second': 5.527, 'eval_steps_per_second': 5.527}
#t27 2_e26_512 {'eval_loss': 0.4439195692539215, 'eval_runtime': 92.953, 'eval_samples_per_second': 5.508, 'eval_steps_per_second': 5.508}

#aes e1_512 {'eval_loss': 0.3197530210018158, 'eval_runtime': 100.1814, 'eval_samples_per_second': 5.111, 'eval_steps_per_second': 5.111}
#0x0 e1_512 {'eval_loss': 0.3169699013233185, 'eval_runtime': 98.9367, 'eval_samples_per_second': 5.175, 'eval_steps_per_second': 5.175}
#t27 e1_512 {'eval_loss': 0.31162574887275696, 'eval_runtime': 96.6104, 'eval_samples_per_second': 5.3, 'eval_steps_per_second': 5.3}
#m00 e1_512 {'eval_loss': 0.30996233224868774, 'eval_runtime': 91.7931, 'eval_samples_per_second': 5.578, 'eval_steps_per_second': 5.578}
#sn6 e1_512 {'eval_loss': 0.3129960298538208, 'eval_runtime': 94.8393, 'eval_samples_per_second': 5.399, 'eval_steps_per_second': 5.399}

#m01 e18_512 {'eval_loss': 0.29920494556427, 'eval_runtime': 103.3506, 'eval_samples_per_second': 4.954, 'eval_steps_per_second': 4.954}
#sn6 e18_512 {'eval_loss': 0.30003952980041504, 'eval_runtime': 92.8474, 'eval_samples_per_second': 5.514, 'eval_steps_per_second': 5.514}

In [5]:
# train_data, eval_data = get_data(train_name=['data/perf_43_16896_256.json',], #  # 'data/atk_43_2048_2048.json'
#                                  train_subset=np.arange(16384), eval_subset=np.arange(16384, 16384+512),
#                                  eval_name='data/perf_43_16896_256.json', tokenizer=tokenizer, shuffle=False)
train_data, eval_data = get_data(train_name=['data/cortex_52_4608.json',],
                                 train_subset=np.arange(4096), eval_subset=np.arange(4096,4096+512),
                                 eval_name='data/cortex_52_4608.json', tokenizer=tokenizer, shuffle=False)

No duplicates found in eval


In [6]:
train(lora_model, train_data, eval_data, base_model=None, inf_training=False, training_device="cuda",
        acc_batch_size=1024, instruction_finetuing=True, precalculate_base_loss=True, precalculate_batch_mult=1.55,
        lr=4e-5, weight_decay=0.0, lr_scheduler="constant", warmup_steps=0, betas=(0.8, 0.95), squared_loss=False,
        use_sam=False, sam_rho=0.2, opt="adamw",
        loss_eps = 0.02, overshoot_buffer = -0.01,
        manual_grad_clip_norm=0.0, manual_grad_clip_value=0.0, wait_for_full_batch=True, sam_reuse_base_outputs=True,
        do_base_gradient=False, add_overshoot_penalty=False, ignore_overshot_samples=True, bad_sample_mult=1.0,
        prompt_dropout=0.0,
        eval_steps=2048, save_steps=2048, do_save=True, save_name=lora_name,
        average_stats=False,
        partial_eval_steps=0, partial_eval_size=128, save_n_start=0,
        gradient_checkpointing=True)

KeyboardInterrupt: 

In [7]:
gc.collect(); torch.cuda.empty_cache()
evaluate(lora_model, eval_data, return_stats=True, base_model=None)

Evaluating ........ Loss: 0.01634118, Base Loss: 0.394893, Lora Diff: 0.00121789, WR: 5.08%, 0epsWR: 38.09%, OShL: 0.00017984


{'loss': 0.016341184635450645,
 'base_loss': 0.39489315043670103,
 'lora_diff': 0.0012178865729310928,
 'head_to_head': 5.078125,
 'eps0_head_to_head': 38.0859375,
 'overshoot': 0.00017984330224862788}

In [7]:
train_data, _ = get_data(train_name=['data/cortex_40_8704.json','data/cortex_39_8704.json','data/cortex_38_8704.json'],
                                 eval_subset=np.arange(8192, 8192+512),
                                 eval_name='data/cortex_41_8704.json', tokenizer=tokenizer)
train(lora_model, train_data, eval_data, 
        acc_batch_size=64,
        lr=2.5e-8, weight_decay=0.0, lr_scheduler="constant", warmup_steps=8, betas=(0.88, 0.97), squared_loss=False,
        use_sam=True, sam_rho=0.25, do_dadapt=False,
        loss_eps = 0.02, overshoot_buffer = -0.01,
        manual_grad_clip_norm=0.0, manual_grad_clip_value=0.0, wait_for_full_batch=True,
        do_base_gradient=False, add_overshoot_penalty=False, ignore_overshot_samples=True, bad_sample_mult=1.25,
        prompt_dropout=0.0,
        eval_steps=2048, save_steps=512, do_save=True, save_name=lora_name,
        average_stats=False,
        partial_eval_steps=512, partial_eval_size=128, save_n_start=0)
gc.collect(); torch.cuda.empty_cache()

Removed 13 duplicates from eval
................Step 67/8192	Loss: 0.021026 OShL: 8.122e-04	Base: 0.4503 Diff: -1.3390e-03 	WR: 4.48% 0eps: 59.70%  	LR: 3.12e-09 eps: 2.00e-02 fit: 3/-27
................Step 134/8192	Loss: 0.021912 OShL: 6.013e-05	Base: 0.3729 Diff: -3.9632e-04 	WR: 4.48% 0eps: 49.25%  	LR: 6.25e-09 eps: 2.00e-02 fit: 3/-34
................Step 202/8192	Loss: 0.023419 OShL: 9.164e-04	Base: 0.4080 Diff: -1.3260e-03 	WR: 5.88% 0eps: 54.41%  	LR: 9.37e-09 eps: 2.00e-02 fit: 4/-31
................Step 271/8192	Loss: 0.022529 OShL: 2.105e-04	Base: 0.3744 Diff: -1.9587e-04 	WR: 7.25% 0eps: 49.28%  	LR: 1.25e-08 eps: 2.00e-02 fit: 5/-35
................Step 341/8192	Loss: 0.020482 OShL: 1.952e-03	Base: 0.4646 Diff: -2.7062e-03 	WR: 8.57% 0eps: 52.86%  	LR: 1.56e-08 eps: 2.00e-02 fit: 6/-33
................Step 408/8192	Loss: 0.021326 OShL: 6.891e-05	Base: 0.4062 Diff: -3.0093e-04 	WR: 4.48% 0eps: 55.22%  	LR: 1.87e-08 eps: 2.00e-02 fit: 3/-30
................Step 480/8192	Los

In [8]:
train_data, _ = get_data(train_name=['data/cortex_39_8704.json','data/cortex_38_8704.json'],
                                 eval_subset=np.arange(8192, 8192+512),
                                 eval_name='data/cortex_41_8704.json', tokenizer=tokenizer)
train(lora_model, train_data, eval_data, 
        acc_batch_size=128,
        lr=2.5e-7, weight_decay=0.0, lr_scheduler="constant", warmup_steps=32, betas=(0.9, 0.99), squared_loss=False,
        use_sam=True, sam_rho=0.5, do_dadapt=False,
        loss_eps = 0.02, overshoot_buffer = -0.01,
        manual_grad_clip_norm=0.0, manual_grad_clip_value=0.0, wait_for_full_batch=True,
        do_base_gradient=False, add_overshoot_penalty=False, ignore_overshot_samples=True, bad_sample_mult=1.15,
        prompt_dropout=0.0,
        eval_steps=2048, save_steps=512, do_save=True, save_name=lora_name,
        average_stats=False,
        partial_eval_steps=512, partial_eval_size=128, save_n_start=0)
gc.collect(); torch.cuda.empty_cache()

No duplicates found in eval
................Step 146/17408	Loss: 0.019014 OShL: 1.533e-03	Base: 0.4047 Diff: -2.1964e-03 	WR: 12.33% 0eps: 59.59%  	LR: 7.81e-09 eps: 2.00e-02 fit: 18/-59
................Step 286/17408	Loss: 0.020845 OShL: 9.163e-04	Base: 0.4167 Diff: -1.0542e-03 	WR: 8.57% 0eps: 57.14%  	LR: 1.56e-08 eps: 2.00e-02 fit: 12/-60
................Step 425/17408	Loss: 0.020790 OShL: 9.871e-05	Base: 0.4187 Diff: -3.1786e-04 	WR: 7.91% 0eps: 53.24%  	LR: 2.34e-08 eps: 2.00e-02 fit: 11/-65
................Step 573/17408	Loss: 0.020117 OShL: 1.626e-03	Base: 0.4422 Diff: -2.2701e-03 	WR: 13.51% 0eps: 60.81%  	LR: 3.12e-08 eps: 2.00e-02 fit: 20/-58
Evaluating ........ Loss: 0.00872664, Base Loss: 0.366165, Lora Diff: -0.00173172, WR: 13.28%, 0epsWR: 63.28%, OShL: 0.00098392
................Step 712/17408	Loss: 0.020651 OShL: 1.554e-03	Base: 0.3986 Diff: -2.0033e-03 	WR: 7.91% 0eps: 56.12%  	LR: 3.91e-08 eps: 2.00e-02 fit: 11/-61
................Step 864/17408	Loss: 0.020669 OShL: 

KeyboardInterrupt: 

In [None]:
evaluate(lora_model, eval_data, return_stats=True)

Evaluating ........0.012056129032998797
 Loss: 0.01205613, Base Loss: 0.365259, Lora Diff: -0.00147909, WR: 13.87%, 0epsWR: 58.01%, OShL: 0.00109561


{'loss': 0.012056129032998797,
 'base_loss': 0.36525918503184585,
 'lora_diff': -0.0014790902009735873,
 'head_to_head': 13.8671875,
 'eps0_head_to_head': 58.0078125,
 'overshoot': 0.0010956073215311335}

In [None]:
# del model, lora_model, tokenizer, train_data, eval_data, trainer; gc.collect(); torch.cuda.empty_cache()
# gc.collect(); torch.cuda.empty_cache()

In [None]:
# lora_model = lora_model.merge_and_unload()
# lora_model.config.name_or_path = "MesozoicMetallurgist/new_model"
# model_dir = "Models/merged_model"
# model_dir = os.path.expanduser("~/finetuning-subnet/merged_model")
# if not os.path.exists(model_dir):
#     os.makedirs(model_dir, exist_ok=True)
# else:
#     # wipe the directory
#     for file in os.listdir(model_dir):
#         os.remove(os.path.join(model_dir, file))
# lora_model.save_pretrained(save_directory=model_dir, safe_serialization=True)
# tokenizer.save_pretrained(save_directory=model_dir)

# del lora_model, trainer, model, tokenizer

('/home/alyx/finetuning-subnet/merged_model/tokenizer_config.json',
 '/home/alyx/finetuning-subnet/merged_model/special_tokens_map.json',
 '/home/alyx/finetuning-subnet/merged_model/tokenizer.model',
 '/home/alyx/finetuning-subnet/merged_model/added_tokens.json',
 '/home/alyx/finetuning-subnet/merged_model/tokenizer.json')