# MAIA Demo

#### Many of MAIA's experiments are available in the [experiment browser](https://multimodal-interpretability.csail.mit.edu/maia/experiment-browser/) ####

In [None]:
import os

import openai
from dotenv import load_dotenv

# Some imports require api key to be set ######
# Load environment variables
load_dotenv()

# Load OpenAI API key
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.organization = os.getenv("OPENAI_ORGANIZATION")
###############################################


from maia_api import System, Tools, SyntheticSystem
from utils.DatasetExemplars import DatasetExemplars, SyntheticExemplars
from utils.main_utils import generate_save_path, create_bias_prompt, retrieve_synth_label
from utils.InterpAgent import InterpAgent
from utils.api_utils import str2image
from utils.ExperimentEnvironment import ExperimentEnvironment
from utils.call_agent import ask_agent

Stable-Diffusion 3.5 requires access to run. You can request it here: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium

In [None]:
# Login for access to sd-3.5
from huggingface_hub import login
login()

### Utils

In [2]:
# Plot the results from the experiment log
def plot_results_notebook(log_entry):
    if (log_entry['role'] == 'assistant'):
        print('\n\n*** MAIA: ***\n\n')  
    for item in log_entry['content']:
        if item['type'] == 'text':
            print(item['text'])
        elif item['type'] == 'image_url':
            display(str2image(item['image_url']['url'].split(',')[1]))

In [None]:
# MAIA's experiment loop, redefined here to display to the notebook after each step

class InterpAgentDemo(InterpAgent):
    def run_experiment(self, system: System, tools: Tools, save_html=False):
        """Runs the experiment loop. """

        # Make sure experiment log is clean
        self._init_experiment_log()
        experiment_env = ExperimentEnvironment(system, tools, globals())
        # Set Tools to point to this CodeAgent
        temp_agent, tools.agent = tools.agent, self
        # Experiment loop
        round_count = 0
        while True:
            round_count += 1
            model_experiment = ask_agent(self.model_name, self.experiment_log)
            self.update_experiment_log(role='model', type="text", type_content=str(model_experiment))
            if save_html:
                tools.generate_html(self.experiment_log)
            if self.debug:
                print(model_experiment)
            
            if round_count > self.max_round_count:
                self._overload_instructions()
            else:
                if self.end_experiment_token in model_experiment:
                    break
                
                try:
                    experiment_output = experiment_env.execute_experiment(model_experiment)
                    if experiment_output != "":
                        self.update_experiment_log(role='user', type="text", type_content=experiment_output)
                except ValueError:
                    self.update_experiment_log(role='execution', 
                                               type="text", 
                                               type_content=f"No code to run was provided, please continue with the experiments based on your findings, or output your final {self.end_experiment_token}.")
            # Plot the results from the experiment log
            for log_entry in self.experiment_log:
                plot_results_notebook(log_entry)
        if save_html:
                tools.generate_html(self.experiment_log)
        
        # Restore tools to its original state
        tools.agent = temp_agent

### Arguments

In [None]:
maia_model = 'gpt-4o'
task = 'neuron_description'
n_exemplars = 15
model_name = "resnet152"
layer = "layer4"
neuron_num = 20
images_per_prompt = 1
path2save = '../results'
path2prompts = './prompts'
path2exemplars = './exemplars'
path2indices = './neuron_indices'
device = 1
text2image = 'sd'
debug = True

unit_config = {model_name: {layer: [neuron_num]}}

path2save = generate_save_path(path2save, maia_model, "test")
print(path2save)
os.makedirs(path2save, exist_ok=True)

### Initialize MAIA

In [None]:
# Prompt needs to be created dynamically for bias_discovery so class label can be inserted
if task == "bias_discovery":
    create_bias_prompt(path2indices, path2prompts, str(neuron_num))

# Add API configuration
api = [
    (System, [System.call_neuron]),
    (Tools, [Tools.text2image, Tools.edit_images, Tools.dataset_exemplars, 
             Tools.display, Tools.describe_images, Tools.summarize_images])
]
maia = InterpAgent(
    model_name=maia_model,
    api=api,
    prompt_path=path2prompts,
    api_prompt_name="api.txt",
    user_prompt_name=f"user_{task}.txt",
    overload_prompt_name="final.txt",
    end_experiment_token="[FINAL]",
    max_round_count=15,
    debug=debug
)
if model_name == "synthetic":
    net_dissect = SyntheticExemplars(
        os.path.join(path2exemplars, model_name),
        path2save,
        layer
    )
    gt_label = retrieve_synth_label(layer, neuron_num)
    system = SyntheticSystem(neuron_num, gt_label, layer, device)
else:
    net_dissect = DatasetExemplars(
        path2exemplars,
        n_exemplars,
        path2save,
        unit_config
    )
    system = System(model_name, layer, neuron_num, net_dissect.thresholds, device)

tools = Tools(
    path2save,
    device,
    maia,
    system,
    net_dissect,
    images_per_prompt=images_per_prompt,
    text2image_model_name=text2image,
    image2text_model_name=maia_model
)

print("HTML path: ", tools.html_path)

## Maia's api and user prompts

In [None]:
plot_results_notebook(maia.experiment_log[0])
plot_results_notebook(maia.experiment_log[1])

## Experiment

In [None]:
# Run full experiment
maia.run_experiment(system, tools, save_html=True)

### Manual Experiment
To manually run experiment steps, run this cell repeatedly. 

In [None]:
# To reset maia's experiment log to just the api and user prompt, run
# maia._init_experiment_log()
model_experiment = ask_agent(model_name, maia.experiment_log)
experiment_output = maia.experiment_env.execute_experiment(model_experiment)
maia.update_experiment_log(role='user', type="text", type_content=experiment_output)
# Plot the results from the experiment log
for log_entry in maia.experiment_log:
    plot_results_notebook(log_entry)