In [None]:
from transformers import LlamaForCausalLM, LlamaTokenizer
import torch
import torch.nn as nn
from typing import List, Optional, Union
from peft import LoraConfig
from dataclasses import asdict, dataclass, field
import torch.nn.functional as F
import re
import warnings
import math
import copy
from enum import Enum
from datasets import load_dataset
from utils.prompter import Prompter
import transformers
import os
import sys

base_model = "/root/llama-7b-hf"  # the only required argument
data_path = "train_data_3_class_clean.jsonl"
output_dir = "/root/autodl-tmp/output2s"
# training hyperparams
batch_size = 128
micro_batch_size = 4
num_epochs = 1000
learning_rate = 3e-4
cutoff_len = 256
val_set_size = 0
# lora hyperparams
lora_r = 8
lora_alpha = 16
lora_dropout = 0.05
lora_target_modules = [
    "q_proj",
    "v_proj",
]
# llm hyperparams
train_on_inputs = True  # if False, masks out inputs in loss
add_eos_token = False
group_by_length = False  # faster, but produces an odd training loss curve
# wandb params
wandb_project = ""
wandb_run_name = ""
wandb_watch = ""  # options: false | gradients | all
wandb_log_model = ""  # options: false | true
# resume_from_checkpoint = '/root/autodl-tmp/output/checkpoint-100'  # either training checkpoint or final adapter
resume_from_checkpoint=None
prompt_template_name = "alpaca"  # The prompt template to use, will default to alpaca.
device_map = "auto"
gradient_accumulation_steps = batch_size // micro_batch_size

world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
use_wandb = len(wandb_project) > 0 or (
    "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
)

In [None]:
prompter = Prompter(prompt_template_name)

In [None]:
model = LlamaForCausalLM.from_pretrained(
        base_model,
        load_in_8bit=False,
        torch_dtype=torch.float16,
        device_map=device_map,
    )
tokenizer = LlamaTokenizer.from_pretrained(base_model)

tokenizer.pad_token_id = (
    0  # unk. we want this to be different from the eos token
 )
tokenizer.padding_side = "left"  # Allow batched inference

In [None]:
import importlib
import random
def is_bnb_available():
    return importlib.util.find_spec("bitsandbytes") is not None
if is_bnb_available():
    import bitsandbytes as bnb
class LoraLayer:
    def __init__(
        self,
        in_features: int,
        out_features: int,
    ):
        self.r = {}
        self.lora_alpha = {}
        self.scaling = {}
        
        self.lora_dropout = nn.ModuleDict({})
        # self.lora_A = nn.ModuleDict({})
        # self.lora_B = nn.ModuleDict({})
        # For Embedding layer
        self.lora_embedding_A = nn.ParameterDict({})
        self.lora_embedding_B = nn.ParameterDict({})
        # Mark the weight as unmerged
        self.merged = False
        self.disable_adapters = False
        self.in_features = in_features
        self.out_features = out_features
        # self.gate = nn.Linear(in_features, 4)

    def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
        self.r[adapter_name] = r
        self.lora_alpha[adapter_name] = lora_alpha
        if lora_dropout > 0.0:
            lora_dropout_layer = nn.Dropout(p=lora_dropout)
        else:
            lora_dropout_layer = nn.Identity()

        self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
        # Actual trainable parameters
        self.lora_num = 8
        if r > 0:
            # self.lora_A.update(nn.ModuleDict({adapter_name: nn.Linear(self.in_features, r, bias=False)}))
            # self.lora_B.update(nn.ModuleDict({adapter_name: nn.Linear(r, self.out_features, bias=False)}))
            self.lora_A = nn.ParameterDict()
            self.lora_B = nn.ParameterDict()
            for i in range(self.lora_num):
                self.lora_A[str(i)] = nn.Parameter(self.weight.new_zeros((r, self.in_features)))
                self.lora_B[str(i)] = nn.Parameter(self.weight.new_zeros((self.out_features, r)))
            # add gate
            self.lora_gate = nn.Parameter(self.weight.new_zeros((self.lora_num, self.in_features)))
            self.scaling[adapter_name] = lora_alpha / r
        if init_lora_weights:
            self.reset_lora_parameters(adapter_name)
        self.to(self.weight.device)

    def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
        self.r[adapter_name] = r
        self.lora_alpha[adapter_name] = lora_alpha
        if lora_dropout > 0.0:
            lora_dropout_layer = nn.Dropout(p=lora_dropout)
        else:
            lora_dropout_layer = nn.Identity()

        self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
        # Actual trainable parameters
        if r > 0:
            self.lora_embedding_A.update(
                nn.ParameterDict({adapter_name: nn.Parameter(self.weight.new_zeros((r, self.in_features)))})
            )
            self.lora_embedding_B.update(
                nn.ParameterDict({adapter_name: nn.Parameter(self.weight.new_zeros((self.out_features, r)))})
            )
            self.scaling[adapter_name] = lora_alpha / r
        if init_lora_weights:
            self.reset_lora_parameters(adapter_name)
        self.to(self.weight.device)

    def reset_lora_parameters(self, adapter_name):
        if adapter_name in self.lora_A.keys():
            # initialize A the same way as the default for nn.Linear and B to zero
            # nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5))
            # nn.init.zeros_(self.lora_B[adapter_name].weight)
            nn.init.kaiming_uniform_(self.lora_A[adapter_name], a=math.sqrt(5))
            nn.init.zeros_(self.lora_B[adapter_name])
        if adapter_name in self.lora_embedding_A.keys():
            # initialize a the same way as the default for nn.linear and b to zero
            nn.init.zeros_(self.lora_embedding_A[adapter_name])
            nn.init.normal_(self.lora_embedding_B[adapter_name])

def transpose(weight, fan_in_fan_out):
    return weight.T if fan_in_fan_out else weight

