In [1]:
import argparse
from tqdm import tqdm
import copy
import torch
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, pipeline
from collections import OrderedDict
from datasets import load_dataset
from trl.core import LengthSampler

  from .autonotebook import tqdm as notebook_tqdm



Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues


In [2]:
class ScriptArguments:
    sentiment_models = [
        "lvwerra/distilbert-imdb", "distilbert-base-uncased-finetuned-sst-2-english",
        "martin-ha/toxic-comment-model", "valurank/distilbert-quality"
    ]
    model_names = [
        "alexrame/gpt-neo-125M-imdb-lora-adapter-merged-ppo-sentiment-lr1.41e-05",
        # "alexrame/gpt-neo-125M-imdb-lora-adapter-merged-ppo-sentiment-lr1e-05",
        # "alexrame/gpt-neo-125M-imdb-lora-adapter-merged-ppo-sentiment-distilbert-lr1.41e-05",
        # cli: accelerate launch gpt-neo-20b_sentiment_peft.py --sentiment_model distilbert-base-uncased-finetuned-sst-2-english
        # "alexrame/gpt-neo-125M-imdb-lora-adapter-merged-ppo-sentiment-distilbert-neg-lr1.41e-05",
        # cli: accelerate launch gpt-neo-20b_sentiment_peft.py --sentiment_model distilbert-base-uncased-finetuned-sst-2-english --score_goal negative
        # "alexrame/gpt-neo-125M-imdb-lora-adapter-merged-ppo-sentiment-toxic-neg-lr1.41e-05",
        # cli: accelerate launch gpt-neo-20b_sentiment_peft.py --sentiment_model martin-ha/toxic-comment-model --score_goal 1
        # "alexrame/gpt-neo-125M-imdb-lora-adapter-merged-ppo-sentiment-toxic-0",
        # cli: accelerate launch gpt-neo-20b_sentiment_peft.py --sentiment_model martin-ha/toxic-comment-model --score_goal 0,
        # "alexrame/gpt-neo-125M-imdb-lora-adapter-merged-ppo-sentiment-distilbert-1",
        # cli: accelerate launch gpt-neo-20b_sentiment_peft.py --sentiment_model valurank/distilbert-quality --score_goal 1
        "alexrame/gpt-neo-125M-imdb-lora-adapter-merged-ppo-sentiment-distilbert-2"
        #  cli: accelerate launch gpt-neo-20b_sentiment_peft.py --sentiment_model valurank/distilbert-quality --score_goal 2
    ]
    num_samples = 160

def get_args():
    parser = argparse.ArgumentParser(description='Inference')
    parser.add_argument('--sentiment_models', type=str, nargs='+', default=ScriptArguments.sentiment_models)
    parser.add_argument('--model_names', type=str, nargs='+', default=ScriptArguments.model_names)
    parser.add_argument('--num_samples', type=int, default=ScriptArguments.num_samples)
    return parser.parse_args()

def notebook_get_args():
    return ScriptArguments()

In [3]:
def load_model(peft_model_id):
    peft_config = PeftConfig.from_pretrained(peft_model_id)
    model = AutoModelForCausalLM.from_pretrained(
        peft_config.base_model_name_or_path,
        return_dict=True,
        #torch_dtype=torch.float16,
        load_in_8bit=True,
        device_map="auto",
    )
    # Load the Lora model
    model = PeftModel.from_pretrained(
        model,
        peft_model_id,
    )
    model.eval()
    return model


# average
def average_weights(input_models, coefficients):
    """average weights of different transformer models based on the amount of training data they were trained on"""
    weights_averaged = OrderedDict()
    for i, current_model in tqdm(enumerate(input_models), leave=False):
        current_weights = current_model.state_dict()
        for key in current_weights.keys():
            if i == 0:
                weights_averaged[key] = coefficients[i] * current_weights[key]
            else:
                weights_averaged[key] += coefficients[i] * current_weights[key]

    return weights_averaged

def enrich_wa(dict_models_to_merge, coefficients=None):
    if coefficients is None:
        coefficients = [1 / len(dict_models_to_merge) for _ in len(dict_models_to_merge)]
    weights_averaged = average_weights(dict_models_to_merge.values(), coefficients)
    base_model_copy = list(dict_models_to_merge.values())[0]
    base_model_copy.load_state_dict(weights_averaged, strict=True)
    return base_model_copy


