In [1]:
import os
import json
import math
import time
import gc
import copy

import numpy as np

import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer

from peft import (
    LoraConfig,
    get_peft_model,
    PeftModel,
    LoHaConfig,
)

from cortexsubsetloader import CortexSubsetLoader, tokenize
from pytorch_optimizer import Ranger21, SAM, DAdaptAdam, SophiaH, ScalableShampoo #, Lamb, DAdaptLion, LOMO
from pytorch_optimizer.optimizer.sam import WSAM

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

notebook_start_time = time.time()

In [2]:
# !pip install transformers==4.37.2
# !pip install -U transformers
# !pip install -U peft
# !pip show torch transformers peft flash-attn

In [2]:
def filter_data(train, eval):
    evalf = []
    for f in eval:
        if f not in train:
            evalf.append(f)
    if len(evalf) < len(eval): print(f"Removed {len(eval)-len(evalf)} duplicates from eval")
    else: print("No duplicates found in eval")
    return evalf


def data_collator(features):
    batches = []
    for feature in features:
        inputs, prompt_len = feature
        data = [inputs]
        b_labels = inputs.clone()
        b_labels[:, :prompt_len] = -100
        labels = [b_labels]
            
        batch = {}
        batch['input_ids'] = torch.concat(data)
        batch['labels'] = torch.concat(labels)
        batches.append(batch)
    return batches


def get_data(train_name, eval_name, tokenizer, train_subset=None, eval_subset=None, shuffle=True, extend_train_length=0):
    if type(train_name) == str:
        with open(train_name, "r") as f:
            train_data = np.array(json.load(f))
    else: # list of str
        train_data = []
        for name in train_name:
            with open(name, "r") as f:
                train_data = train_data + json.load(f)
        train_data = np.array(train_data)
    with open(eval_name, "r") as f:
        eval_data = np.array(json.load(f))

    if train_subset is not None:
        train_data = train_data[train_subset]
    if eval_subset is not None:
        eval_data = eval_data[-eval_subset:]
    
    eval_data = filter_data(train_data, eval_data)

    if shuffle:
        p = np.random.permutation(len(train_data))
        train_data = train_data[p]

    train_data = tokenize(tokenizer, train_data, 2048+extend_train_length)
    eval_data = tokenize(tokenizer, eval_data, 2048)

    train_data = data_collator(train_data)
    eval_data = data_collator(eval_data)

    return train_data, eval_data


def merge(model0, model1, ratio=0.5, embed_ratio=None, norm_ratio=None, fc_ratio=None): # higher ratio means more of model0
    if embed_ratio is None:
        embed_ratio = ratio
    if norm_ratio is None:
        norm_ratio = ratio
    if fc_ratio is None:
        fc_ratio = ratio

    params0 = {}
    for name, param in model0.named_parameters():
        params0[name] = param

    for name, param in model1.named_parameters():
        if "embed" in name:
            param.data = ((params0[name].data * embed_ratio) + (param.data * (1 - embed_ratio)))
        elif ("up_proj" not in name 
            and "down_proj" not in name 
            and "gate_proj" not in name 
            and "o_proj" not in name 
            and "k_proj" not in name 
            and "v_proj" not in name 
            and "q_proj" not in name
            and "embed" not in name
            ):
            param.data = ((params0[name].data * norm_ratio) + (param.data * (1 - norm_ratio)))
        elif "up_proj" in name or "down_proj" in name:
            param.data = ((params0[name].data * fc_ratio) + (param.data * (1 - fc_ratio)))
        else:
            param.data = ((params0[name].data * ratio) + (param.data * (1 - ratio)))

    return model1

def copy_weights_over(model0, model1):
    params0 = {}
    for name, param in model0.named_parameters():
        params0[name] = param

    for name, param in model1.named_parameters():
        if name in params0:
            param.data = params0[name].data
    return model1


# initial_loss_eps = 0.0001
intermed_check_step_split = 8