class SuperScalableLinear(torch.nn.Linear):
    def __init__(self, in_features, out_features, rank):
        super(SuperScalableLinear, self).__init__(in_features=in_features, out_features=out_features)
        config_A_B = [f'LoRA_{rank}', 'vector', 'constant', 'none']
        config_C = [f'LoRA_{rank}', 'vector', 'none']
        config_D_E = ['constant', 'none', 'vector']
        self.configs = []
        for A in config_A_B:
            for B in config_A_B:
                for C in config_C:
                    for D in config_D_E:
                        for E in config_D_E:
                            config = {'A':A,'B':B,'C':C,'D':D,'E':E}
                            self.configs.append(config)

        self.Ad, self.Au = self.make_param((out_features, in_features), f'LoRA_{rank}')
        self.Bd, self.Bu = self.make_param((out_features, in_features), f'LoRA_{rank}')
        self.Cd, self.Cu = self.make_param((in_features, 1), f'LoRA_{rank}')
        self.D = nn.Parameter(torch.zeros(out_features))
        self.E = nn.Parameter(torch.zeros(out_features))
        self.eval_config = None
        nn.init.xavier_uniform_(self.Au)
        nn.init.xavier_uniform_(self.Bu)
        nn.init.xavier_uniform_(self.Cu)
        self.to(self.weight.device)
    
    def prepare_path(self, config, Xd, Xu=None):
        if Xu is not None:
            if 'LoRA' in config:
                rank = int(config.split('_')[1])
                X = torch.matmul(Xd[:,:rank], Xu[:rank, :])
            elif 'vector' in config:
                X = Xd[:,0].unsqueeze(1)
            elif 'constant' in config:
                X = Xd[0,0]
            elif 'none' in config:
                X = torch.zeros(Xd.shape[0], Xu.shape[1]).cuda()
            else:
                raise ValueError
        else:
            if 'vector' in config:
                X = Xd
            elif 'constant' in config:
                X = Xd[0]
            elif 'none' in config:
                X = torch.zeros(1).cuda()
            else:
                raise ValueError
        return X
    
    def make_param(self, shape, config=None):
        if 'LoRA' in config:
            out_feature = shape[0]
            in_feature = shape[1]
            try:
                rank = int(config.split('_')[1])
            except:
                rank = 4
            return nn.Parameter(torch.zeros(out_feature, rank)), nn.Parameter(torch.zeros(rank, in_feature))
        return nn.Parameter(torch.zeros(*shape))
        
    def forward(self, input):
        if self.eval_config is not None:
            path_config = self.eval_config
        else:
            path_config = random.choice(self.configs)

        previous_dtype = input.dtype
        
        A = self.prepare_path(path_config['A'], self.Ad, self.Au)
        B = self.prepare_path(path_config['B'], self.Bd, self.Bu)
        C = self.prepare_path(path_config['C'], self.Cd, self.Cu)
        D = self.prepare_path(path_config['D'], self.D)
        E = self.prepare_path(path_config['E'], self.E)
        optimal_weight = self.weight + self.weight*A + B
        if torch.is_tensor(self.bias):
            optimal_bias = self.bias + self.bias*D + E
        else:
            optimal_bias = E
        optimal_prompt = torch.matmul(self.weight, C).squeeze()
        input = input.to(optimal_weight.dtype)
        result = F.linear(input, optimal_weight, optimal_bias+optimal_prompt)
        result = result.to(previous_dtype)
        return result

    @staticmethod
    def from_linear(linear_module, rank):
        new_linear = SuperScalableLinear(linear_module.in_features, linear_module.out_features, rank)
        new_linear.weight = linear_module.weight
        new_linear.bias = linear_module.bias
        return new_linear
    
