In [None]:
import requests
def download_image_from_url(url, filename):
    img_data = requests.get(url).content
    with open(filename, 'wb') as handler:
        handler.write(img_data)

## Define Hyperparameters of the experiment

In [None]:
## Define Hyperparameters of the experiment
gradio_link = "https://.....gradio.live"

# about the LVLM using AWS Bedrock
aws_access_key_id= "yours_aws_access_key_id",
aws_secret_access_key= "yours_aws_secret_access_key", 
region_name= "yours_region_name"
model_id = "anthropic.claude-3-5-sonnet-20241022-v2:0" # model_id of the LVLM

## Load the Pretrained Stable Diffusion Model

In [None]:
import cece
from cece.queries import *
from cece.refine import *
from cece.wordnet import *

import pickle
import os
from tqdm import tqdm

## LVLM

In [None]:
import boto3
# Define the aws runtime and model
bedrock_runtime_client = boto3.client(
    'bedrock-runtime',
    aws_access_key_id= aws_access_key_id,
    aws_secret_access_key= aws_secret_access_key,
    region_name=region_name
)

model_id = "anthropic.claude-3-5-sonnet-20241022-v2:0"

## Classifier

In [None]:
# from BDD100k_classifier import BDD100k_classifier

# classifier = BDD100k_classifier()

from claude_predictor import *

classification_prompt = f"""
Classify each image in their appropriate class according to the driving situation they depict. 
Valid class labels are 'start' or 'stop' and only these, depending on whether the car has to move or stop based on its surroundings.
You need to classify the images in one of these classes.
Pay attention to the semantics that define each class.
Return me only the label of the scene depicted and nothing else.
"""

prompt_analyze = """
Please analyze the images in detail and answer the following question with reason based on these images. 
"""

text_prompt = f"""
Based on your analysis above, classify each image in their appropriate class according to the driving situation they depict. 
Valid class labels are 'start' or 'stop' and only these, depending on whether the car has to move or stop based on its surroundings.
Pay attention to the semantics that define each class.
You need to classify the images in one of these classes.
Return me only the label of the scene depicted and nothing else.
"""


image_names = os.listdir("../Datasets/BDD100K/bdd100k/images/10k/train/")


class Claude_classifier: 
    
    def __init__(self):
        pass
    def classify(self, image_name):
        source_classes = defaultdict(list)
        pred = predict_classes_claude([image_name], source_classes, 
                                                classification_prompt, prompt_analyze, text_prompt, analyze=False)
        pred = pred[image_name][0]
        if pred == "stop":
            return 0
        else:
            return 1

classifier = Claude_classifier()    
classifier.classify(image_names[0])

In [None]:
import json 

with open ("bdd100k/labels/sem_seg/rles/sem_seg_train.json") as handle:
    segs = json.load(handle)

In [8]:
dataset = []
labels = []
index_to_image_id = {}
image_id_to_index = {}

for i, row in tqdm(enumerate(segs["frames"])):
    objs = []
    for obj in row["labels"]:
        objs.append(obj["category"])
    
    dataset.append(objs.copy())
    image_id = row["name"]
    labels.append(classifier.classify(os.path.join("bdd100k/images/10k/train", image_id)))
    index_to_image_id[i] = image_id
    image_id_to_index[image_id] = i
        

7000it [00:00, 39392.54it/s]


