# ViUniT Demo

## Environment Setup

In [None]:
%cd /export/einstein-vision-hs/visual_unit_testing/
import os
os.environ['PYTHONPATH'] = '/nlpgpu/data/artemisp/visual_unit_testing:$CONDA_PREFIX/'
os.environ['HF_HOME'] = '/nlpgpu/data/artemisp/.cache/huggingface'
os.environ['TORCH_HOME'] = '/nlpgpu/data/artemisp/visual_unit_testing/.cache/'
os.environ['HF_ACCESS_TOKEN'] = '<HF_TOKEN>'
os.environ['HF_TOKEN'] = '<HF_TOKEN>'
os.environ['CUDA_HOME'] = os.environ['CONDA_PREFIX']
os.environ['CONFIG_NAMES'] = 'demo_config'
os.environ["GQA_IMAGE_PATH"] = "/nlp/data/vision_datasets/GQA"
os.environ["WINOGROUND_IMAGE_PATH"] = "/nlp/data/vision_datasets/winoground/data/images"
os.environ["COCO_VAL2017"] = "/nlp/data/vision_datasets/winoground/data/images"
plot_path = '/nlpgpu/data/artemisp/visual_unit_testing/plots/'

## Load Imports and ImagePatch API

In [None]:
import sys
import re
import pickle
import os
import omegaconf
from unit_test_generation.processing import extract_unit_tests, get_unit_test_prompt, get_grounded_diffusion_prompt
from unit_test_generation.unit_test_sampling import TextSampler
from utils import (load_config,
                    get_base_prompt,
                    get_visual_program_prompt,
                    extract_python_code,
                    get_visual_program_correction_prompt,
                    SynonymChecker,
                    set_seed,
                    get_fixed_code,
                    SYNTAX_ERRORS,
                    initialize_image_generator
                    )
from viper_configs import viper_config
from tqdm import tqdm
import ast
import copy
from transformers import AutoTokenizer

base_config = '/export/einstein-vision-hs/visual_unit_testing/viunit_configs/base.yaml'
this_config = omegaconf.OmegaConf.load(base_config)


llm_model_name = viper_config.llm.model_id
codex_model_name = viper_config.codex.codellama_model_name
# unit test generation
unit_test_system_prompt = open(
    this_config['unit_test_generation']['generation']['prompt_file']).read()
unit_test_in_context_examples = open(
    this_config['unit_test_generation']['generation']['in_context_examples_file']).read()
    
llm_tokenizer = AutoTokenizer.from_pretrained(
    llm_model_name, trust_remote_code=True, token=os.getenv('HF_ACCESS_TOKEN'),)

correction_prompt = open(
this_config['visual_program_generator']['generation']['correction_prompt_file']).read()

# if this_config['image_generation']['image_source'] == 'diffusion' and 'lmd' in this_config['image_generation']['diffusion_model_name']:
lm_grounded_diffusion_in_context_prompt = open(this_config['image_generation']['generation']['in_context_examples_file']).read().strip()
lm_grounded_diffusion_system_prompt = open(this_config['image_generation']['generation']['prompt_file']).read().strip()


base_prompt = get_base_prompt(this_config['visual_program_generator']['generation']['prompt_file'],
                              this_config['visual_program_generator']['generation']['in_context_examples_file'],
                              this_config['visual_program_generator']['generation']['num_in_context_examples']
                              )
program_tokenizer = AutoTokenizer.from_pretrained(
    codex_model_name,  token=os.getenv('HF_ACCESS_TOKEN'), trust_remote_code=True
)

In [None]:
# from vision_processes import forward
from main_simple_lib import *

In [None]:
text_sampler = TextSampler(model_name = this_config['unit_test_sampling']['model_name'],
            sampling_strategy=this_config['unit_test_sampling']['strategy'],
            filter_long_answers=this_config['unit_test_sampling']['filter_long_answers']
            )
this_config['image_generation']['return_image'] = True
image_generator = initialize_image_generator(this_config)

## Utility Functions for Unit Tests