class Linear(nn.Linear, LoraLayer):
    # Lora implemented in a dense layer
    def __init__(
        self,
        adapter_name: str,
        in_features: int,
        out_features: int,
        r: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0.0,
        fan_in_fan_out: bool = False,  # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
        **kwargs,
    ):
        init_lora_weights = kwargs.pop("init_lora_weights", True)

        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        LoraLayer.__init__(self, in_features=in_features, out_features=out_features)
        # Freezing the pre-trained weight matrix
        self.weight.requires_grad = False

        self.fan_in_fan_out = fan_in_fan_out
        if fan_in_fan_out:
            self.weight.data = self.weight.data.T

        nn.Linear.reset_parameters(self)
        self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
        self.active_adapter = adapter_name

    def merge(self):
        if self.active_adapter not in self.lora_A.keys():
            return
        if self.merged:
            warnings.warn("Already merged. Nothing to do.")
            return
        if self.r[self.active_adapter] > 0:
            self.weight.data += (
                transpose(
                    self.lora_B[self.active_adapter].weight @ self.lora_A[self.active_adapter].weight,
                    self.fan_in_fan_out,
                )
                * self.scaling[self.active_adapter]
            )
            self.merged = True

    def unmerge(self):
        if self.active_adapter not in self.lora_A.keys():
            return
        if not self.merged:
            warnings.warn("Already unmerged. Nothing to do.")
            return
        if self.r[self.active_adapter] > 0:
            self.weight.data -= (
                transpose(
                    self.lora_B[self.active_adapter].weight @ self.lora_A[self.active_adapter].weight,
                    self.fan_in_fan_out,
                )
                * self.scaling[self.active_adapter]
            )
            self.merged = False

    def compute_average_weights(self, module_dict):
        weights = []
        for module in module_dict.values():
            for param in module.parameters():
                weights.append(param)

        num_parameters = len(weights)
        if num_parameters > 0:
            average_weight = torch.stack(weights).mean(dim=0)
        else:
            average_weight = None

        return average_weight



    def forward(self, x: torch.Tensor):
        previous_dtype = x.dtype

        if self.active_adapter not in self.lora_A.keys():
            return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
        if self.disable_adapters:
            if self.r[self.active_adapter] > 0 and self.merged:
                self.unmerge()
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
        elif self.r[self.active_adapter] > 0 and not self.merged:
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

            # x = x.to(self.lora_A[self.active_adapter].weight.dtype)
            x = x.to(self.lora_A[self.active_adapter].dtype)

            x_s = x[:,0,:]
            t = x_s @ self.lora_gate.T
            gate = torch.argmax(t,dim=1)
            x_gate_dict = {}
            x_gate_output_dict = {}
            # init
            for i in range(self.lora_num):
                x_gate_dict[str(i)] = torch.empty((0, x.shape[1], x.shape[2])).to(x.device)
            # dispatch
            for i in range(x.shape[0]):
                x_gate_dict[str(gate[i].item())] = torch.cat((x_gate_dict[str(gate[i].item())], x[i].unsqueeze(0)), dim=0)
            # forward
            for i in range(self.lora_num):
                x_gate_output_dict[str(i)] = (self.lora_dropout[self.active_adapter](x_gate_dict[str(i)]) @ self.lora_A[str(i)].T @ self.lora_B[str(i)].T) * self.scaling[self.active_adapter]
            
            result_restored = torch.zeros_like(result).to(result.device)
            # init index                
            idx_list = {}
            for i in range(self.lora_num):
                idx_list[str(i)] = 0
            # merge    
            for i in range(x.shape[0]):
                result_restored[i] = x_gate_output_dict[str(gate[i].item())][idx_list[str(gate[i].item())]]
                idx_list[str(gate[i].item())] += 1
            result += result_restored




            # gate_output = self.gate(x)

            # probabilities = torch.softmax(gate_output, dim=1)
            # # 计算每个样本概率总和并获取最大概率的位置索引
            # sum_probs = torch.sum(probabilities, dim=1)
            # max_prob_idx = torch.argmax(sum_probs, dim=0)

            # result += (
            #         self.lora_B[str(max_prob_idx.item())](
            #             self.lora_A[str(max_prob_idx.item())](self.lora_dropout[str(max_prob_idx.item())](x))
            #         )
            #         * self.scaling[str(max_prob_idx.item())]
            #     )
            
            # if self.training:
            #     module_keys = list(self.lora_A.keys())
            #     random_key = random.choice(module_keys)
            #     result += (
            #         self.lora_B[random_key](
            #             self.lora_A[random_key](self.lora_dropout[random_key](x))
            #         )
            #         * self.scaling[random_key]
            #     )
            # else:
            #     # average_dropout_weights = self.compute_average_weights(self.lora_dropout)
            #     average_A_weights = self.compute_average_weights(self.lora_A)
            #     average_B_weights = self.compute_average_weights(self.lora_B)
            #     module_keys = list(self.lora_A.keys())
            #     random_key = random.choice(module_keys)
            #     average_scaling = self.scaling[random_key]
                    
            #     result += F.linear(F.linear(self.lora_dropout[random_key](x),average_A_weights),average_B_weights)*average_scaling

            # for key in self.lora_A.keys(): 
            #     result += (
            #         self.lora_B[key](
            #             self.lora_A[key](self.lora_dropout[key](x))
            #         )
            #         * self.scaling[key]
            #     )
            # result += (
            #     self.lora_B[self.active_adapter](
            #         self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
            #     )
            #     * self.scaling[self.active_adapter]
            # )
        else:
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

        result = result.to(previous_dtype)

        return result
    
def _freeze_adapter(model, adapter_name):
    for n, p in model.named_parameters():
        if adapter_name in n:
            p.requires_grad = False

def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
    for n, p in model.named_parameters():
        if "lora_" not in n:
            p.requires_grad = False
        # if any([x in n for x in ['A', 'B', 'C', 'D', 'E']]):
        #     p.requires_grad = True
        # else:
        #     p.requires_grad = False
    if bias == "none":
        return
    elif bias == "all":
        for n, p in model.named_parameters():
            if "bias" in n:
                p.requires_grad = True
    elif bias == "lora_only":
        for m in model.modules():
            if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None:
                m.bias.requires_grad = True
    else:
        raise NotImplementedError

def _get_submodules(model, key):
    parent = model.get_submodule(".".join(key.split(".")[:-1]))
    target_name = key.split(".")[-1]
    target = model.get_submodule(key)
    return parent, target, target_name


class ModulesToSaveWrapper(torch.nn.Module):
    def __init__(self, module_to_save, adapter_name):
        super().__init__()
        self.original_module = module_to_save
        self.modules_to_save = torch.nn.ModuleDict({})
        self.update(adapter_name)
        self.active_adapter = adapter_name

    def update(self, adapter_name):
        self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)}))

    def forward(self, *args, **kwargs):
        if self.active_adapter not in self.modules_to_save:
            return self.original_module(*args, **kwargs)
        return self.modules_to_save[self.active_adapter](*args, **kwargs)
    