def get_prediction_rewards(model, query_tensors):
    def get_rewards(responses_text):
        sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": 1}
        rewards = [
            [sentiment_pipe(response_text, **sent_kwargs) for sentiment_pipe in sentiment_pipes]
            for response_text in responses_text]

        rewards = [transform_reward(reward) for reward in rewards]
        return rewards
    def transform_reward(reward):
        d_reward = []
        for rew in reward:
            d = {}
            assert len(rew) == 1
            for r in rew[0]:
                d[r["label"]] = r["score"]
            d_reward.append(d)
        return d_reward

    def average_rewards(rewards):
        avg_reward = None
        for reward in rewards:
            if avg_reward is None:
                avg_reward = copy.deepcopy(reward)
            else:
                for a_dict_reward, r_dict_reward in zip(avg_reward, reward):
                    for label in a_dict_reward:
                        a_dict_reward[label] = a_dict_reward[label] + r_dict_reward[label]

        for a_dict_reward in avg_reward:
            for label in a_dict_reward:
                a_dict_reward[label] = a_dict_reward[label] / len(rewards)
        return avg_reward

    response_tensors = []
    responses_text = []
    # with torch.cuda.amp.autocast():
    for i in range(len(query_tensors)):
        query_tensor = torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device)
        output = model.generate(
            input_ids=query_tensor, max_new_tokens=50, pad_token_id=tokenizer.eos_token_id
        ).squeeze()
        response_tensors.append(output)
        response = tokenizer.decode(output, skip_special_tokens=True)
        responses_text.append(response)

    rewards = get_rewards(responses_text)
    avg_reward = average_rewards(rewards)
    return responses_text, rewards, avg_reward


def get_samples_query_tensors():
    list_texts = [
        "I really enjoyed the slight hint towards",
        "I really hated the horrible hint towards"
    ]

    batch = tokenizer(list_texts, return_tensors="pt")
    return batch["input_ids"]


def predict(dict_models_to_merge, query_tensors, verbose=False):
    list_rewards = []
    for model_name, model in dict_models_to_merge.items():
        responses_text, rewards, avg_reward = get_prediction_rewards(model, query_tensors)
        print("model:", model_name)
        print("responses_text[0]", responses_text[0])
        if verbose:
            print("avg reward:", avg_reward)
            for text, reward in zip(responses_text, rewards):
                print("text:", text)
                print("reward:", reward)
            print("\n")
        list_rewards.append(avg_reward)
    return list_rewards



