In this notebook, we will work on an example of estimating the fine-tuning loss on the Alpaca dataset, using the Llama-3-8B model and LoRA. 

We will first apply our estimation technique to estimate the LoRA paramters on top of the Llama-3-8B model and use it to estimate the loss on randomly sampled subsets. Then, we can compare them with true fine-tuning results to evaluate the approximation.  

In [None]:
import os
import torch
import pandas as pd
import numpy as np
import pytorch_lightning as pl

from alpaca_data_module import AlpacaDataModule
from alpaca_model import AlpacaModel
from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig
from adapters import AutoAdapterModel, DoubleSeqBnConfig

### Obtain a meta-initialization by multitask training on all data

First, we obtain a meta-initialization on all data from the Alpaca dataset. One can follow the script to fine-tune a Llama-3-8B model: 

```
python custom_train_instruction.py --model_key "meta-llama/Llama-3.1-8B"  \
    --lr 2e-5 --batch_size 4 --max_length 256 --epochs 10\
    --train_lora --lora_rank 16 --lora_alpha 128\
    --strategy auto --devices 0 --runs 1 --accumulate 1 --precision "bf16-true" 
```

Here, we will provide a fine-tuned checkpoint (`meta_initialization.pt` under this folder), so one can skip the meta-training and go directly to the estimation. 

**Gradient evaluation** on the meta-initialization: 

Next, we will load the fine-tuned model as the meta-initialization to evaluate gradients on all training samples. The gradients will be used to conduct the estimation later.  


In [None]:
# Define constants
class args:
    model_key = "meta-llama/Llama-3.1-8B" 
    train_lora = True
    lora_rank = 16
    lora_alpha = 128
    use_qlora = False
    use_qadapter = False
    use_3bit = False
    use_2bit = False
    train_adapter = False
    reduction_factor = 128
    devices = [0]

    # data contants
    max_length = 256
    batch_size = 4
    inference_batch_size = 4
    downsample = 1
    

def initialize_model(args):
    model_key = args.model_key.replace("/", "-").replace("..", "")
    if "gpt" in args.model_key or "Llama" in model_key \
        or "bloomz" in model_key or "gemma" in model_key or "Mistral" in model_key:
        hf_key = args.model_key.replace("_", "-")
        tokenizer = AutoTokenizer.from_pretrained(hf_key)
        tokenizer.padding_side = 'right'
        if args.use_qlora:
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type='nf4'
                )
            model = AutoModelForCausalLM.from_pretrained(hf_key, quantization_config=quantization_config, torch_dtype=torch.bfloat16, device_map={"": args.devices[0]}) #
        else:
            model = AutoModelForCausalLM.from_pretrained(hf_key)
        model_type = "decoder"
        append_eos = True
    elif "flan" in model_key:
        hf_key = "google/{}".format(model_key.replace("_", "-"))
        model = AutoModelForSeq2SeqLM.from_pretrained(hf_key)
        tokenizer = AutoTokenizer.from_pretrained(hf_key, model_max_length=512)
        model_type = "encoder_decoder"
        append_eos = False  # t5 tokenizers already append eos
    else:
        raise NotImplementedError(args.model_key)
    
    
    if args.train_adapter:
        
        if args.use_qadapter:
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type='nf4' 
            )

            model = AutoAdapterModel.from_pretrained(
                hf_key, 
                quantization_config=quantization_config, 
                torch_dtype=torch.bfloat16, 
                device_map={"": args.devices[0]}
            )
        
        else: model = AutoAdapterModel.from_pretrained(hf_key)

        bottleneck_config = DoubleSeqBnConfig(
            mh_adapter=True,    
            output_adapter=True,    
            reduction_factor=args.reduction_factor,     
            non_linearity="relu"     
        )

        model.add_adapter(adapter_name="seq_bn",config=bottleneck_config)

        for name, param in model.named_parameters():
            if "adapter" not in name:
                param.requires_grad = False

        model.set_active_adapters("seq_bn")
        trainable_params_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
        all_params_count = sum(p.numel() for p in model.parameters())

        print(f"Trainable parameters: {trainable_params_count} || All parameters: {all_params_count} || ratio: {trainable_params_count/all_params_count}")
        print("-"*20,"Bottleneck_Adapter","-"*20)

    
    if args.use_3bit or args.use_2bit:
        from src.lqlora_utils import lora_utils
        model = lora_utils.prepare_model_for_lora(
            model=model,
            num_ranks=args.lora_rank,
            lora_alpha=args.lora_alpha,
            lora_dropout=0.1,
            use_gradient_checkpointing=True)

        lora_utils.transform_lora_layers(
            lpq=False,
            model=model,
            model_name="nf3" if args.use_3bit else "nf2",
            device=f"cuda:{args.devices[0]}")
        model.to(f"cuda:{args.devices[0]}")        

    elif args.train_lora:
        if args.model_key == "gpt2": # for gpt2, we generally use full model
            config = LoraConfig(
                r=args.lora_rank,
                lora_alpha=args.lora_alpha,
                target_modules=["c_attn", "c_proj", "c_fc"],
                lora_dropout=0.1,
                bias="lora_only",
                modules_to_save=[],
            )
        elif args.model_key == "EleutherAI/gpt-neox-20b":
            config = LoraConfig(
                r=args.lora_rank,
                lora_alpha=args.lora_alpha,
                target_modules=["query_key_value"],
                lora_dropout=0.1,
                bias="lora_only",
                modules_to_save=[],
            )
        elif "flan" in args.model_key:
            config = LoraConfig(
                r=args.lora_rank,
                lora_alpha=args.lora_alpha,
                target_modules=["q", "k", "v"],
                lora_dropout=0.1,
                bias="lora_only",
                modules_to_save=[],
            )
        else:
            config = LoraConfig(
                r=args.lora_rank,
                lora_alpha=args.lora_alpha,
                target_modules=["q_proj", "k_proj", "v_proj"],
                lora_dropout=0.1,
                bias="lora_only",
                modules_to_save=[],
            )
        model = get_peft_model(model, config)
        model.print_trainable_parameters()
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer, hf_key, model_type, append_eos