class LoraModel(torch.nn.Module):
    def __init__(self, model, config, adapter_name):
        super().__init__()
        self.model = model
        self.forward = self.model.forward
        self.peft_config = config
        self.add_adapter(adapter_name, self.peft_config[adapter_name])
                
    def add_adapter(self, adapter_name, config=None):
        if config is not None:
            model_config = self.model.config.to_dict() if hasattr(self.model.config, "to_dict") else self.model.config
            config = self._prepare_lora_config(config, model_config)
            self.peft_config[adapter_name] = config
        self._find_and_replace(adapter_name)
        if len(self.peft_config) > 1 and self.peft_config[adapter_name].bias != "none":
            raise ValueError(
                "LoraModel supports only 1 adapter with bias. When using multiple adapters, set bias to 'none' for all adapters."
            )
        mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias)
        if self.peft_config[adapter_name].inference_mode:
            _freeze_adapter(self.model, adapter_name)

    def _find_and_replace(self, adapter_name):
        lora_config = self.peft_config[adapter_name]
        loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
        if loaded_in_8bit and not is_bnb_available():
            raise ImportError(
                "To use Lora with 8-bit quantization, please install the `bitsandbytes` package. "
                "You can install it with `pip install bitsandbytes`."
            )
        is_target_modules_in_base_model = False
        kwargs = {
            "r": lora_config.r,
            "lora_alpha": lora_config.lora_alpha,
            "lora_dropout": lora_config.lora_dropout,
            "fan_in_fan_out": lora_config.fan_in_fan_out,
            "init_lora_weights": lora_config.init_lora_weights,
        }
        key_list = [key for key, _ in self.model.named_modules()]
        for key in key_list:
            # if 'head' in key:
            #     print(key)
            #     continue
            # parent, target, target_name = _get_submodules(self.model, key)
            # if isinstance(target, torch.nn.Linear):
            #     is_target_modules_in_base_model = True
            #     in_features, out_features = target.in_features, target.out_features
            #     new_module = SuperScalableLinear(in_features, out_features, lora_config.r)
            #     self._replace_module(parent, target_name, new_module, target)
            # continue

            if isinstance(lora_config.target_modules, str):
                target_module_found = re.fullmatch(lora_config.target_modules, key)
            else:
                target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules)
            if target_module_found:
                if not is_target_modules_in_base_model:
                    is_target_modules_in_base_model = True
                parent, target, target_name = _get_submodules(self.model, key)
                if hasattr(target, "bias"):
                    bias = target.bias is not None

                if isinstance(target, LoraLayer):
                    target.update_layer(
                        adapter_name,
                        lora_config.r,
                        lora_config.lora_alpha,
                        lora_config.lora_dropout,
                        lora_config.init_lora_weights,
                    )
                else:
                    if isinstance(target, torch.nn.Linear):
                        in_features, out_features = target.in_features, target.out_features
                        if kwargs["fan_in_fan_out"]:
                            warnings.warn(
                                "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
                                "Setting fan_in_fan_out to False."
                            )
                            kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
                    new_module = Linear(adapter_name, in_features, out_features, bias=bias, **kwargs)
                    # new_module = SuperScalableLinear(in_features, out_features, lora_config.r)
                    # if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
                    #     eightbit_kwargs = kwargs.copy()
                    #     eightbit_kwargs.update(
                    #         {
                    #             "has_fp16_weights": target.state.has_fp16_weights,
                    #             "memory_efficient_backward": target.state.memory_efficient_backward,
                    #             "threshold": target.state.threshold,
                    #             "index": target.index,
                    #         }
                    #     )
                    #     new_module = Linear8bitLt(
                    #         adapter_name, target.in_features, target.out_features, bias=bias, **eightbit_kwargs
                    #     )
                    # elif isinstance(target, torch.nn.Embedding):
                    #     embedding_kwargs = kwargs.copy()
                    #     embedding_kwargs.pop("fan_in_fan_out", None)
                    #     in_features, out_features = target.num_embeddings, target.embedding_dim
                    #     new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs)
                    # else:
                    #     if isinstance(target, torch.nn.Linear):
                    #         in_features, out_features = target.in_features, target.out_features
                    #         if kwargs["fan_in_fan_out"]:
                    #             warnings.warn(
                    #                 "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
                    #                 "Setting fan_in_fan_out to False."
                    #             )
                    #             kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
                    #     elif isinstance(target, Conv1D):
                    #         in_features, out_features = (
                    #             target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
                    #         )
                    #         if not kwargs["fan_in_fan_out"]:
                    #             warnings.warn(
                    #                 "fan_in_fan_out is set to False but the target module is `Conv1D`. "
                    #                 "Setting fan_in_fan_out to True."
                    #             )
                    #             kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
                    #     else:
                    #         raise ValueError(
                    #             f"Target module {target} is not supported. "
                    #             f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
                    #         )
                    #     new_module = Linear(adapter_name, in_features, out_features, bias=bias, **kwargs)

                    self._replace_module(parent, target_name, new_module, target)
        if not is_target_modules_in_base_model:
            raise ValueError(
                f"Target modules {lora_config.target_modules} not found in the base model. "
                f"Please check the target modules and try again."
            )

    def _replace_module(self, parent_module, child_name, new_module, old_module):
        setattr(parent_module, child_name, new_module)
        new_module.weight = old_module.weight
        if hasattr(old_module, "bias"):
            if old_module.bias is not None:
                new_module.bias = old_module.bias

        if getattr(old_module, "state", None) is not None:
            new_module.state = old_module.state
            new_module.to(old_module.weight.device)

        # dispatch to correct device
        for name, module in new_module.named_modules():
            if "lora_" in name:
                module.to(old_module.weight.device)
            if any([x in name for x in ['A', 'B', 'C', 'D', 'E', 'head']]):
                module.to(old_module.weight.device)

    def __getattr__(self, name: str):
        """Forward missing attributes to the wrapped module."""
        try:
            return super().__getattr__(name)  # defer to nn.Module's logic
        except AttributeError:
            return getattr(self.model, name)

    def get_peft_config_as_dict(self, inference: bool = False):
        config_dict = {}
        for key, value in self.peft_config.items():
            config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()}
            if inference:
                config["inference_mode"] = True
        config_dict[key] = config
        return config

    def _set_adapter_layers(self, enabled=True):
        for module in self.model.modules():
            if isinstance(module, LoraLayer):
                module.disable_adapters = False if enabled else True

    def enable_adapter_layers(self):
        self._set_adapter_layers(enabled=True)

    def disable_adapter_layers(self):
        self._set_adapter_layers(enabled=False)

    def set_adapter(self, adapter_name):
        for module in self.model.modules():
            if isinstance(module, LoraLayer):
                if module.merged:
                    warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
                    module.unmerge()
                module.active_adapter = adapter_name

    def merge_adapter(self):
        for module in self.model.modules():
            if isinstance(module, LoraLayer):
                module.merge()

    def unmerge_adapter(self):
        for module in self.model.modules():
            if isinstance(module, LoraLayer):
                module.unmerge()

    @staticmethod
    def _prepare_lora_config(peft_config, model_config):
        # if peft_config.target_modules is None:
        #     if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING:
        #         raise ValueError("Please specify `target_modules` in `peft_config`")
        #     peft_config.target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]]
        if peft_config.inference_mode:
            peft_config.merge_weights = True
        return peft_config

    def merge_and_unload(self):
        r"""
        This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model
        as a standalone model.
        """
        if getattr(self.config, "model_type", None) == "gpt2":
            raise ValueError("GPT2 models are not supported for merging LORA layers")

        if getattr(self.model, "is_loaded_in_8bit", False):
            raise ValueError("Cannot merge LORA layers when the model is loaded in 8-bit mode")

        key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
        for key in key_list:
            try:
                parent, target, target_name = _get_submodules(self.model, key)
            except AttributeError:
                continue
            if isinstance(target, LoraLayer):
                bias = target.bias is not None
                new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias)
                target.merge()
                self._replace_module(parent, target_name, new_module, target)

            # save any additional trainable modules part of `modules_to_save`
            if isinstance(target, ModulesToSaveWrapper):
                setattr(parent, target_name, target.modules_to_save[target.active_adapter])

        return self.model

    def add_weighted_adapter(self, adapters, weights, adapter_name):
        if len({self.peft_config[adapter].r for adapter in adapters}) != 1:
            raise ValueError("All adapters must have the same r value")
        self.peft_config[adapter_name] = self.peft_config[adapters[0]]
        self.peft_config[adapter_name].lora_alpha = self.peft_config[adapters[0]].r
        self._find_and_replace(adapter_name)
        mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias)
        _freeze_adapter(self.model, adapter_name)
        key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
        for key in key_list:
            _, target, _ = _get_submodules(self.model, key)
            if isinstance(target, LoraLayer):
                if adapter_name in target.lora_A:
                    target.lora_A[adapter_name].weight.data = target.lora_A[adapter_name].weight.data * 0.0
                    target.lora_B[adapter_name].weight.data = target.lora_B[adapter_name].weight.data * 0.0
                    for adapter, weight in zip(adapters, weights):
                        if adapter not in target.lora_A:
                            continue
                        target.lora_A[adapter_name].weight.data += (
                            target.lora_A[adapter].weight.data * weight * target.scaling[adapter]
                        )
                        target.lora_B[adapter_name].weight.data += target.lora_B[adapter].weight.data * weight

                elif adapter_name in target.lora_embedding_A:
                    target.lora_embedding_A[adapter_name].data = target.lora_embedding_A[adapter_name].data * 0.0
                    target.lora_embedding_B[adapter_name].data = target.lora_embedding_B[adapter_name].data * 0.0
                    for adapter, weight in zip(adapters, weights):
                        if adapter not in target.lora_embedding_A:
                            continue
                        target.lora_embedding_A[adapter_name].data += (
                            target.lora_embedding_A[adapter].data * weight * target.scaling[adapter]
                        )
                        target.lora_embedding_B[adapter_name].data += target.lora_embedding_B[adapter].data * weight