In [None]:
def get_unit_tests(query,program=None, num_unit_tests=3):
    print("Generating unit tests...")
    prompts = get_unit_test_prompt(
                    [query], 
                    unit_test_system_prompt,
                    unit_test_in_context_examples, 
                    llm_model_name, 
                    llm_tokenizer
                    )
    output = forward(model_name='llm_general', prompt=copy.deepcopy(prompts), queues=None, min_new_tokens=10, max_new_tokens=180, return_full_text=False, top_p=0.9, do_sample=True, num_return_sequences=3)
    output = output.split('assistant')[-1]
    unit_tests = extract_unit_tests(output)
    if len(unit_tests) == 0:
        output = forward(model_name='llm_general', prompt=copy.deepcopy(prompts), queues=None, min_new_tokens=10, max_new_tokens=180, return_full_text=False, top_p=0.9, do_sample=True, num_return_sequences=3)
        output = output.split('assistant')[-1]
        unit_tests = extract_unit_tests(output)
    print("Sampling unit tests...")
    unit_tests = text_sampler.sample(unit_tests, num_unit_tests)
    if this_config['image_generation']['image_source'] == 'diffusion' and 'lmd' in this_config['image_generation']['diffusion_model_name']:
        grounded_diffusion_prompt = get_grounded_diffusion_prompt(
                        [ut[0].replace('"', '') for ut in unit_tests],
                        lm_grounded_diffusion_system_prompt,
                        lm_grounded_diffusion_in_context_prompt, 
                        llm_model_name, 
                        llm_tokenizer)
        llm_response = []
        print("Generating grounded diffusion prompts...")
        for p in tqdm(grounded_diffusion_prompt):
            llm_response.append(forward(model_name='llm_general', prompt=[p], queues=None, min_new_tokens=10, max_new_tokens=320, return_full_text=False).split('assistant')[-1])
        images = image_generator.batch_fetch_image([ut[0] for ut in unit_tests], llm_response)
    else:
        print("Generating images...")
        images = image_generator.batch_fetch_image([ut[0] for ut in unit_tests])
    im_per_test = this_config['image_generation']['return_k']
    images = [images[i*im_per_test:(i+1)*im_per_test] for i in range(len(unit_tests))]
    
    return unit_tests, images

def get_program(query, **kwargs):
    output = forward(model_name='codellama', prompt=[query], queues=None, **kwargs)
    program = [extract_python_code(o) for o  in output]
    return program


def exec_code(image, code):
    code = extract_python_code(code)
    if "def execute_command(image)" not in code:
        code = "def execute_command(image):\n" + code
    code = ast.unparse(ast.parse(code))
    syntax_2 = Syntax(code, "python", theme="light", line_numbers=True, start_line=0)
    code = code.replace("def execute_command(image)","def execute_command(image, my_fig, time_wait_between_lines, syntax)") 
    print(code)
    print('-------------------')
    if isinstance(image, Image.Image):
        img = image
    elif 'http' in image:
        img = Image.open(requests.get(image, stream=True).raw)
    else:
        img = Image.open(image)
    if img.mode != 'RGB':
        img = img.convert('RGB')
    try:
        execute_code((code, syntax_2), img, show_intermediate_steps=True)
    except Exception as e:
        print(e)
    
    
def run_fixed_code(query, image):
    fixed_code = get_fixed_code('GQA').format(query)
    
    exec_code(image, fixed_code)

def print_unit_tests(query,unit_tests, images):
    print(unit_tests, images)
    print("Question: ", query)
    for i, (ut, img) in enumerate(zip(unit_tests, images)):
        print(f'Unit Test {i+1}: {ut[0]}, {ut[1]}')
        for im in img:
            if isinstance(im[1], list):
                im = im[1][0]
            else:
                im = im[1]
            im.resize((256,256)).show()
        print('-------------------')

def exec_unit_tests(code, unit_tests, images):
    for i, (ut, img) in enumerate(zip(unit_tests, images)):
        for im in img:
            if isinstance(im[1], list):
                im = im[1][0]
            else:
                im = im[1]
            print(f'Unit Test {i+1}: {ut[0]}, {ut[1]}')
            exec_code(im, code)
        print('-------------------')
    

## Code Demo

In [None]:
demo_image = "https://cdn.mos.cms.futurecdn.net/4TDZhQ9ZDtt4GZznENbhs7-768-80.jpg.webp"
demo_question = "What color are the pillows on the dark couch?"

In [None]:
unit_tests, images = get_unit_tests(demo_question, num_unit_tests=2)
print_unit_tests(demo_question, unit_tests, images)

In [None]:
code = get_program(demo_question, num_return_sequences=3, do_sample=True, top_p=0.9, max_new_tokens=320)

In [None]:
for code in code:
    exec_unit_tests(code, unit_tests, images)

## Gradio Demo UI

In [None]:
import gradio as gr
# Gradio function
def generate_and_display_unit_tests(query):
    unit_tests, images = get_unit_tests(query)
    results_text = ""
    result_images = []
    for i, (ut, img) in enumerate(zip(unit_tests, images)):
        result_text = f"Unit Test {i+1}: {ut[0]} - {ut[1]}"
        results_text += result_text + "\n\n"
        result_images.extend([im[1] if isinstance(im[1], Image.Image) else im[1][0] for im in img])
    
    return results_text, result_images


# Create the Gradio Interface
interface = gr.Interface(
    fn=generate_and_display_unit_tests,
    inputs=gr.Textbox(label="Enter your query:"),
    outputs=[gr.Markdown(), gr.Gallery(label="Generated Images")],
    title="Unit Test and Image Generator",
    description="Enter a query to generate unit tests and corresponding images."
)


# Launch the interface
interface.launch(share=True, inline=True)