# initialize model 
model_key = args.model_key.replace("/", "-").replace("..", "")
model, tokenizer, hf_key, model_type, append_eos = initialize_model(args)
model.load_state_dict(torch.load("meta_initialization.pt"), strict=False)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

trainable params: 9,437,184 || all params: 8,039,698,432 || trainable%: 0.11738231327729483


  model.load_state_dict(torch.load("meta_initialization.pt"), strict=False)


_IncompatibleKeys(missing_keys=['base_model.model.model.embed_tokens.weight', 'base_model.model.model.layers.0.self_attn.q_proj.weight', 'base_model.model.model.layers.0.self_attn.k_proj.weight', 'base_model.model.model.layers.0.self_attn.v_proj.weight', 'base_model.model.model.layers.0.self_attn.o_proj.weight', 'base_model.model.model.layers.0.mlp.gate_proj.weight', 'base_model.model.model.layers.0.mlp.up_proj.weight', 'base_model.model.model.layers.0.mlp.down_proj.weight', 'base_model.model.model.layers.0.input_layernorm.weight', 'base_model.model.model.layers.0.post_attention_layernorm.weight', 'base_model.model.model.layers.1.self_attn.q_proj.weight', 'base_model.model.model.layers.1.self_attn.k_proj.weight', 'base_model.model.model.layers.1.self_attn.v_proj.weight', 'base_model.model.model.layers.1.self_attn.o_proj.weight', 'base_model.model.model.layers.1.mlp.gate_proj.weight', 'base_model.model.model.layers.1.mlp.up_proj.weight', 'base_model.model.model.layers.1.mlp.down_proj.we

In [None]:
# Initialize dataset
task_idxes = list(range(38))
data_module = AlpacaDataModule(tokenizer=tokenizer,
                    data_path="./data/alpaca_final.pkl",
                    dev_split_path="./data/alpaca_dev_split_map.pkl",
                    task_idxes=task_idxes,
                    batch_size = args.batch_size,
                    inference_batch_size = args.inference_batch_size,
                    context_length=args.max_length,
                    model_type=model_type)
data_module.setup(stage="fit")

In [None]:
# Generate gradients on the meta-initialization
# Let's evaluate the gradients on all training samples and project them to a dimension of 100
args.project_dimension = 100
args.run = 0

gradient_dir = "Alpaca_{}".format(model_key) + (f"_lora_r_{args.lora_rank}" if args.train_lora else "") \
                 + f"_dim_{args.project_dimension}_run_{args.run}" 
print("Directory for saving gradients", gradient_dir)

lm = AlpacaModel(model=model, tokenizer=tokenizer, model_type=model_type,
                lr=2e-5, weight_decay=0, max_length=args.max_length, use_wandb=False,
                intialize_project_matrix=True, run_seed=args.run, 
                project_dim=args.project_dimension, gradient_dir=gradient_dir, use_sgd=True)


default_root_dir = "./external_lightning_logs/" # This is for creating a new directory
if not os.path.exists(default_root_dir):
        os.makedirs(default_root_dir)

trainer = pl.Trainer(accelerator="gpu", devices=args.devices, strategy="auto",
                    default_root_dir=default_root_dir, min_epochs=0, max_epochs=0,
                    accumulate_grad_batches=1, precision="bf16-true",
                    enable_checkpointing=True, inference_mode=False
        )

state_dict = {key: val.clone().to("cpu") for key, val in model.state_dict().items() if 'absmax' not in key and 'quant' not in key}

# We use trainer to call the predict_step() function defined within the AlpacaModel
# Please refer to the definition of the AlpacaModel for the predict_step() function
trainer.predict(lm, dataloaders=data_module.train_dataloader())

### Estimation: Solving logistic regression using gradients as features

Nest, we will use the gradients as features to solve a logistic regression problem. Then, the logistic regression coefficients are used as the estimated paramters on a subset of tasks. 

Notice that in the logistic regression, setting the regularization parameter is crucial in order to control the norm of model fine-tuned weights. It usually needs to be tuned so that the estimted loss is in a reasonable range. 

In [None]:
from sklearn.linear_model import LogisticRegression

# Perform estimation
def generate_state_dict(model, state_dict, coef, device="cpu", removing_keys = ["shared", "lm_head", "wte", "wpe", "ln", "embed_tokens", "norm", "word_embeddings"]):
    new_state_dict = {}; cur_len = 0
    for key, param in model.named_parameters():
        if not param.requires_grad: continue
        param_len = param.numel()
        if any([rkey in key for rkey in removing_keys]):
            continue
            # new_state_dict[key] = state_dict[key].clone()
        else:
            new_state_dict[key] = state_dict[key].clone().to(device) + \
                torch.FloatTensor(coef[cur_len:cur_len+param_len].reshape(param.shape)).to(device)
            cur_len += param_len
    return new_state_dict

def compute_norm(state_dict, use_lora = True, removing_keys = ["shared", "lm_head", "wte", "wpe", "ln", "embed_tokens", "norm", "word_embeddings"]):
    norm = 0
    for key, val in state_dict.items():
        if use_lora:
            if "lora" in key:
                norm += val.clone().square().sum().item()
        else:
            if any([rkey in key for rkey in removing_keys]):
                    continue
            norm += val.clone().square().sum().item()
    return np.math.sqrt(norm)

# Key function to solve logistic regression
def evaluate_subset(args, trainer, lm, data_module, data_idxes, state_dict, projection_matrix, gradient_dir):
    # collect gradients for the subset
    gradients = []
    for idx in data_idxes:
        gradient_file_idx = idx // args.batch_size
        gradient_file = f"{gradient_dir}/train_batch_{gradient_file_idx}_gradients.npy"
        if os.path.exists(gradient_file):
            tmp_gradients = np.load(gradient_file)
            gradients.append(tmp_gradients[idx % args.batch_size])
    gradients = np.array(gradients)
    if len(gradients) == 0:
        return {}
    
    # randomly assign labels as 0 or 1
    labels = np.random.binomial(n=1, p=0.7, size=gradients.shape[0])
    # reverse the gradients for the 0 labels
    mask = np.copy(labels)
    mask[labels == 0] = -1
    mask = mask.reshape(-1, 1)
    gradients = gradients*mask
    train_gradients, train_labels = gradients[:], labels[:]

    # train a logistic regression model
    clf = LogisticRegression(random_state=0, penalty='l2', C=1e-4, solver='liblinear') 
    clf.fit(train_gradients, train_labels)
    print("Linear regression score: ", clf.score(train_gradients, train_labels))
    proj_coef = clf.coef_.copy().flatten().reshape(-1, 1)
    coef = projection_matrix @ proj_coef.flatten()
    print("L2 norm of estimated parameters", np.linalg.norm(coef))

    new_state_dict = generate_state_dict(lm.model, state_dict, coef, device=lm.model.device)
    pretrain_state_dict = state_dict
    finetuned_state_dict = new_state_dict
    lm.model.load_state_dict(pretrain_state_dict)
    lm.model.load_state_dict(finetuned_state_dict, strict=False)

    summary = trainer.validate(lm, datamodule=data_module)[0]
    return summary

In [None]:
args.number_of_subsets = 10 # Let's sample 10 subsets
args.subset_size = 0.5 # Let's sample 50% of the tasks

project_matrix = lm.project_matrix 
gradient_dir = lm.gradient_dir

def add_result_to_csv(result_datapoint, file_name):
    for key, val in result_datapoint.items():
        result_datapoint[key] = [val, ]
    
    if os.path.exists(file_name):
        result_df = pd.read_csv(file_name, index_col=0)
        tmp_df = pd.DataFrame(result_datapoint)
        result_df = pd.concat([result_df, tmp_df], ignore_index = True)
        result_df.to_csv(file_name)
    else:
        result_df = pd.DataFrame(result_datapoint)  
        result_df.to_csv(file_name) 

for k in range(args.number_of_subsets):
    # sample a subset of  tasks
    train_dataset = data_module.train_dataset
    skills = [tmp_data['skill'] for tmp_data in train_dataset.data]
    skill_list = data_module.skills
    task_num = len(skill_list)

    subset_idxes = np.random.choice(task_num, int(args.subset_size*task_num), replace=False)
    subset_idxes.sort()
    tmp_skill_list = [skill_list[i] for i in subset_idxes]
    data_idxes = [i for i in range(len(skills)) if skills[i] in tmp_skill_list]
    # Perform estimation on the subset of tasks
    summary = evaluate_subset(args, trainer, lm, data_module, data_idxes, state_dict, project_matrix, gradient_dir)
    if not summary:
        continue

    # Write the evaluation results to a csv file
    result_datapoint = {
        "Data indices": " ".join([str(idx) for idx in subset_idxes])
    }
    for key, val in summary.items():
        result_datapoint[key] = val
    file_name = "estimation_results.csv"
    add_result_to_csv(result_datapoint, file_name)

Linear regression score:  1.0
L2 norm of estimated parameters 13.67517478148466


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]
/home/ldy/miniconda3/envs/llama-env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.


Validation: |          | 0/? [00:00<?, ?it/s]