In [None]:
peft_config = {}
config = LoraConfig(
    r=lora_r,
    # r=4,
    lora_alpha=lora_alpha,
    target_modules=lora_target_modules,
    lora_dropout=lora_dropout,
    bias="none",
    task_type="CAUSAL_LM",
)
# model = get_peft_model(model, config)

peft_config['0'] = config

model = LoraModel(model, peft_config, '0')


# config2 = LoraConfig(
#     r=16,
#     lora_alpha=32,
#     target_modules=lora_target_modules,
#     lora_dropout=lora_dropout,
#     bias="none",
#     task_type="CAUSAL_LM",
#     inference_mode = False
# )
# peft_config['default2'] = config
# for i in range(3):
#     model.add_adapter(f"{i+1}", config)

In [None]:
for key, _ in model.named_modules():
    pass
    # print(key)

In [None]:
for name, param in model.named_parameters():
    print(name)
    # print(name, param.requires_grad,param.dtype)
    

In [None]:
def tokenize(prompt, add_eos_token=True):
        # there's probably a way to do this with the tokenizer settings
        # but again, gotta move fast
        result = tokenizer(
            prompt,
            truncation=True,
            max_length=cutoff_len,
            padding=False,
            return_tensors=None,
        )
        if (
            result["input_ids"][-1] != tokenizer.eos_token_id
            and len(result["input_ids"]) < cutoff_len
            and add_eos_token
        ):
            result["input_ids"].append(tokenizer.eos_token_id)
            result["attention_mask"].append(1)

        result["labels"] = result["input_ids"].copy()

        return result
def generate_and_tokenize_prompt(data_point):
    data_point["instruction"] = 'What is the sentiment toward Bitcoin in the input sentence? [positive, negative, neutral]'
    data_point["input"] = data_point['text']
    data_point["output"] = data_point['label']
    del data_point['text']
    del data_point['label']
    full_prompt = prompter.generate_prompt(
        data_point["instruction"],
        data_point["input"],
        data_point["output"],
    )
    # print(full_prompt)
    tokenized_full_prompt = tokenize(full_prompt)
    if not train_on_inputs:
        user_prompt = prompter.generate_prompt(
            data_point["instruction"], data_point["input"]
        )
        tokenized_user_prompt = tokenize(
            user_prompt, add_eos_token=add_eos_token
        )
        user_prompt_len = len(tokenized_user_prompt["input_ids"])

        if add_eos_token:
            user_prompt_len -= 1

        tokenized_full_prompt["labels"] = [
            -100
        ] * user_prompt_len + tokenized_full_prompt["labels"][
            user_prompt_len:
        ]  # could be sped up, probably
    return tokenized_full_prompt





