In [None]:
from torch.utils.data import DataLoader
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoFeatureExtractor
from general_dataset import GeneralDataset
from agi_utils import *
from tqdm import tqdm
from undecorated import undecorated
from types import MethodType


import numpy as np
from IPython.utils import io
import random
from evaluate import load
from torchvision import transforms
from torchmetrics.multimodal import CLIPScore
from combine_model_seq import SeqCombine

In [None]:
"""
assign openagi data path 
"""
data_path = "YOUR_DATA_PATH"

task_discriptions = txt_loader("./task_description.txt")
test_task_idx = [2,3,10,15,20,35,45,55,65,70,70,90,106,107]
test_dataloaders = []
for i in tqdm(test_task_idx):
    dataset = GeneralDataset(i, data_path)
    dataloader = DataLoader(dataset, batch_size=20)
    test_dataloaders.append(dataloader)
    
test_tasks = [task_discriptions[i].strip() for i in test_task_idx]

In [None]:
# base_model = "eachadea/vicuna-7b-1.1"
base_model = "meta-llama/Llama-2-13b-chat-hf"
# base_model = "TheBloke/Llama-2-13B-chat-GGML"
# base_model = "chainyo/alpaca-lora-7b"
load_8bit = True

hf_token = "YOUR_HUGGINGFACE_KEY"

max_memory_mapping = {
    0: "48GB",
    1: "48GB",
    2: "48GB",
    3: "48GB",
    4: "0GB",
    5: "0GB",
    6: "0GB",
    # 7: "0GB",
}

# max_memory_mapping = {
#     0: "0GB",
#     1: "0GB",
#     2: "24GB",
#     3: "24GB",
# }

tokenizer = AutoTokenizer.from_pretrained(
    base_model,
    use_auth_token=hf_token
    # padding_side='left'
)
# tokenizer.add_special_tokens({'pad_token': '<pad>'})
tokenizer.pad_token_id = 0

model = AutoModelForCausalLM.from_pretrained(
    base_model,
    device_map="auto",
    max_memory=max_memory_mapping,
    use_auth_token=hf_token
)

lora_weights = "YOUR_LORA_WEIGHTS"

model = PeftModelForCausalLM.from_pretrained(
    model,
    lora_weights,
    torch_dtype=torch.float16,
    is_trainable=False,
    device_map="auto",
    max_memory=max_memory_mapping,
)

model.print_trainable_parameters()

In [None]:
import openai
openai.api_key = "YOUR_OPENAI_KEY"

def generate_module_list_with_gpt(generated_module_seq):
    todo_prompt = "You are a key phrase extractor who is able to extract potential module names from the given context. You have already known all the module names in the full module list. The full module list is: [Image Classification, Colorization, Object Detection, Image Deblurring, Image Denoising, Image Super Resolution, Image Captioning, Text to Image Generation, Visual Question Answering, Sentiment Analysis, Question Answering, Text Summarization, Machine Translation]. Given the following context: '{}'. Please extract a module sequence from this context and remove module names which do not exist in the full module list from this sequence. Output the module sequence after filtering as the format of 'module: module1, module: module2, module: module3, etc...'. "
    prompt = todo_prompt.format(generated_module_seq)

    completion = openai.ChatCompletion.create(
      model="gpt-3.5-turbo-0613",
      messages=[
        {"role": "user", "content": prompt}
      ]
    )

    content = completion.choices[0].message["content"]
    
    # print(content)
    
    content = content.split("module: ")[1:]
    
    result = ""
    for c in content:
        result += c
    
    # result = result[:-1] if len(result) > 0 else result
    
    return result

# generated_module_list = generate_module_list_with_gpt(response[prompt_length:])
# print(generated_module_list)

In [None]:
"""
Loading Evaluation Metrics
"""

clip_score = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16")


# Load a pre-trained Vision Transformer model and its feature extractor
vit_ckpt = "nateraw/vit-base-beans"
vit = AutoModel.from_pretrained(vit_ckpt)
vit.eval()
vit_extractor = AutoFeatureExtractor.from_pretrained(vit_ckpt)

f = transforms.ToPILImage()
bertscore = load("bertscore")

# device_list = ["cuda:1","cuda:2","caugment_prompt:3","cuda:4","cuda:5","cuda:7","cpu"]
device_list = ["cuda:0", "cpu"]
seqCombination = SeqCombine(device_list)

In [None]:
from sentence_transformers import SentenceTransformer, util
sentence_model = SentenceTransformer('all-MiniLM-L6-v2', device="cpu")