def simple_eval(model, eval_d):
    print("Evaluating", end=" ")
    model = model.to("cuda")
    model.eval()
    eval_loss = 0
    steps_so_far = 1
    for batch in eval_d:
        inputs = batch['input_ids'].to("cuda")
        labels = batch['labels'].to("cuda")
        with torch.no_grad():
            outputs = model(inputs, labels=labels)
            eval_loss += outputs.loss.item() / len(eval_d)
        if steps_so_far % (len(eval_d) // intermed_check_step_split) == 0:
            print(".", end="")
            gc.collect(); torch.cuda.empty_cache()
        steps_so_far += 1
    model = model.to("cpu")
    gc.collect(); torch.cuda.empty_cache()
    print(f" Loss: {eval_loss:.8f}")

def evaluate(model, eval_d, return_to_cpu=False, return_stats=False, print_stats=True, cached_base_loss=None,
             base_model=None, precompute_base_loss=True, device="cuda", instruction_finetuning=True, true_eps=0.01):
    print("Evaluating", end=" ")
    model = model.to("cuda")
    eval_base_loss = 0
    lora_diff = 0
    eval_loss = 0
    head_to_head = 0
    eps0_head_to_head = 0
    overshoot = 0
    model.eval()
    steps_so_far = 1

    LORA = True
    if base_model is not None:
        LORA = False

    precomputed_base_losses = []
    if cached_base_loss is not None:
        for x in cached_base_loss:
            precomputed_base_losses.append(x)
        precompute_base_loss = True
    elif precompute_base_loss and cached_base_loss is None:
        if LORA:
            model.disable_adapter_layers()
        else:
            model = model.to("cpu")
            gc.collect(); torch.cuda.empty_cache()
            base_model = base_model.to(device)

        for batch in eval_d:
            if instruction_finetuning:
                inputs = batch['input_ids'].to(device)
                labels = batch['labels'].to(device)
            else:
                inputs = batch.to(device)
                labels = inputs.clone()

            with torch.no_grad():
                if LORA:
                    base_outputs_loss = model(inputs, labels=labels).loss
                else:
                    base_outputs_loss = base_model(inputs, labels=labels).loss
                precomputed_base_losses.append(base_outputs_loss)
        
        if LORA:
            model.enable_adapter_layers()
        else:
            base_model = base_model.to("cpu")
            gc.collect(); torch.cuda.empty_cache()
            model = model.to(device)
    else:
        print("WARN: not precomputing base loss will put both models on the same device")
        base_model = base_model.to(device)


    for batch in eval_d:
        if instruction_finetuning:
            inputs = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
        else:
            inputs = batch.to(device)
            labels = inputs.clone()
        with torch.no_grad():
            if precompute_base_loss:
                base_outputs_loss = precomputed_base_losses.pop(0)
            else:
                if LORA:
                    model.disable_adapter_layers()
                    base_outputs_loss = model(inputs, labels=labels).loss
                    model.enable_adapter_layers()
                else:
                    base_outputs_loss = base_model(inputs, labels=labels).loss
            outputs_loss = model(inputs, labels=labels).loss

            base_loss = base_outputs_loss
            partial_loss = torch.nn.functional.relu(outputs_loss - (base_loss * (1.0 - true_eps)))
            overshoot_penalty = torch.nn.functional.relu(-(outputs_loss - (base_loss * (1.0 - true_eps))))
            loss = partial_loss / base_loss

            if cached_base_loss is None:
                base_outputs_loss_item = base_outputs_loss.item()
            else:
                base_outputs_loss_item = base_outputs_loss
            eval_loss += loss.item() / len(eval_d)
            eval_base_loss += base_outputs_loss_item / len(eval_d)
            lora_diff += (outputs_loss.item() - base_outputs_loss_item) / len(eval_d)
            head_to_head += 100.0 / len(eval_d) if outputs_loss < (base_outputs_loss * (1.0 - true_eps)) else 0.0
            head_to_head += 50.0 / len(eval_d) if outputs_loss == (base_outputs_loss * (1.0 - true_eps)) else 0.0
            eps0_head_to_head += 100.0 / len(eval_d) if outputs_loss < base_outputs_loss else 0.0
            eps0_head_to_head += 50.0 / len(eval_d) if outputs_loss == base_outputs_loss else 0.0
            overshoot += overshoot_penalty.item() / len(eval_d)

        if steps_so_far % (len(eval_d) // intermed_check_step_split) == 0:
            print(".", end="")
            gc.collect(); torch.cuda.empty_cache()
        steps_so_far += 1

    if return_to_cpu:
        model = model.to("cpu")

    if not LORA:
        base_model = base_model.to("cpu")

    gc.collect(); torch.cuda.empty_cache()

    data = {
        "loss": eval_loss,
        "base_loss": eval_base_loss,
        "lora_diff": lora_diff,
        "head_to_head": head_to_head,
        "eps0_head_to_head": eps0_head_to_head,
        "overshoot": overshoot
    }

    if print_stats:
        print(f" Loss: {eval_loss:.8f}, Base Loss: {eval_base_loss:.6f}, Lora Diff: {lora_diff:.8f},",
            f"WR: {head_to_head:.2f}%, 0epsWR: {eps0_head_to_head:.2f}%, OShL: {overshoot:.8f}")
    if return_stats:
        return data

old_train_data = []

def train(model, tokenizer, train_d=[], eval_d=[], base_model=None, inf_training=False, inf_data_steps=5, inf_data_decay=1, training_device="cuda",
            acc_batch_size=512, instruction_finetuing=True, precalculate_batch_mult=2.25, precalc_eval_base=True,
            lr=1e-5, weight_decay=0.001, lr_scheduler="constant", warmup_steps=4, warmup_end_offset=0, betas=(0.9, 0.99), 
            use_sam=False, sam_rho=0.05, opt="adamw", adaptive_sam=True, wsam_variant=False, ignore_below=0.0, ignore_auto_percent=2,
            manual_grad_clip_norm=1.0, sam_grad_clip_norm=None,
            add_overshoot_penalty=False, ignore_overshot_samples=True, bad_sample_mult=1.0, remerging=False, remerge_eval=True, remerge_ratio=0.5,
            loss_eps = 0.02, overshoot_buffer = -0.01, true_eps=0.01, simple_loss=False, process_base_loss=True, relative_loss=False,
            eval_steps=2048, save_name="lora", do_save=True,
            average_stats=False, save_n_start=0,
            gradient_checkpointing=False, excessive_cache_clearing=False, base_model_switching=True):
    global old_train_data
    if warmup_steps is None:
        warmup_steps = (eval_steps // acc_batch_size) // 2
    LORA = True
    if base_model is not None:
        LORA = False
        # base_model = base_model.to("cuda")

    if len(train_d) == 0:
        print("WARN: no training data provided, enabling infinite training")
        inf_training = True

    if remerging:
        params = {
            'low_cpu_mem_usage': True,
            'trust_remote_code': False,
            'torch_dtype': torch.bfloat16,
            'use_safetensors': True,
            'attn_implementation': "flash_attention_2"
        }
        # create a copy of the model to merge with
        model.save_pretrained("model_prev")
        model_prev = AutoModelForCausalLM.from_pretrained("model_prev", **params)
        print("WARN: remerging is enabled, model_prev created and may take up extra memory")

    
    def get_new_data(n_samples=2560, dd_eval=True, dd_train=True, steps=1, old_data=old_train_data):
        cortex_subset_loader = CortexSubsetLoader(latest=True, random_seed = None, max_samples=n_samples, progress=False, 
                                        running=True, retry_limit=5, page_size=400, retry_delay=5, silent=True, steps=steps,
                                        ignore_list=old_data)
        batches = data_collator(cortex_subset_loader.tokenize(tokenizer))
        dedup_batches = batches
        p = np.random.permutation(len(dedup_batches))# [:min(n_samples, len(dedup_batches))]
        return [dedup_batches[i] for i in p]

    if len(eval_d) == 0:
        print("WARN: no evaluation data provided, acquiring new data")
        eval_d = get_new_data(512)
        eval_d = eval_d[:512]

    add_inf_steps = 0
    if inf_training:
        if len(train_d) == 0:
            print("WARN: no training data provided, acquiring new data")
            while len(train_d) < (acc_batch_size * precalculate_batch_mult):
                new_data = get_new_data(int(acc_batch_size * precalculate_batch_mult), steps=inf_data_steps+add_inf_steps)
                if len(new_data) == 0:
                    add_inf_steps += inf_data_steps
                else:
                    add_inf_steps = add_inf_steps - inf_data_decay
                    train_d = train_d + new_data
    
    if simple_loss:
        if process_base_loss:
            print("WARN: simple loss is enabled, this will disable base model processing")
            process_base_loss = False
        
    if not process_base_loss:
        if ignore_overshot_samples:
            print("Base loss processing is disabled, disabling ignore overshot samples")
            ignore_overshot_samples = False


    gc.collect(); torch.cuda.empty_cache()
    model = model.to("cuda")
    model.enable_input_require_grads()

    if gradient_checkpointing:
        model.config.use_cache = False
        grad_check_kwargs = {"use_reentrant": False}
        model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=grad_check_kwargs)
    model.train()


    if not use_sam:
        if opt == "dadapt_adam":
            optimizer = DAdaptAdam(model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay, fixed_decay=True)
        elif opt == "shampoo":
            optimizer = ScalableShampoo(model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay, 
                                        start_preconditioning_step=warmup_steps+1, preconditioning_compute_steps=1)
        elif opt == "sophia":
            optimizer = SophiaH(model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
        elif opt == "ranger":
            optimizer = Ranger21(model.parameters(), num_iterations=1, lr=lr, betas=betas, weight_decay=weight_decay,
                                 num_warm_up_iterations=0, num_warm_down_iterations=0)
        elif opt == "adamw":
            optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
        elif opt == "sgd":
            optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=betas[0], weight_decay=weight_decay)
        else:
            raise ValueError(f"Unknown optimizer {opt}")
    else:
        base_optimizer_args = {"lr": lr, "weight_decay": weight_decay, "betas": betas, "eps": 1e-8}

        if opt == "dadapt_adam":
            base_optimizer = DAdaptAdam
        elif opt == "sophia":
            base_optimizer = SophiaH
        elif opt == "shampoo":
            base_optimizer = ScalableShampoo
            base_optimizer_args["start_preconditioning_step"] = warmup_steps+1
            base_optimizer_args["preconditioning_compute_steps"] = 1
        elif opt == "ranger":
            base_optimizer = Ranger21
            base_optimizer_args["num_iterations"] = 1
            base_optimizer_args["num_warm_up_iterations"] = 0
            base_optimizer_args["num_warm_down_iterations"] = 0
        elif opt == "adamw":
            base_optimizer = torch.optim.AdamW
        elif opt == "sgd":
            base_optimizer = torch.optim.SGD
            del base_optimizer_args["betas"], base_optimizer_args["eps"]
            base_optimizer_args["momentum"] = betas[0]
        else:
            raise ValueError(f"Unknown optimizer {opt}")

        if not wsam_variant:
            optimizer = SAM(model.parameters(), base_optimizer=base_optimizer, rho=sam_rho, adaptive=adaptive_sam, **base_optimizer_args)
            sam_optimizer = optimizer
        else:
            optimizer = WSAM(model, params=model.parameters(), base_optimizer=base_optimizer, rho=sam_rho, adaptive=adaptive_sam, 
                             **base_optimizer_args, max_norm=sam_grad_clip_norm)
            sam_optimizer = optimizer



    if lr_scheduler == "cosine":
        lr_scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, warmup_steps, (len(train_d)//acc_batch_size)+warmup_end_offset)
    elif lr_scheduler == "polynomial":
        lr_scheduler = transformers.get_polynomial_decay_schedule_with_warmup(optimizer, warmup_steps, 
                                                                              (len(train_d)//acc_batch_size)+warmup_end_offset)
    elif lr_scheduler == "constant":
        lr_scheduler = transformers.get_constant_schedule_with_warmup(optimizer, warmup_steps)
    else:
        raise ValueError(f"Unknown lr_scheduler {lr_scheduler}")
    lr_scheduler.step() # don't want to start at 0
    
    @torch.jit.script
    def combined_loss_os(outputs_loss, base_loss_in, loss_eps:float=loss_eps, overshoot_buffer:float=overshoot_buffer):
        base_loss = base_loss_in.item()
        partial_loss = outputs_loss - (base_loss * (1.0 - loss_eps))
        loss = partial_loss / base_loss
        overshoot_penalty = torch.nn.functional.relu(-(loss + overshoot_buffer))
        return torch.nn.functional.relu(loss) + overshoot_penalty, overshoot_penalty.item()
    
    @torch.jit.script
    def combined_loss_os_noshot(outputs_loss, base_loss_in, loss_eps:float=loss_eps, overshoot_buffer:float=overshoot_buffer):
        base_loss = base_loss_in.item()
        partial_loss = outputs_loss - (base_loss * (1.0 - loss_eps))
        loss = partial_loss / base_loss
        return torch.nn.functional.relu(loss), torch.nn.functional.relu(-(loss + overshoot_buffer)).item()
    
    @torch.jit.script
    def simple_loss_func(outputs_loss, base_loss_in, loss_eps:float=loss_eps, overshoot_buffer:float=overshoot_buffer):
        return outputs_loss, 0.0
    
    @torch.jit.script
    def relative_loss_func(outputs_loss, base_loss_in, loss_eps:float=loss_eps, overshoot_buffer:float=overshoot_buffer):
        base_loss = base_loss_in.item()
        relative_loss = outputs_loss / (base_loss + loss_eps)
        osh = torch.nn.functional.relu(-((relative_loss - 1.0) + overshoot_buffer)).item()
        loss = outputs_loss * (4.2 * torch.square(torch.sin((torch.clamp(relative_loss, min=0.813, max=1.187) - 0.065) * 4.2)))
        return loss, osh

    if add_overshoot_penalty:
        combined_loss = combined_loss_os
    else:
        combined_loss = combined_loss_os_noshot

    if simple_loss:
        combined_loss = simple_loss_func
        if relative_loss:
            print("WARN: simple loss is enabled, disabling relative loss")
            relative_loss = False

    if relative_loss:
        combined_loss = relative_loss_func


    if precalc_eval_base:
        print("Note: precalced eval base loss does not account for pretrained fine-tuning")
        eval_base_loss = []
        steps_so_far = 1
        if not LORA:
            model = model.to("cpu")
            gc.collect(); torch.cuda.empty_cache()
            base_model = base_model.to(training_device)
        for batch in eval_d:
            if instruction_finetuing:
                inputs = batch['input_ids'].to(training_device)
                labels = batch['labels'].to(training_device)
            else:
                inputs = batch.to(training_device)
                labels = inputs.clone()
            with torch.no_grad():
                if LORA:
                    base_outputs_loss = model(inputs, labels=labels).loss
                else:
                    base_outputs_loss = base_model(inputs, labels=labels).loss
                eval_base_loss.append(base_outputs_loss.item())
            if steps_so_far % (len(eval_d) // intermed_check_step_split) == 0:
                print(".", end="")
                gc.collect(); torch.cuda.empty_cache()
            steps_so_far += 1
        if not LORA:
            base_model = base_model.to("cpu")
            gc.collect(); torch.cuda.empty_cache()
            model = model.to(training_device)
        print(f"Eval Base Loss: {sum(eval_base_loss)/len(eval_d):.6f}")
        if ignore_below == "auto":
            ignore_below = np.percentile(eval_base_loss, ignore_auto_percent)
    else:
        eval_base_loss = None

    if ignore_below == "auto":
        print("eval loss not precalced, ignore_below set to 0.0")
        ignore_below = 0.0

    steps_so_far = 1 # start at one to avoid all the modulo checks
    epoch_loss = 0; epoch_overshoot = 0; epoch_base_loss = 0; lora_diff = 0
    epoch_wr = 0; epoch_0eps_wr = 0
    fit_samples = 0; unfit_samples = 0
    sam_batch = []
    # lomo_batch_loss = []
    accum_steps = 0.0
    true_steps_taken = 0; prev_dot_step = -1
    last_tst = true_steps_taken
    add_inf_steps = 0
    sam_saved_base_outputs = []
    precalculated_base_outputs = []
    while len(train_d) > 0:

        if (true_steps_taken % (acc_batch_size // intermed_check_step_split) == 0) and accum_steps != prev_dot_step:
            prev_dot_step = accum_steps
            print(".", end="")
            gc.collect(); torch.cuda.empty_cache()

        while len(train_d) < (acc_batch_size * precalculate_batch_mult):
            new_data = get_new_data(int(acc_batch_size * precalculate_batch_mult), steps=inf_data_steps+add_inf_steps)
            if len(new_data) == 0:
                add_inf_steps += inf_data_steps
            else:
                add_inf_steps = add_inf_steps - inf_data_decay
            # old_train_data = old_train_data + new_data
            train_d = train_d + new_data

        bstep = 0
        if len(precalculated_base_outputs) == 0 and process_base_loss:
            batches = train_d[:int(acc_batch_size * precalculate_batch_mult)]

            if LORA:
                model.disable_adapter_layers()
            else:
                if base_model_switching:
                    model = model.to("cpu")
                    gc.collect(); torch.cuda.empty_cache()
                base_model = base_model.to(training_device)
            
            for batch in batches:
                if instruction_finetuing:
                    inputs = batch['input_ids'].to(training_device)
                    labels = batch['labels'].to(training_device)
                else:
                    inputs = batch.to(training_device)
                    labels = inputs.clone()

                with torch.no_grad():
                    if LORA:
                        base_outputs = model(inputs, labels=labels)
                    else:
                        base_outputs = base_model(inputs, labels=labels)
                    precalculated_base_outputs.append(torch.tensor(base_outputs.loss.item()))

                bstep+=1
                if bstep % (intermed_check_step_split) == 0:
                    gc.collect(); torch.cuda.empty_cache()

            if LORA:
                model.enable_adapter_layers()
            else:
                base_model = base_model.to("cpu")
                gc.collect(); torch.cuda.empty_cache()
                if base_model_switching:
                    model = model.to(training_device)


        batch = train_d.pop(0)
        if instruction_finetuing:
            inputs = batch['input_ids'].to(training_device)
            labels = batch['labels'].to(training_device)
        else:
            inputs = batch.to(training_device)
            labels = inputs.clone()

        
        if process_base_loss:
            base_outputs_loss = precalculated_base_outputs.pop(0)
        else:
            base_outputs_loss = torch.zeros(size=(1,))
        outputs_loss = model(inputs, labels=labels).loss

        loss, overshoot_penalty = combined_loss(outputs_loss, base_outputs_loss)
        loss = loss / acc_batch_size

        if (not ignore_overshot_samples or overshoot_penalty <= 0.0) and outputs_loss.item() >= ignore_below:
            if not simple_loss and loss.item() > ((loss_eps / acc_batch_size)+1e-8):
                unfit_samples += -1
                if bad_sample_mult != 1.0:
                    loss = loss * bad_sample_mult

            loss.backward()
            
            accum_steps += 1
            true_steps_taken += 1
            if use_sam:
                sam_batch.append((inputs, labels))
                sam_saved_base_outputs.append(base_outputs_loss)
        else:
            fit_samples += 1

        outputs_loss_item = outputs_loss.detach().item()
        if process_base_loss:
            base_loss_item = base_outputs_loss.item()
        else:
            if eval_base_loss is not None:
                base_loss_item = sum(eval_base_loss) / len(eval_d)
            else:
                base_loss_item = 0.0

        epoch_base_loss += base_loss_item
        lora_diff += (outputs_loss_item - base_loss_item)
        epoch_loss += max(loss.detach().item() * acc_batch_size, 0.0)
        epoch_wr += 100.0 if outputs_loss_item < (base_loss_item * (1.0 - true_eps)) else 0.0
        epoch_wr += 50.0 if outputs_loss_item == (base_loss_item * (1.0 - true_eps)) else 0.0
        epoch_0eps_wr += 100.0 if outputs_loss_item < base_loss_item else 0.0
        epoch_0eps_wr += 50.0 if outputs_loss_item == base_loss_item else 0.0
        epoch_overshoot += overshoot_penalty

        if accum_steps == acc_batch_size:
            if not use_sam:
                if manual_grad_clip_norm is not None:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), manual_grad_clip_norm)
                optimizer.step()
                optimizer.zero_grad()
            else:
                sub_steps = 1
                sam_optimizer.first_step(zero_grad=True)
                
                for inputs, labels in sam_batch:
                    base_outputs_loss = sam_saved_base_outputs.pop(0)
                    outputs = model(inputs, labels=labels)

                    loss, overshoot_penalty = combined_loss(outputs.loss, base_outputs_loss)
                    if loss.item() > ((loss_eps / acc_batch_size)+1e-8):
                        if bad_sample_mult is not None and bad_sample_mult != 1.0:
                            loss = loss * bad_sample_mult
                    loss = loss / accum_steps
                    
                    loss.backward()

                    if sub_steps % (acc_batch_size // intermed_check_step_split) == 0:
                        print(".", end="")
                    sub_steps += 1
                    
                    if excessive_cache_clearing:
                        gc.collect(); torch.cuda.empty_cache()


                sam_optimizer.second_step(zero_grad=True)
                sam_batch = []

            if average_stats:
                stat_steps = steps_so_far
            else:
                stat_steps = accum_steps
                stat_steps += fit_samples
            print(f"Step {steps_so_far}/{len(train_d)}\tLoss: {epoch_loss/stat_steps:.6f}",
                                                    f"OShL: {epoch_overshoot/stat_steps:.3e}"
                                                    f"\tBase: {epoch_base_loss/stat_steps:.4f}",
                                                    f"Diff: {lora_diff/stat_steps:.4e}",
                                                    f"\tWR: {epoch_wr/stat_steps:2.2f}%",
                                                    f"0eps: {epoch_0eps_wr/stat_steps:2.2f}% ",
                                                    f"\tLR: {lr_scheduler.get_last_lr()[0]:.2e}",
                                                    # f"eps: {loss_eps:.2e}",
                                                    f"fit: {fit_samples}/{unfit_samples}"
                                                    )

            if not average_stats:
                epoch_overshoot = 0
                epoch_loss = 0
                epoch_base_loss = 0
                lora_diff = 0
                epoch_wr = 0
                epoch_0eps_wr = 0
            unfit_samples = 0
            fit_samples = 0
            accum_steps = 0
            
            lr_scheduler.step()
            if lr_scheduler.get_last_lr()[0] == 0.0:
                lr_scheduler.step()

            gc.collect(); torch.cuda.empty_cache()


        if true_steps_taken % eval_steps == 0 and len(train_d) > 0 and true_steps_taken != last_tst:
            if remerging:
                if remerge_eval:
                    evaluate(model, eval_d, base_model=base_model, device=training_device, instruction_finetuning=instruction_finetuing, 
                            true_eps=true_eps, cached_base_loss=eval_base_loss, precompute_base_loss=True)
                model = model.to("cpu")
                model = merge(model_prev, model, ratio=remerge_ratio)
                model.save_pretrained("model_prev")
                model_prev = AutoModelForCausalLM.from_pretrained("model_prev", **params)
                model = model.to(training_device)
            evaluate(model, eval_d, base_model=base_model, device=training_device, instruction_finetuning=instruction_finetuing, 
                    true_eps=true_eps, cached_base_loss=eval_base_loss, precompute_base_loss=True)
            if do_save:
                model.save_pretrained(save_name + '_' + str((true_steps_taken // eval_steps) + save_n_start).format("02d"))
            model.train()
        
        steps_so_far += 1
        last_tst = true_steps_taken

        if excessive_cache_clearing:
            gc.collect(); torch.cuda.empty_cache()
            
        if time.time() - notebook_start_time > 5.75*60*60:
            break

    if do_save:
        if save_n_start > 0:
            model.save_pretrained(save_name+"_X"+str(save_n_start))
        else:
            model.save_pretrained(save_name)

    model.eval()
    final_eval_stats = evaluate(model, eval_d, return_stats=True, base_model=base_model, device=training_device, 
                                instruction_finetuning=instruction_finetuing, true_eps=true_eps,
                     cached_base_loss=eval_base_loss, precompute_base_loss=True)

    model = model.to("cpu")
    gc.collect(); torch.cuda.empty_cache()

    return final_eval_stats

In [4]:
# lora_name "Carboniferous"
lora_name = "Burdigalian"
model_name = "MesozoicMetallurgist/nous-Aquitanian" #
# model_name = "fulltune"
# lora_name = "Helium"
model_name_to_beat = "MesozoicMetallurgist/nous-Aquitanian"

neft_noise = 0.0 # bad actually?

# rank = 512
# config = LoraConfig(
#     r=rank, lora_alpha=rank*2,
#     target_modules=[
#                     'q_proj',
#                     'v_proj', 
#                     "k_proj", 
#                     "o_proj", 
#                     # "gate_proj", 
#                     # "up_proj", 
#                     # "down_proj"
#                     ],  #   , 
#     lora_dropout=0.0,
#     bias="none", task_type="CAUSAL_LM",
#     # use_rslora=True,
#     use_dora=True,
#     # init_lora_weights="gaussian",
# )

params = {
    'low_cpu_mem_usage': True,
    'trust_remote_code': False,
    'torch_dtype': torch.bfloat16,
    'use_safetensors': True,
    'attn_implementation': "flash_attention_2"
}


model = AutoModelForCausalLM.from_pretrained(model_name, **params, cache_dir="Models")
# model.config.bos_token_id = 2
# model.config.eos_token_id = 1
# model.generation_config.bos_token_id = 2
# model.generation_config.eos_token_id = 1
model = model.to("cuda")
tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-2-zephyr-1_6b", trust_remote_code=False, use_fast=True, cache_dir="Models")


def norm_model_weights(model):
    last_q = None
    lqb = None
    lqkm = None
    last_v = None
    lvb = None
    lvom = None
    last_up = None
    for name, param in model.named_parameters():
        if "q_proj" in name:
            if "bias" in name:
                lqb = param
            else:
                last_q = param
        if "k_proj" in name:
            if "bias" in name:
                param.data = param.data * lqkm
            else:
                # print(last_q.data.shape, param.data.shape)

                # safe but not full solution
                # last_q.data = last_q.data.to(torch.float64)
                # param.data = param.data.to(torch.float64)
                mult = torch.sqrt(torch.mean(torch.abs(last_q.data), dim=0, keepdim=True).transpose(0, 1) / 
                                torch.mean(torch.abs(param.data), dim=1, keepdim=True))
                mult = torch.mean(mult)
                last_q.data = last_q.data / mult
                lqb.data = lqb.data / mult
                param.data = param.data * mult
                lqkm = mult
                # last_q.data = last_q.data.to(torch.bfloat16)
                # param.data = param.data.to(torch.bfloat16)
                #

                # THIS ONE STAYS
                # last_q.data = last_q.data.to(torch.float64)
                # param.data = param.data.to(torch.float64)
                # mult = torch.sqrt(torch.mean(torch.abs(last_q.data), dim=1, keepdim=True) / 
                #                   torch.mean(torch.abs(param.data.repeat(8, 1)), dim=1, keepdim=True))
                # # print(mult.shape, mult)
                # last_q.data = last_q.data / mult
                # param.data = param.data * ((mult[:256] + mult[256:512] + mult[512:768] + mult[768:1024] +
                #                            mult[1024:1280] + mult[1280:1536] + mult[1536:1792] + mult[1792:2048]) / 8)
                # last_q.data = last_q.data.to(torch.bfloat16)
                # param.data = param.data.to(torch.bfloat16)

        if "v_proj" in name:
            if "bias" in name:
                lvb = param
            else:
                last_v = param
        if "o_proj" in name:
            if "bias" in name:
                param.data = param.data * lvom
            else:
                # print(last_v.data.shape, param.data.shape)

                # safe but not full solution
                # last_v.data = last_v.data.to(torch.float64)
                # param.data = param.data.to(torch.float64)
                mult = torch.sqrt(torch.mean(torch.abs(last_v.data), dim=0, keepdim=True) / 
                                torch.mean(torch.abs(param.data), dim=0, keepdim=True))
                mult = torch.mean(mult) * -1
                last_v.data = last_v.data / mult
                lvb.data = lvb.data / mult
                param.data = param.data * mult
                lvom = mult
                # last_v.data = last_v.data.to(torch.bfloat16)
                # param.data = param.data.to(torch.bfloat16)
                # 

                # mult = torch.sqrt(torch.mean(torch.abs(last_v.data.repeat(8, 1)), dim=0, keepdim=True) / 
                #                   torch.mean(torch.abs(param.data), dim=0, keepdim=True)) / 2
                # print(mult.shape, mult)
                # last_v.data = last_v.data / mult # ((mult[:256] + mult[256:512] + mult[512:768] + mult[768:1024] +
                #                            #mult[1024:1280] + mult[1280:1536] + mult[1536:1792] + mult[1792:2048]) / 8)
                # param.data = param.data * mult

        if "up_proj" in name:
            last_up = param
        if "down_proj" in name:
            # print(last_up.data.shape, param.data.shape)
            # last_up.data = last_up.data.to(torch.float64)
            # param.data = param.data.to(torch.float64)
            mult = torch.sqrt(torch.mean(torch.abs(last_up.data), dim=1, keepdim=True).transpose(0, 1) / 
                            torch.mean(torch.abs(param.data), dim=0, keepdim=True))
            # print(mult, mult.shape)
            last_up.data = last_up.data / mult.transpose(0, 1)
            param.data = param.data * mult
            # last_up.data = last_up.data.to(torch.bfloat16)
            # param.data = param.data.to(torch.bfloat16)
    return model

model = norm_model_weights(model)
        

# model.save_pretrained("Models/fixedscaling")
base_model = None
base_model = AutoModelForCausalLM.from_pretrained(model_name_to_beat, **params, cache_dir="Models")
for name, param in base_model.named_parameters():
    param.requires_grad = False

lora_model = model
for name, param in lora_model.named_parameters():
    if (
        "embed_" not in name
          and "up_proj" not in name and "down_proj" not in name
          and "gate_proj" not in name 
        # and "q_proj" not in name and "k_proj" not in name 
        # and "v_proj" not in name and "o_proj" not in name 
    ):
        param.requires_grad = True
    else:
        param.requires_grad = False

# lora_model = PeftModel.from_pretrained(model, model_id="Ypresian_", is_trainable=True)
# lora_model = lora_model.merge_and_unload(progressbar=True)
# lora_model = get_peft_model(model, config)
# lora_model.print_trainable_parameters()
# tokenizer = AutoTokenizer.from_pretrained("MesozoicMetallurgist/nous-Hauterivian", trust_remote_code=False, use_fast=True, cache_dir="Models")
lora_model = lora_model.to("cpu")

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


In [None]:
for name, param in lora_model.named_parameters():
    if "lora" not in name:
        print(name, param, param.data.shape)
# for name, param in lora_model.named_parameters():
#     if "lora" in name:
#         print(name, param)

In [6]:
# train_data, eval_data = get_data(train_name=['data/cortex_686_8704.json',
# ],
#                                  train_subset=np.arange(0,32),
#                                  eval_subset=32,
#                                  eval_name='data/cortex_686_8704.json', tokenizer=tokenizer, shuffle=False)
# evaluate(lora_model, eval_data, return_to_cpu=True, print_stats=True, base_model=base_model)

In [7]:
train_data = []
eval_data = []
train(lora_model, tokenizer, train_data, eval_data, base_model=base_model, inf_training=False, training_device="cuda",
        acc_batch_size=512, instruction_finetuing=True, precalculate_batch_mult=2.25,
        lr=3.2e-5, weight_decay=0.0, lr_scheduler="cosine", warmup_steps=4, warmup_end_offset=8, betas=(0.8, 0.95),
        use_sam=False, sam_rho=0.05, adaptive_sam=False, wsam_variant=True,
        opt="adamw",
        loss_eps = 0.02, overshoot_buffer = -0.01, true_eps=0.01, ignore_below=0.0, ignore_auto_percent=1,
        remerging=True, remerge_eval=True, remerge_ratio=0.3,
        simple_loss=True, process_base_loss=True, precalc_eval_base=True, relative_loss=False,
        manual_grad_clip_norm=1.0, sam_grad_clip_norm=None,
        add_overshoot_penalty=False, ignore_overshot_samples=True, bad_sample_mult=1.0,
        eval_steps=2048, do_save=False, save_name=lora_name, save_n_start=0,
        average_stats=False,
        gradient_checkpointing=False, excessive_cache_clearing=False)
        # gradient_checkpointing=True, excessive_cache_clearing=True)

WARN: no training data provided, enabling infinite training
WARN: remerging is enabled, model_prev created and may take up extra memory
WARN: no evaluation data provided, acquiring new data
WARN: no training data provided, acquiring new data
WARN: simple loss is enabled, this will disable base model processing
Base loss processing is disabled, disabling ignore overshot samples
Note: precalced eval base loss does not account for pretrained fine-tuning
........Eval Base Loss: 0.933087
........Step 512/1792	Loss: 0.763006 OShL: 0.000e+00	Base: 0.9331 Diff: -1.7008e-01 	WR: 72.46% 0eps: 74.41%  	LR: 3.13e-05 fit: 0/0
........Step 1024/1280	Loss: 0.779480 OShL: 0.000e+00	Base: 0.9331 Diff: -1.5361e-01 	WR: 75.20% 0eps: 76.17%  	LR: 6.25e-05 fit: 0/0
........Step 1536/1920	Loss: 0.859848 OShL: 0.000e+00	Base: 0.9331 Diff: -7.3240e-02 	WR: 63.28% 0eps: 64.26%  	LR: 9.38e-05 fit: 0/0
........Step 2048/1408	Loss: 0.910329 OShL: 0.000e+00	Base: 0.9331 Diff: -2.2758e-02 	WR: 59.57% 0eps: 61.33%  

KeyboardInterrupt: 

In [12]:
lora_model.save_pretrained(lora_name)

config.json:   0%|          | 0.00/678 [00:00<?, ?B/s]

In [9]:
evaluate(lora_model, eval_data, return_to_cpu=True, print_stats=True, base_model=base_model)

Evaluating ........ Loss: 0.00578544, Base Loss: 0.864540, Lora Diff: -0.00556602, WR: 33.86%, 0epsWR: 80.78%, OShL: 0.00163308


In [13]:
len(train_data)

11173

In [12]:
len(eval_data)

435

In [10]:
lora_model.save_pretrained(lora_name)

config.json:   0%|          | 0.00/659 [00:00<?, ?B/s]

In [11]:
lora_model = lora_model.to("cuda")
lora_model = lora_model.merge_and_unload()
# lora_model = norm_model_weights(lora_model)
# simple_eval(lora_model, eval_data)
lora_model = lora_model.to("cpu")

In [12]:
lora_model.config.name_or_path = "MesozoicMetallurgist/new_model"
model_dir = "Models/merged_model"
# model_dir = os.path.expanduser("~/finetuning-subnet/merged_model")
if not os.path.exists(model_dir):
    os.makedirs(model_dir, exist_ok=True)
else:
    # wipe the directory
    for file in os.listdir(model_dir):
        os.remove(os.path.join(model_dir, file))
lora_model.save_pretrained(save_directory=model_dir, safe_serialization=True)
tokenizer.save_pretrained(save_directory=model_dir)
lora_model = lora_model.to("cpu")
gc.collect(); torch.cuda.empty_cache()

In [8]:
evaluate(lora_model, eval_data, return_to_cpu=True, print_stats=True)

Evaluating ........ Loss: 0.00798187, Base Loss: 0.557989, Lora Diff: -0.00123047, WR: 18.10%, 0epsWR: 66.80%, OShL: 0.00037500


In [5]:
# lora_model = AutoModelForCausalLM.from_pretrained(model_name, **params, cache_dir="Models")

lora_model = lora_model.to("cuda")
# lora_model = lora_model.merge_and_unload()
# model_dir = "Models/merged_model"
# lora_model.save_pretrained(save_directory=model_dir, safe_serialization=True)
# tokenizer.save_pretrained(save_directory=model_dir)

last_q = None
last_v = None
last_up = None
for name, param  in lora_model.named_parameters():
    # if "q_proj" in name:
    #     last_q = param
    # elif "k_proj" in name:
    #     # print(last_q.data.shape, param.data.shape)
    #     # mult = (torch.mean(torch.abs(last_q.data), dim=0, keepdim=True) / torch.mean(torch.abs(param.data), dim=0)) ** 0.5
    #     mult = (torch.mean(torch.abs(last_q.data)) / torch.mean(torch.abs(param.data))) ** 0.5
    #     mult = 100.0
    #     # print(mult.shape)
    #     last_q.data = last_q.data / mult#.transpose(0, 1)
    #     param.data = param.data * mult
    #     last_q = None
    #     # print(mult)
    # if "v_proj" in name:
    #     last_v = param
    # elif "o_proj" in name:
    #     # print(last_v.data.shape, param.data.shape)
    #     # mult = (torch.mean(torch.abs(last_v.data), dim=0, keepdim=True) / torch.mean(torch.abs(param.data), dim=1)) ** 0.5
    #     mult = (torch.mean(torch.abs(last_v.data)) / torch.mean(torch.abs(param.data))) ** 0.5
    #     mult = 100.0
    #     # print(mult.shape)
    #     last_v.data = last_v.data / mult
    #     param.data = param.data * mult#.transpose(0, 1)
    #     last_v = None
        # print(mult)
    # get current random seed from torch
    iseed = torch.initial_seed()
    # set the random seed to a fixed value
    torch.manual_seed(42)
    if "up_proj" in name:
        last_up = param
    elif "down_proj" in name:
        # print(last_up.data.shape, param.data.shape)
        mult = (torch.mean(torch.abs(last_up.data), dim=1, keepdim=True).transpose(0, 1) / torch.mean(torch.abs(param.data), dim=0)) ** 0.5
        mult = torch.randint_like(mult, 1, 128, dtype=torch.bfloat16)
        # print(mult.shape)
        last_up.data = last_up.data / mult.transpose(0, 1)
        param.data = param.data * mult
        last_up = None
        # print(mult)
    # set the random seed back to the original value
    torch.manual_seed(iseed)

model_dir = "Models/merged_model_fuckkery"
lora_model.save_pretrained(save_directory=model_dir, safe_serialization=True)
tokenizer.save_pretrained(save_directory=model_dir)
lora_model = lora_model.to("cpu")

In [18]:
lora_model = get_peft_model(model, config)
evaluate(lora_model, eval_data, return_to_cpu=True, print_stats=True)

Evaluating ........ Loss: 0.01010798, Base Loss: 0.707643, Lora Diff: 0.00000000, WR: 0.00%, 0epsWR: 50.00%, OShL: 0.00000000


In [20]:
lora_model = lora_model.to("cuda")
lora_model = lora_model.merge_and_unload()
lora_model.save_pretrained(lora_name+'fuckkery')