data = load_dataset("json", data_files=data_path)

data["test"] = data["train"].select(range(300,len(data["train"])))
data["train"] = data["train"].select(range(300))
# print(data["train"][0])
train_data = data["train"].map(generate_and_tokenize_prompt)
val_data = data["test"].map(generate_and_tokenize_prompt)
val_set_size = len(val_data)
# val_set_size = 0
print(train_data)
print(val_set_size, val_data)
# train_data


In [None]:
import json
from torch.utils.data import DataLoader
import time
def my_evaluate(self, ignore_keys):
    self.model.eval()    
    instructions = []
    with open(data_path, 'r') as f:
        for line in f:
            data = json.loads(line)
            data_point = {}
            data_point["instruction"] = 'What is the sentiment toward Bitcoin in the input sentence? [positive, negative, neutral]'
            data_point["input"] = data['text']
            data_point["output"] = data['label']
            full_prompt = prompter.generate_prompt(
                data_point["instruction"],
                data_point["input"]
            )
            instructions.append({'context':full_prompt, 'target':data['label']})

        # print(instructions[0])
        start_time = time.time()
        with torch.autocast("cuda"):
            with torch.no_grad():
                right = 0
                all = 0
                batch_size = 32
                input_texts = []
                targets = []
                for idx, item in enumerate(instructions[300:]):
                    # feature = format_example(item)
                    # input_text = feature["context"]
                    all = all + 1
                    input_texts.append(item["context"])
                    targets.append(item["target"])
                test_loader = DataLoader(input_texts, batch_size=batch_size)
                for batch_idx,batch in enumerate(test_loader):
                    input_ids = tokenizer(batch, padding=True,return_tensors='pt').to('cuda')
                    out = model.generate(
                        **input_ids,
                        temperature=0,
                        return_dict_in_generate= True,
                        output_scores=True,
                        max_new_tokens = 1
                    )
                    seqs = out['sequences']
                    scores = out['scores']
                    # print(scores[0].shape)
                    results = tokenizer.batch_decode(seqs)
                    # print(results, '\n\n')
                    # break
                    for idx,res in enumerate(results):
                        pred = res[res.find('Response') + 10:]
                        target = targets[batch_idx*batch_size + idx]
                        if target.find(pred) >= 0:
                            right = right + 1
                    # print(right,all,right/all)
    metrics = {"eval_acc": right/all}
    self.log(metrics)
    # print(metrics)
    self.model.train()
    self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
    return metrics



transformers.Trainer.evaluate = my_evaluate
trainer = transformers.Trainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=val_data,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=micro_batch_size,
        gradient_accumulation_steps=1,
        # warmup_steps=100,
        num_train_epochs=num_epochs,
        learning_rate=learning_rate,
        fp16=True,
        # logging_strategy  = "steps",
        logging_steps=10,
        optim="adamw_torch",
        metric_for_best_model = "acc",
        evaluation_strategy="steps" if val_set_size > 0 else "no",
        save_strategy="no",
        eval_steps=100 if val_set_size > 0 else None,
        save_steps=100,
        output_dir=output_dir,
        save_total_limit=50,
        # load_best_model_at_end=True if val_set_size > 0 else False,
        ddp_find_unused_parameters=False if ddp else None,
        group_by_length=group_by_length,
        report_to="wandb" if use_wandb else None,
        run_name=wandb_run_name if use_wandb else None,
    ),
    data_collator=transformers.DataCollatorForSeq2Seq(
        tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
    ),
)
model.config.use_cache = False

# old_state_dict = model.state_dict
# model.state_dict = (
#     lambda self, *_, **__: get_peft_model_state_dict(
#         self, old_state_dict()
#     )
# ).__get__(model, type(model))
# if torch.__version__ >= "2" and sys.platform != "win32":
#     model = torch.compile(model)

In [None]:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)

In [None]:
import random
import os
import sys
import torch
# from avalanche.evaluation.metrics.accuracy import Accuracy

