In [None]:
from dataclasses import dataclass, field
from typing import Optional


from dataclasses import dataclass, field
from typing import Optional

from tqdm import tqdm
import peft
import copy
import torch
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, pipeline
from collections import OrderedDict
from datasets import load_dataset
from trl.core import LengthSampler


class ScriptArguments:
    #model_name = "edbeeching/gpt-neo-125M-imdb-lora-adapter-merged-ppo-sentiment"
    sentiment_models = [
        "lvwerra/distilbert-imdb",
        "distilbert-base-uncased-finetuned-sst-2-english",
        "martin-ha/toxic-comment-model",
        "valurank/distilbert-quality"
    ]
    model_names = [
        "alexrame/gpt-neo-125M-imdb-lora-adapter-merged-ppo-sentiment-lr1.41e-05",
        # "alexrame/gpt-neo-125M-imdb-lora-adapter-merged-ppo-sentiment-lr1e-05",
        # "alexrame/gpt-neo-125M-imdb-lora-adapter-merged-ppo-sentiment-distilbert-neg-lr1.41e-05",
        "alexrame/gpt-neo-125M-imdb-lora-adapter-merged-ppo-sentiment-toxic-neg-lr1.41e-05"
    ]


script_args = ScriptArguments()

def load_model(peft_model_id):
    peft_config = PeftConfig.from_pretrained(peft_model_id)
    model = AutoModelForCausalLM.from_pretrained(
        peft_config.base_model_name_or_path,
        return_dict=True,
        #torch_dtype=torch.float16,
        load_in_8bit=True,
        device_map="auto"
    )
    # Load the Lora model
    model = PeftModel.from_pretrained(model, peft_model_id)
    model.eval()
    return model


device = 0 if torch.cuda.is_available() else "cpu"
dict_models_to_merge = OrderedDict({model_name: load_model(model_name) for model_name in script_args.model_names})
# average
def average_weights(input_models, coefficients):
    """average weights of different transformer models based on the amount of training data they were trained on"""
    weights_averaged = OrderedDict()
    for i, current_model in tqdm(enumerate(input_models), leave=False):
        current_weights = current_model.state_dict()
        for key in current_weights.keys():
            if i == 0:
                weights_averaged[key] = coefficients[i] * current_weights[key]
            else:
                weights_averaged[key] += coefficients[i] * current_weights[key]

    return weights_averaged

def enrich_wa(dict_models_to_merge):
    weights_averaged = average_weights(dict_models_to_merge.values(), [1/len(dict_models_to_merge)]*len(dict_models_to_merge))
    base_model_copy = list(dict_models_to_merge.values())[0]
    base_model_copy.load_state_dict(weights_averaged, strict=True)
    return base_model_copy

tokenizer = AutoTokenizer.from_pretrained(
    PeftConfig.from_pretrained(script_args.model_names[0]).base_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token

print(f"Load sentiment model with {script_args.sentiment_models}")
sentiment_pipes = [
    pipeline("sentiment-analysis", model=sentiment_model, device=device)
    for sentiment_model in script_args.sentiment_models]

def get_prediction_rewards(model, query_tensors):
    def get_rewards(responses_text):
        sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": 1}
        rewards = [
            [sentiment_pipe(response_text, **sent_kwargs) for sentiment_pipe in sentiment_pipes]
            for response_text in responses_text]
        return rewards
    def average_rewards(rewards):
        avg_reward = None
        for reward in rewards:
            if avg_reward is None:
                avg_reward = copy.deepcopy(reward)
            else:
                for a, r in zip(avg_reward, reward):
                    for i, rr in enumerate(r):
                        for j, rrr in enumerate(rr):
                            assert a[i][j]["label"] == rrr["label"]
                            a[i][j]["score"] = a[i][j]["score"] + rrr["score"]

        for a in avg_reward:
            for i, r in enumerate(a):
                for j, rr in enumerate(r):
                    rr["score"] = rr["score"] / len(rewards)
        return avg_reward

    response_tensors = []
    responses_text = []
    # with torch.cuda.amp.autocast():
    for i in range(len(query_tensors)):
        query_tensor = torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device)
        output = model.generate(input_ids=query_tensor, max_new_tokens=50).squeeze()
        response_tensors.append(output)
        response = tokenizer.decode(output, skip_special_tokens=True)
        responses_text.append(response)

    rewards = get_rewards(responses_text)
    reward = average_rewards(rewards)
    return responses_text, rewards, reward


def get_samples_query_tensors():
    list_texts = [
        "I really enjoyed the slight hint towards",
        "I really hated the horrible hint towards"
    ]

    batch = tokenizer(list_texts, return_tensors="pt")
    return batch["input_ids"]




def predict(dict_models_to_merge, query_tensors):
    for model_name, model in dict_models_to_merge.items():
        responses_text, rewards, avg_reward = get_prediction_rewards(model, query_tensors)
        print("model:", model_name)
        print("avg reward:", avg_reward)
        for text, reward in zip(responses_text, rewards):
            print("text:", text)
            print("reward:", reward)
        print("\n")

samples_query_tensors = get_samples_query_tensors()
predict(dict_models_to_merge, samples_query_tensors)

def get_imdb_query_tensors(bs=16):
    ds = load_dataset("imdb", split="test")
    ds = ds.filter(lambda x: len(x["text"]) > 200, batched=False)

    input_min_text_length=2
    input_max_text_length=8
    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["text"])[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")

    #### get a batch from the dataset
    ds.set_format("pandas")
    df_batch = ds[:].sample(bs)
    query_tensors = df_batch['input_ids'].tolist()
    return query_tensors



In [None]:
predict(dict_models_to_merge, samples_query_tensors)

In [None]:
wa = enrich_wa(dict_models_to_merge)
predict({"wa": wa}, samples_query_tensors)

In [None]:
predict({"wa": wa}, imdb_query_tensors)