In [None]:
import pickle
import argparse
import os

from mobile_safety.logger import Logger
from mobile_safety.environment import MobileSafetyEnv
from mobile_safety.prompt._prompt import PromptBuilder

from mobile_safety.agent.gpt_agent import GPTAgent
from mobile_safety.agent.gemini_agent import GeminiAgent
from mobile_safety.agent.claude_agent import ClaudeAgent

from cip.utils import parse_cid, save_image
from cip.mobilesafetybench import load_task
from cip.prompts import (
    REFINEMENT_SYSTEM_PROMPT,
    REFINEMENT_USER_PROMPT,
)
from dotenv import load_dotenv

PROVIDER = {
    "anthropic": "cip.anthropic",
    "google": "cip.google",
    "openai": "cip.openai",
}
load_dotenv()
CIP_HOME = os.getenv('CIP_HOME')

In [None]:
def parse_args():
    parser = argparse.ArgumentParser()
    
    # device
    parser.add_argument('--avd_name', type=str, default='pixel_7_test_00')
    parser.add_argument('--avd_name_sub', type=str, default='pixel_7_test_01')
    parser.add_argument('--port', type=int, default=5558)
    parser.add_argument('--appium_port', type=int, default=4724)
    
    # task
    parser.add_argument('--task_id', type=str, default='message_forwarding')
    parser.add_argument('--scenario_id', type=str, default='high_risk_1')
    parser.add_argument('--prompt_mode', type=str, default='cip', 
                        choices=['basic', 'scot', 'cip'])
    
    # agent
    parser.add_argument('--model', type=str, default="gpt-4o-2024-08-06",
                        choices=["gpt-4o-2024-08-06", "gemini-1.5-pro-002", "claude-3-5-sonnet-20241022"]) 
    
    # others
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--gui', type=bool, default=True)
    parser.add_argument('--delay', type=float, default=10.0)
    
    # parse args
    args, _ = parser.parse_known_args()
    return args

args = parse_args()

In [None]:
env = MobileSafetyEnv(
	avd_name=args.avd_name,
	avd_name_sub=args.avd_name_sub,
	gui=args.gui,
	delay=args.delay,
	task_tag=f'{args.task_id}_{args.scenario_id}',
	prompt_mode=args.prompt_mode,
	port=args.port,
	appium_port=args.appium_port,
	model_name=args.model,
)

logger = Logger(args)
prompt_builder = PromptBuilder(env)

if 'gpt' in args.model:
	agent = GPTAgent(model_name=args.model, seed=args.seed, port=args.port,)
if 'gemini' in args.model:
	agent = GeminiAgent(model_name=args.model, seed=args.seed, port=args.port,)
if 'claude' in args.model:
	agent = ClaudeAgent(model_name=args.model,seed=args.seed, port=args.port,)
 

In [None]:
if 'gpt' in args.model:
	provider = "openai"
if 'gemini' in args.model:
	provider = "google"
if 'claude' in args.model:
	provider = "anthropic"

cid_path = f"{CIP_HOME}/cip/cids/mobilesafetybench/{provider}/{env.model_name}/{env.task_tag}.pkl"
with open(cid_path, "rb") as f:
	cid = pickle.load(f)

task = load_task(args.task_id, args.scenario_id)
instruction = task["instruction"]
tools = task["action_space"]

module = __import__(PROVIDER[provider], fromlist=['refine_cid'])
refine_cid = getattr(module, "refine_cid")

In [None]:
# reset the environment
timestep = env.reset()
prompt = prompt_builder.build(
            parsed_obs=env.parsed_obs,
            action_history=env.evaluator.actions[1:],
            action_error=env.action_error,
        )

# logging
logger.log(timestep=timestep)

In [None]:
while True:
 
	response_dict, final_prompt = agent.get_response(
     	timestep=timestep, 
		system_prompt=prompt.system_prompt, 
		user_prompt=prompt.user_prompt,
	)

	# check response
	if response_dict["action"] == None:
		print("Error in response")
 
	# env.step
	action = response_dict["action"]
	timestep_new = env.step(action)
	if timestep_new is None:
		continue
	timestep = timestep_new
	
	# logging
	logger.log(prompt=final_prompt, response_dict=response_dict, timestep=timestep)
 
	# check end of timestep
	if timestep.last() or env.evaluator.progress["finished"]:
		break
 
	# refine cid	
	system_prompt = REFINEMENT_SYSTEM_PROMPT
	cid_string = parse_cid(cid)
	user_prompt = REFINEMENT_USER_PROMPT.format(instruction=instruction, 
                                                action_space=tools,
                                                current_cid=cid_string,
                                                recent_action=env.evaluator.actions[-1],
                                                recent_observation=env.parsed_obs,
                                                )
	image_path = save_image(timestep)
	cid = refine_cid(
		model_name=env.model_name,
		system_prompt=system_prompt,
		user_prompt=user_prompt,
		cid_module=cid,
		image_path=image_path,
		verbose=True
	)
	prompt = prompt_builder.refined_prompt(
				parsed_obs=env.parsed_obs,
				action_history=env.evaluator.actions[1:],
				action_error=env.action_error,
				cid_module=cid,
			)

print("\n\nReward:", timestep_new.curr_rew)

In [None]:
cid.cid.draw()

In [None]:
print(parse_cid(cid))

In [None]:
env.close()