class EvolutionSearcher(object):

    def __init__(self, args, model, choices, output_dir):
        # self.device = device
        self.model = model
        self.args = args
        self.max_epochs = args.max_epochs
        self.select_num = args.select_num
        self.population_num = args.population_num
        self.m_prob = args.m_prob
        self.crossover_num = args.crossover_num
        self.mutation_num = args.mutation_num
        self.parameters_limits = args.param_limits 
        self.min_parameters_limits = args.min_param_limits 
        # self.val_loader = val_loader
        self.output_dir = output_dir
        self.memory = []
        self.vis_dict = dict()
        self.keep_top_k = dict()
        self.keep_top_k[self.select_num] = []
        self.keep_top_k[50] = []
        self.epoch = 0
        self.candidates = []
        self.top_accuracies = []
        self.choices = choices
        self.cand_params = []

    def save_checkpoint(self):

        info = dict()
        info['top_accuracies'] = self.top_accuracies
        info['memory'] = self.memory
        info['candidates'] = self.candidates
        info['vis_dict'] = self.vis_dict
        info['keep_top_k'] = self.keep_top_k
        info['epoch'] = self.epoch
        checkpoint_path = os.path.join(self.output_dir, "checkpoint-{}.pth.tar".format(self.epoch))
        torch.save(info, checkpoint_path)
        print('save checkpoint to', checkpoint_path)

    def load_checkpoint(self):
        if not os.path.exists(self.checkpoint_path):
            return False
        info = torch.load(self.checkpoint_path)
        self.memory = info['memory']
        self.candidates = info['candidates']
        self.vis_dict = info['vis_dict']
        self.keep_top_k = info['keep_top_k']
        self.epoch = info['epoch']

        print('load checkpoint from', self.checkpoint_path)
        return True

    def set_config(self, config):
        i = 0
        for name, l in self.model.named_modules():
            # if isinstance(l, torch.nn.Linear) and name!='head':
            if isinstance(l, SuperScalableLinear) and name!='head':
                l.eval_config = config[i]
                i+=1

    def get_param_tensor(self, config, in_feature, out_feature, name):
        if 'A' in name or 'B' in name or 'C' in name:
            if 'C' in name:
                out_feature = in_feature
                in_feature = 1
            if 'LoRA' in config:
                try:
                    rank = int(config.split('_')[1])
                except:
                    rank = 16
                param = out_feature*rank + in_feature*rank
            elif 'vector' in config:
                param = out_feature
            elif 'constant' in config:
                param = 1
            elif 'none' in config:
                param = 0
            else:
                raise ValueError
        else:
            if 'vector' in config:
                param = out_feature
            elif 'constant' in config:
                param = 1
            elif 'none' in config:
                param = 0
            else:
                raise ValueError
        return param
    
    def get_param(self, configs):
        i = 0
        params = 0

        for n, l in self.model.named_modules():
            if isinstance(l, SuperScalableLinear) and n != 'head':
                out_channel = l.out_features
                in_channel = l.in_features
                for sup_tnsr in ['A', 'B', 'C', 'D', 'E']:
                    # print(i, len(configs))
                    params += self.get_param_tensor(configs[i][sup_tnsr], in_channel, out_channel, sup_tnsr)
                i+=1
        return params

    def is_legal(self, cand):
        assert isinstance(cand, tuple)
        if str(cand) not in self.vis_dict:
            self.vis_dict[str(cand)] = {}
        info = self.vis_dict[str(cand)]
        if 'visited' in info:
            return False
        n_parameters = self.get_param(configs=cand)
        info['params'] =  n_parameters / 10.**6 
        
        if info['params'] > self.parameters_limits:
            print('parameters limit exceed')
            sys.stdout.flush()
            return False

        if info['params'] < self.min_parameters_limits:
            print('under minimum parameters limit')
            return False
        
        eval_acc = self.evaluate(config=cand)
        info['acc'] = eval_acc
        print(info['acc'])
        info['visited'] = True

        return True

    def evaluate(self, config):
        self.set_config(config)
        self.model.eval()    
        instructions = []
        with open(data_path, 'r') as f:
            for line in f:
                data = json.loads(line)
                data_point = {}
                data_point["instruction"] = 'What is the sentiment toward Bitcoin in the input sentence? [positive, negative, neutral]'
                data_point["input"] = data['text']
                data_point["output"] = data['label']
                full_prompt = prompter.generate_prompt(
                    data_point["instruction"],
                    data_point["input"]
                )
                instructions.append({'context':full_prompt, 'target':data['label']})

            # print(instructions[0])
            start_time = time.time()
            with torch.autocast("cuda"):
                with torch.no_grad():
                    right = 0
                    all = 0
                    batch_size = 256
                    input_texts = []
                    targets = []
                    for idx, item in enumerate(instructions[300:]):
                        # feature = format_example(item)
                        # input_text = feature["context"]
                        all = all + 1
                        input_texts.append(item["context"])
                        targets.append(item["target"])
                    test_loader = DataLoader(input_texts, batch_size=batch_size)
                    for batch_idx,batch in enumerate(test_loader):
                        input_ids = tokenizer(batch, padding=True,return_tensors='pt').to('cuda')
                        out = model.generate(
                            **input_ids,
                            temperature=0,
                            return_dict_in_generate= True,
                            output_scores=True,
                            max_new_tokens = 1
                        )
                        seqs = out['sequences']
                        scores = out['scores']
                        # print(scores[0].shape)
                        results = tokenizer.batch_decode(seqs)
                        # print(results, '\n\n')
                        # break
                        for idx,res in enumerate(results):
                            pred = res[res.find('Response') + 10:]
                            target = targets[batch_idx*batch_size + idx]
                            if target.find(pred) >= 0:
                                right = right + 1
                        # print(right,all,right/all)
        return right/all
        # self.model.eval()
        # self.set_config(config)
        # acc = Accuracy()
        # for batch in self.val_loader:
        #     x, y = batch[0].cuda(), batch[1].cuda()
        #     out = self.model(x).data
        #     acc.update(out.argmax(dim=1).view(-1), y, 1)

        # return acc.result()[1]
    
    def update_top_k(self, candidates, *, k, key, reverse=True):
        assert k in self.keep_top_k
        print('select ......')
        t = self.keep_top_k[k]
        t += candidates
        t.sort(key=key, reverse=reverse)
        self.keep_top_k[k] = t[:k]

    def stack_random_cand(self, random_func, *, batchsize=10):
        while True:
            cands = [random_func() for _ in range(batchsize)]
            for cand in cands:
                if str(cand) not in self.vis_dict:
                    self.vis_dict[str(cand)] = {}
                info = self.vis_dict[str(cand)]
            for cand in cands:
                yield cand

    def get_random_cand(self):

        cand_tuple = list()
        depth = 64 ## 12 (depth) X 4 (layers)
        for i in range(depth):
            cand_tuple.append({'A':random.choice(self.choices['A']),
                               'B':random.choice(self.choices['B']),
                               'C':random.choice(self.choices['C']),
                               'D':random.choice(self.choices['D']),
                               'E':random.choice(self.choices['E'])})

        return tuple(cand_tuple)

    def get_random(self, num):
        print('random select ........')
        cand_iter = self.stack_random_cand(self.get_random_cand)
        while len(self.candidates) < num:
            cand = next(cand_iter)
            if not self.is_legal(cand):
                continue
            self.candidates.append(cand)
            print('random {}/{}'.format(len(self.candidates), num))
        print('random_num = {}'.format(len(self.candidates)))

    def get_mutation(self, k, mutation_num, m_prob):
        assert k in self.keep_top_k
        print('mutation ......')
        res = []
        iter = 0
        max_iters = mutation_num * 10

        def random_func():
            cand = list(random.choice(self.keep_top_k[k]))
            final = list()
            for i in range(len(cand)):
                final_layer = dict()
                for key in ['A', 'B', 'C', 'D', 'E']:
                    random_s = random.random()
                    if random_s < m_prob:
                        final_layer[key] = random.choice(self.choices[key])
                    else:
                        final_layer[key] = cand[i][key]
                final.append(final_layer)
            return tuple(final)

        cand_iter = self.stack_random_cand(random_func)
        while len(res) < mutation_num and max_iters > 0:
            max_iters -= 1
            cand = next(cand_iter)
            if not self.is_legal(cand):
                continue
            res.append(cand)
            print('mutation {}/{}'.format(len(res), mutation_num))

        print('mutation_num = {}'.format(len(res)))
        return res

    def get_crossover(self, k, crossover_num):
        assert k in self.keep_top_k
        print('crossover ......')
        res = []
        iter = 0
        max_iters = 10 * crossover_num

        def random_func():
            cand_1 = list(random.choice(self.keep_top_k[k]))
            cand_2 = list(random.choice(self.keep_top_k[k]))
            final = list()
            for i in range(len(cand_1)):
                final_layer = dict()
                for key in ['A', 'B', 'C', 'D', 'E']:
                    final_layer[key] = random.choice([cand_1[i][key], cand_2[i][key]])
                final.append(final_layer)
            return tuple(final)

        cand_iter = self.stack_random_cand(random_func)
        while len(res) < crossover_num and max_iters > 0:
            max_iters -= 1
            cand = next(cand_iter)
            if not self.is_legal(cand):
                continue
            res.append(cand)
            print('crossover {}/{}'.format(len(res), crossover_num))

        print('crossover_num = {}'.format(len(res)))
        return res

    def search(self):
        print(
            'population_num = {} select_num = {} mutation_num = {} crossover_num = {} random_num = {} max_epochs = {}'.format(
                self.population_num, self.select_num, self.mutation_num, self.crossover_num,
                self.population_num - self.mutation_num - self.crossover_num, self.max_epochs))


        self.get_random(self.population_num)

        while self.epoch < self.max_epochs:
            print('epoch = {}'.format(self.epoch))

            self.memory.append([])
            for cand in self.candidates:
                self.memory[-1].append(cand)
            
            #updata top10
            self.update_top_k(
                self.candidates, k=self.select_num, key=lambda x: self.vis_dict[str(x)]['acc'])
            #updata top50
            self.update_top_k(
                self.candidates, k=50, key=lambda x: self.vis_dict[str(x)]['acc'])

            print('epoch = {} : top {} result'.format(
                self.epoch, len(self.keep_top_k[50])))
            tmp_accuracy = []
            for i, cand in enumerate(self.keep_top_k[50]):
                print('No.{} Top-1 val acc = {}, params = {}'.format(
                    i + 1, self.vis_dict[str(cand)]['acc'], self.vis_dict[str(cand)]['params']))   
                sys.stdout.flush()
                tmp_accuracy.append(self.vis_dict[str(cand)]['acc'])
            self.top_accuracies.append(tmp_accuracy)

            mutation = self.get_mutation(
                self.select_num, self.mutation_num, self.m_prob)
            crossover = self.get_crossover(self.select_num, self.crossover_num)

            self.candidates = mutation + crossover

            self.get_random(self.population_num)

            self.epoch += 1

            self.save_checkpoint()