module_length = 10
num_beams = 1
num_return_sequences = 1

eval_device = "cuda:5"

random_seeds = [0, 1, 2, 3, 4]

total_avg_clips = []
total_avg_berts = []
total_avg_similarities = []
total_avg_rewards = []

for idx, seed in enumerate(tqdm(random_seeds)):
    torch.manual_seed(seed)
    rewards = []
    clips = []
    berts = []
    similarities = []
    for i, task_description in enumerate(tqdm(test_tasks)):
        # if i == 1:
        #     break
            
        print(task_description)
        task_rewards = []
        with torch.no_grad():
            input_s = ["### Human: "+task_description+"\n### Assistant: "]
            # input_s = [context + "Problem: " + task_description + "Solution:\n"]
            input_ids = tokenizer.batch_encode_plus(
                input_s, padding="longest", return_tensors="pt"
            )["input_ids"].to(eval_device)
            output = model.generate(
                input_ids=input_ids,
                max_length=2048, 
                return_dict_in_generate=True, 
                output_scores=True, 
                num_beams=1,
                output_hidden_states=True,
                repetition_penalty=1.25
            )
    
        generated_seq = tokenizer.decode(
            output["sequences"][0], skip_special_tokens=True, temperature=0, top_p=0.8, repetition_penalty=1.25
        )

        # print(generated_seq)

        # generated_seq = generated_seq[len(input_s[0]):]
        
        # print(generated_seq)
        
        vicuna_steps = generate_module_list_with_gpt(generated_seq[len(input_s[0]):]).split(",")
        module_list = match_module_seq(vicuna_steps, sentence_model)
        # module_list = "Image Denoising, Image Deblurring, Colorization"
        print(module_list)
    
        if len(module_list) >= 1 and whole_module_seq_filter(module_list, test_task_idx[i]):
            seqCombination.construct_module_seq(module_list)
    
            for idx, batch in tqdm(enumerate(test_dataloaders[i])):
                inputs = list(batch['input'][0])
                # print("Inputs: ", inputs)
                predictions = seqCombination.run_module_seq(inputs)
                # try:
                #     predictions = seqCombination.run_module_seq(inputs)
                #     print(prediction)
                # except:
                #     ave_task_reward = 0
                #     break
    
                if 0 <= test_task_idx[i] <= 14:
                    outputs = list(batch['output'][0])
                    dist = image_similarity(predictions, outputs, vit, vit_extractor)
                    task_rewards.append(dist / 100)
                elif 15 <= test_task_idx[i] <= 104 or 107 <= test_task_idx[i]:
                    outputs = list(batch['output'][0])
                    f1 = np.mean(txt_eval(predictions, outputs, bertscore, device=eval_device))
                    
                    task_rewards.append(f1)
                else:
                    score = clip_score(predictions, inputs)
                    task_rewards.append(score.detach()/100)
                    
            ave_task_reward = np.mean(task_rewards)    
            seqCombination.close_module_seq()
                
        else:
            ave_task_reward = 0
    
        print(ave_task_reward)
            
        if 0 <= test_task_idx[i] <= 14:
            similarities.append(ave_task_reward)
        elif 15 <= test_task_idx[i] <= 104 or 107 <= test_task_idx[i]:
            berts.append(ave_task_reward)
        else:
            clips.append(ave_task_reward)
    
        rewards.append(ave_task_reward)     

    # print("clips")
    # print(clips)
    # print("berts")
    # print(berts)
    # print("similarities")
    # print(similarities)

    avg_clips = np.mean(clips)
    avg_berts = np.mean(berts)
    avg_similarities = np.mean(similarities)
    avg_rewards = (avg_clips + avg_berts + avg_similarities) / 3

    res = [avg_clips, avg_berts, avg_similarities, avg_rewards]

    print(res)

    with open("zero-llama2.txt", "a") as a:
        write_str = ", ".join([str(i) for i in res]) + "\n"
        a.write(write_str)

    total_avg_clips.append(avg_clips)
    total_avg_berts.append(avg_berts)
    total_avg_similarities.append(avg_similarities)
    total_avg_rewards.append(avg_rewards)
    # print([avg_clips, avg_berts, avg_similarities, avg_rewards])

print([total_avg_clips, total_avg_berts, total_avg_similarities, total_avg_rewards])

print([np.mean(total_avg_clips), np.mean(total_avg_berts), np.mean(total_avg_similarities), np.mean(total_avg_rewards)])

print("Finished testing!")  