# Adversarial Simulator for Conversation

In [None]:
from pathlib import Path
from azure.ai.generative.evaluate import evaluate
import json
from azure.ai.generative.synthetic.simulator import Simulator
from azure.ai.resources.client import AIClient
from azure.identity import DefaultAzureCredential
from azure.ai.resources.entities import AzureOpenAIModelConfiguration
from openai import AsyncAzureOpenAI

sub = ""
rg = ""
project_name = ""
oai_client = AsyncAzureOpenAI(api_key="", azure_endpoint="", api_version="2023-12-01-preview")

## Initialize the simulator and get the adversarial template

In [None]:
async_oai_chat_completion_fn = oai_client.chat.completions.create
client = AIClient(
    subscription_id=sub, resource_group_name=rg, project_name=project_name, credential=DefaultAzureCredential()
)

simulator = Simulator.from_fn(fn=async_oai_chat_completion_fn, ai_client=client, model="gpt-4", max_tokens=300)
adv_template = Simulator.get_template("adv_conversation")

## Run the simulator

In [None]:
outputs = await simulator.simulate_async(
    adv_template,
    max_conversation_turns=5,
    api_call_delay_sec=5,
    max_simulation_results=5,
)

## Print the outputs

In [None]:
for line in outputs:
    print(json.dumps(line, indent=2))

## Store outputs in file for eval

In [None]:
eval_name = "conv-oai-conversation-eval"
file_name = f"\{eval_name}.jsonl"
jsonl_object = outputs.to_json_lines()
with Path.open(file_name, "w") as f:
    f.write(jsonl_object)

## Run evaluate

In [None]:
tracking_uri = client.tracking_uri
connection = client.connections.get("Default_AzureOpenAI")
config = AzureOpenAIModelConfiguration.from_connection(connection, model_name="gpt-4", deployment_name="gpt-4")
data_path = Path.cwd() / f"\{eval_name}.jsonl"
output_path = Path.cwd() / "downloaded_artifacts" / "remote"
result = evaluate(
    evaluation_name=eval_name,
    data=data_path,
    task_type="qa",
    metrics_list=["hate_fairness", "violence", "self_harm", "sexual"],
    model_config=config,
    data_mapping={"messages": "messages"},
    tracking_uri=tracking_uri,
    output_path=output_path,
)
print(result)
print(result.metrics_summary)  # will print the defect rate for each content harm
print("Studio URL")
print(result.studio_url)