In [9]:
def export_text_edits(edits):
    """
    Processes a dictionary of text edits by filtering and transforming elements based on specific criteria.

    The function expects a dictionary with three keys: "transf", "additions", and "removals".
    Each key should map to a list of elements (strings or list of strings).

    The "transf" key is expected to contain a list of tuples, where each tuple represents a pair of word lists.
    The function transforms each tuple by selecting the first word from each list in the pair that does not contain a period.

    For the "additions" and "removals" keys, which map to lists of lists of strings, the function flattens these lists and includes only those elements that do not contain a period.

    Parameters:
        edits (dict): A dictionary containing three keys:
            - "transf": A list of tuples, each containing two lists of words (word pairs).
            - "additions": A list of lists, where each inner list contains words to be added.
            - "removals": A list of lists, where each inner list contains words to be removed.

    Returns:
        dict: A dictionary with the same structure as the input but filtered and transformed based on the criteria:
            - "additions": List of words to be added, filtered to exclude words containing a period.
            - "removals": List of words to be removed, filtered to exclude words containing a period.
            - "transf": A list of transformed word pairs, each selected based on the absence of a period.
    """
    transf = []
    for e1, e2 in edits["transf"]:
        ee1, ee2 = None, None
        for e in e1:
            if "." not in e:
                ee1 = e
                break
                
        for e in e2:
            if "." not in e:
                ee2 = e
                break
        transf.append([ee1, ee2])
        
    return {
        "additions": [ee for e in edits["additions"] for ee in e if "." not in ee],
        "removals": [ee for e in edits["removals"] for ee in e if "." not in ee],
        "transf": transf
    }

In [10]:
from cece.xDataset import *
from cece.xDataset import createMSQ


msq_dataset = []
for row in dataset:
    msq = []
    for obj in row:
        try:
            msq.append(connect_term_to_wordnet(obj).union([obj]))
        except:
            try:
                msq.append(connect_term_to_wordnet(obj.replace(" ", "")).union([obj.replace(" ", "")]))
            except:
                pass
    msq_dataset.append(msq.copy())

ds = xDataset(dataset = msq_dataset,
              labels = labels,
              connect_to_wordnet = False)

def get_local_edits(image_id):
    source_index = image_id_to_index[image_id]
    source_image_id = index_to_image_id[source_index]
    objects_source = [dd for d in ds.dataset[source_index].concepts for dd in d if "." not in dd]

    target_index, cost = ds.explain(ds.dataset[source_index], labels[source_index])
    target_image_id = index_to_image_id[target_index]
    cost, edits = ds.find_edits(ds.dataset[source_index], ds.dataset[target_index])
    
    edits = export_text_edits(edits)
    added_objs = edits["additions"] + [e for [_, e] in edits["transf"]]
    removed_objs = edits["removals"] + [e for [e, _] in edits["transf"]]
    return objects_source, added_objs, removed_objs

## Editor

In [11]:
from edits import Edits
import ast
from editor import Editor
import boto3
from chat import Chat

from PIL import Image
import matplotlib.pyplot as plt

In [12]:
editor = Editor(gradio_link)

In [None]:
def global_explanations_calc(orig_label):

    ds = xDataset(dataset = msq_dataset,
              labels = labels,
              connect_to_wordnet = False)
    
    
    regional_dataset = []
    regional_labels = []
    for l, r in zip(labels[:200], dataset[:200]):
        if l == orig_label:
            regional_dataset.append(r)
            regional_labels.append(l)
    
    gl = ds.global_explanation(regional_dataset, regional_labels)
    return {k: v for k, v in gl.items() if "." not in k}

global_explanations_1 = global_explanations_calc(1)
global_explanations_0 = global_explanations_calc(0)
    
def global_explanations(orig_label):
    if orig_label == 0:
        return global_explanations_0
    else:
        return global_explanations_1
    
    
with open('global_explanations_claude-haiku_0.pickle', 'wb') as handle:
    pickle.dump(global_explanations_0, handle)
    
with open('global_explanations_claude-haiku_1.pickle', 'wb') as handle:
    pickle.dump(global_explanations_1, handle)

## Run the editor with the optimal edits

In [None]:
from prompts import prompt_single_step, prompt_add_object, prompt_remove_object

import os
import shutil

def create_or_replace_dir(directory_name):
    # Check if the directory already exists
    if os.path.exists(directory_name):
        # If it exists, remove it
        shutil.rmtree(directory_name)
    
    # Create the new directory
    os.makedirs(directory_name)