In [None]:
# from argparse import ArgumentParser
# parser = ArgumentParser()
# parser.add_argument('--seed', type=int, default=1)
# # parser.add_argument('--model', type=str, default='vit_base_patch16_224_in21k')
# # parser.add_argument('--dataset', type=str, default='cifar')
# parser.add_argument('--save_path', type=str, default='models/temp/')
# parser.add_argument('--load_path', type=str, default='models/temp/')
# parser.add_argument('--max-epochs', type=int, default=20)
# parser.add_argument('--select-num', type=int, default=10)
# parser.add_argument('--population-num', type=int, default=50)
# parser.add_argument('--m_prob', type=float, default=0.2)
# parser.add_argument('--crossover-num', type=int, default=25)
# parser.add_argument('--epochs', type=int, default=30)
# parser.add_argument('--mutation-num', type=int, default=25)
# parser.add_argument('--param-limits', type=float, default=1.00)
# parser.add_argument('--min-param-limits', type=float, default=0)
# parser.add_argument('--rank', type=int, default=4)
# args = parser.parse_args()

# class Args:
#     def __init__(self):
#         self.seed = 1
#         self.save_path = 'models/temp/'
#         self.load_path = 'models/temp/'
#         self.max_epochs = 20
#         self.select_num = 10
#         self.population_num = 50
#         self.m_prob = 0.2
#         self.crossover_num = 25
#         self.epochs = 30
#         self.mutation_num = 25
#         self.param_limits = 1000.00
#         self.min_param_limits = 0
#         self.rank = 8

# args = Args()
# choices = dict()
# choices['A'] = [f'LoRA_{lora_r}', 'vector', 'constant', 'none']
# choices['B'] = choices['A']
# choices['C'] = [f'LoRA_{lora_r}', 'vector', 'none']
# choices['D'] = ['constant', 'none', 'vector']
# choices['E'] = choices['D']

# searcher = EvolutionSearcher(args, model, choices, args.save_path)
# searcher.search()