def get_imdb_query_tensors(bs=16):
    ds = load_dataset("imdb", split="test")
    ds = ds.filter(lambda x: len(x["text"]) > 200, batched=False)

    input_min_text_length=2
    input_max_text_length=8
    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["text"])[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")

    #### get a batch from the dataset
    ds.set_format("pandas")
    df_batch = ds[:].sample(bs)
    query_tensors = df_batch['input_ids'].tolist()
    return query_tensors

# samples_query_tensors = get_samples_query_tensors()

def average_states_dict(list_states_dict, coefficients):
    """average weights of different transformer models based on the amount of training data they were trained on"""
    weights_averaged = OrderedDict()
    for i, current_weights in enumerate(list_states_dict):
        for key in current_weights.keys():
            if i == 0:
                weights_averaged[key] = coefficients[i] * current_weights[key]
            else:
                weights_averaged[key] += coefficients[i] * current_weights[key]
    return weights_averaged


def enrich_wa_states(list_states_dict, coefficients=None):
    weights_averaged = average_states_dict(list_states_dict, coefficients)
    base_model_copy = list(dict_models_to_merge.values())[0]
    base_model_copy.load_state_dict(weights_averaged, strict=True)
    return base_model_copy


In [4]:
script_args = notebook_get_args()

In [5]:
device = 0 if torch.cuda.is_available() else "cpu"
print(f"Load LMs with {script_args.model_names}")
dict_models_to_merge = OrderedDict({model_name: load_model(model_name) for model_name in script_args.model_names})

tokenizer = AutoTokenizer.from_pretrained(
    PeftConfig.from_pretrained(script_args.model_names[0]).base_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

print(f"Load sentiment model with {script_args.sentiment_models}")
sentiment_pipes = [
    pipeline("sentiment-analysis", model=sentiment_model, device=device)
    for sentiment_model in script_args.sentiment_models]

list_states_dict = []
for current_model in dict_models_to_merge.values():
    current_weights = copy.deepcopy(current_model.state_dict())
    list_states_dict.append(current_weights)

Load LMs with ['alexrame/gpt-neo-125M-imdb-lora-adapter-merged-ppo-sentiment-lr1.41e-05', 'alexrame/gpt-neo-125M-imdb-lora-adapter-merged-ppo-sentiment-distilbert-2']




Load sentiment model with ['lvwerra/distilbert-imdb', 'distilbert-base-uncased-finetuned-sst-2-english', 'martin-ha/toxic-comment-model', 'valurank/distilbert-quality']


In [6]:
samples_query_tensors = get_samples_query_tensors()

In [7]:
wa = enrich_wa_states(list_states_dict, coefficients=[0.3, 0.7])
list_rewards_wa_samples = predict({"wa": wa}, samples_query_tensors, verbose=True)

  query_tensor = torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device)
  attn_weights = torch.where(causal_mask, attn_weights, mask_value)


model: wa
responses_text[0] I really enjoyed the slight hint towards the end of the film. The film was very well done and the characters were well-developed. The film was very well-balanced and the story was well-told. The film was well-done and the story was well-told. The
avg reward: [{'NEGATIVE': -0.06881940364837646, 'POSITIVE': 0.010891556739807129}, {'NEGATIVE': 0.1752924919128418, 'POSITIVE': 0.43285930156707764}, {'non-toxic': 2.6703314185142517, 'toxic': -2.844227433204651}, {'bad': 2.0736796855926514, 'medium': -1.498795136809349, 'good': -0.987948015332222}]
text: I really enjoyed the slight hint towards the end of the film. The film was very well done and the characters were well-developed. The film was very well-balanced and the story was well-told. The film was well-done and the story was well-told. The
reward: [{'NEGATIVE': -2.524400472640991, 'POSITIVE': 2.78214693069458}, {'NEGATIVE': -4.2379865646362305, 'POSITIVE': 4.612117290496826}, {'non-toxic': 3.5791893005371094



In [8]:
wa = enrich_wa_states(list_states_dict, coefficients=[0.25, 0.75])
list_rewards_wa_samples = predict({"wa": wa}, samples_query_tensors, verbose=True)

  query_tensor = torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device)


model: wa
responses_text[0] I really enjoyed the slight hint towardsdigyffitiracistNazisstrousicesterSTDOUTizarreSTDOUTizarreSTDOUTizarreSTDOUTfascistaghettifascistaghettifascistpherdracistoilerlbsikiniNazisstrousfasciststrousfasciststrousfasciststrousikini Imranuggishffitiuggishfasciststrouscorruptionampunkikini othersuggishffitiuggishffitiuggishffitiuggishffiti
avg reward: [{'NEGATIVE': 0.6124661266803741, 'POSITIVE': -0.7730308771133423}, {'NEGATIVE': 2.5565974712371826, 'POSITIVE': -2.25279837846756}, {'non-toxic': 1.3001094087958336, 'toxic': -1.5670802146196365}, {'bad': 1.868288278579712, 'medium': -2.549556255340576, 'good': 0.10077637434005737}]
text: I really enjoyed the slight hint towardsdigyffitiracistNazisstrousicesterSTDOUTizarreSTDOUTizarreSTDOUTizarreSTDOUTfascistaghettifascistaghettifascistpherdracistoilerlbsikiniNazisstrousfasciststrousfasciststrousfasciststrousikini Imranuggishffitiuggishfasciststrouscorruptionampunkikini othersuggishffitiuggishffitiuggishffitiuggis

In [None]:

for coeff in [x / 20 for x in range(17, -1, -1)]:
    wa = enrich_wa_states(list_states_dict, coefficients=[1 - coeff, coeff])
    list_rewards_wa_samples = predict({"wa": wa}, samples_query_tensors)
    print(coeff)
    print(list_rewards_wa_samples)
    print("\n")