def edit_global_edits(image_id):
    
    create_or_replace_dir(f"imgs/bdd100k/claude-haiku/global-local/{image_id}")
    source_image_path = f"imgs/bdd100k/claude-haiku/global-local/{image_id}/source.jpg"
    
    url = "bdd100k/images/10k/train/" + image_id #data[image_id]["url"]
    shutil.copyfile(url, source_image_path)
    steps = []
    
    objs, added_objs, removed_objs = get_local_edits(image_id)
#     global_edits = global_explanations(source_image_path)
    
    chat = Chat(model_id, bedrock_runtime_client)

    logs = ""
    excs, i = 0, 1
    orig_label = classifier.classify(source_image_path)
    global_edits = global_explanations(orig_label)
    new_label = orig_label
    logs += f"Classification: {orig_label}\n"
    
    sorted_edits = {}
    for e in added_objs + removed_objs:
        if e in global_edits:
            v = global_edits[e]
        else:
            if e in removed_objs:
                v = -0.1
            if e in added_objs:
                v = 0.1
        
        sorted_edits[e] = v
        
    sorted_edits = [[k, v] for k, v in sorted(sorted_edits.items(), key=lambda item: abs(item[1]), reverse = True)]
    for o in global_edits:
        if global_edits[o] == 0:
            continue 
        if o not in added_objs + removed_objs:
            sorted_edits.append([o, global_edits[o]])
    
    print (sorted_edits)
    print (objs, added_objs, removed_objs)
    added_objects, removed_objs = [], []
    for [obj, v] in sorted_edits:
        try:
                
            if  v <= 0:
                if obj in added_objects:
                    continue
                    
                if obj in objs:
                    prompt = prompt_remove_object(obj)
                    print (prompt)
                    chat.add_user_message_image(prompt, source_image_path) # add a user message with an image and a text prompt
                    background = chat.generate()
                    print (background)
                    logs += f"\n----\nOutput LVLM: {i}\n{background}\n"
                    background = background.strip()

                    new_image, mask = editor.replacer(source_image_path, obj, background)
                    logs += f"\n{['remove', obj, background]}\n" 
                    steps.append(["remove", obj, background])
                    
                    source_image_path = f"imgs/bdd100k/claude-haiku/global-local/{image_id}/step_{i}.jpg"
                    new_image.save(source_image_path)
                    i += 1
                    new_label = classifier.classify(source_image_path)
                    logs += f"Classification: {new_label}\n"
                
                added_objects.append(obj)

            elif v >= 0:    
                if obj in removed_objs:
                    continue
                if obj not in objs: 
                    prompt = prompt_add_object(obj)
                    print (prompt)
                    chat.add_user_message_image(prompt, source_image_path) # add a user message with an image and a text prompt
                    add = chat.generate()
                    print (add)
                    logs += f"\n----\nOutput LVLM: {i}\n{add}\n"
                    add = add.strip()


                    new_image, mask = editor.replacer(source_image_path, add, obj)
                    logs += f"\n{['add', obj, add]}\n" 
                    steps.append(["add", obj, add])

                    source_image_path = f"imgs/bdd100k/claude-haiku/global-local/{image_id}/step_{i}.jpg"
                    new_image.save(source_image_path)
                    i += 1
                    
                    new_label = classifier.classify(source_image_path)
                    logs += f"Classification: {new_label}\n"
                removed_objs.append(obj)

        except Exception as e:
            excs += 1
            logs += f"Exception: {e}\n"
            if excs >= 5:
                break
                
                
        if (orig_label != new_label):
            break

    logs += f"\n\n----\n\n{steps}\n\n----\n\n"
    with open(f"imgs/bdd100k/claude-haiku/global-local/{image_id}/logs.txt", "w") as handle:
        handle.write(logs)
    return steps
        
    
def classified_as(image_path, cl):
    preds = classifier.classify(image_path)
    if preds == cl:
        return True
    return False 

In [None]:
from tqdm import tqdm

for key in tqdm(image_id_to_index):
    if not os.path.exists(f"imgs/bdd100k/claude-sonnet/global-local/{key}/step_1.jpg"):
        edit_global_edits(key)