In [None]:
import transformers
import torch

model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
pipeline = transformers.pipeline(
    task="text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
    pad_token_id=transformers.AutoTokenizer.from_pretrained(model_id).eos_token_id
)

In [None]:
messages = [
    {
        "role": "system",
        "content": '''You are a language model that strictly outputs JSON format content.
Each user input is a sentence that commands an object detector to change its detection results. 
Analyze the user's intent and meaning, and include the analysis results in the following JSON fields:
Supported classes: ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'].
1. attack_type: One of 'remove', 'misclassify', or 'generate'.
2. attack_mode: One of 'targeted' or 'untargeted'.
3. victim_class: One of supported classes or 'null'.
4. target_class: One of supported classes or 'null'.
Here are explainations of the combinations of attack_type and attack_mode and some rules to follow:
1. remove + targeted: Remove the victim_class from the detection results. Rule: In this case, the victim_class should not be 'null' and target_class should be 'null'.
2. remove + untargeted: Remove all classes from the detection results. Rule: In this case, the victim_class and target_class should both be 'null'.
3. misclassify + targeted: Misclassify the victim_class to target_class. Rule: In this case, both victim_class should and target_class should not be 'null'.
4. misclassify + untargeted: Misclassify all classes to another class, i.e. classify all objects wrongly. Rule: In this case, the victim_class and target_class should both be 'null'.
5. generate + untargeted: Generate a lot of randomly scattered and classified boxes in the image. Rule: In this case, the victim_class and target_class should both be 'null'.
Ensure the output is always in valid JSON format. No other fields are allowed in the output. Follow the rules strictly.'''
    },
    {"role": "user", "content": "Let this girl become a cat."},
]

outputs = pipeline(
    messages,
    max_new_tokens=256,
)
print(outputs[0]["generated_text"][-1]['content'])

# Untar

In [None]:
from mmdet.AnywhereDoor.curse_templates import CurseTemplates
import json
import tqdm

curse_templates = CurseTemplates()

success = 0
total = 0

failed_curses = []

for attack_type in ['remove', 'misclassify', 'generate']:
    attack_mode = 'untargeted'
    templates = curse_templates.templates[attack_type][attack_mode]['known'] + curse_templates.templates[attack_type][attack_mode]['unknown']
    print("Processing {} {}...".format(attack_mode, attack_type))
    pbar = tqdm.tqdm(templates)
    for curse in pbar:
        messages.pop()
        messages.append({"role": "user", "content": curse})

        outputs = pipeline(messages, max_new_tokens=256, top_p=1.0, top_k=0)
        generated_text = outputs[0]["generated_text"][-1]['content']
        try:
            start_idx = generated_text.index('{')
            end_idx = generated_text.rindex('}') + 1
            json_str = generated_text[start_idx:end_idx]
            parsed_json = json.loads(json_str)
        except (ValueError, json.JSONDecodeError) as e:
            print("=====================================")
            print(f"Error parsing JSON: {e}")
            print(f"Curse: {curse}")
            print(f"Generated text: {generated_text}")
            print("=====================================")
            continue

        is_field_correct = False
        if 'attack_type' not in parsed_json or 'attack_mode' not in parsed_json or 'victim_class' not in parsed_json or 'target_class' not in parsed_json:
            is_field_correct = False
        else:
            is_field_correct = True
            pred_attack_type = parsed_json['attack_type']
            pred_attack_mode = parsed_json['attack_mode']
            pred_victim_class = parsed_json['victim_class']
            pred_target_class = parsed_json['target_class']

        if is_field_correct and \
                pred_attack_type == attack_type and \
                pred_attack_mode == attack_mode:
                # pred_victim_class in ['null', None, ''] and \
                # pred_target_class in ['null', None, '']:
            success += 1
        else:
            failed_curses.append({
                "curse": curse,
                "attack_type": attack_type,
                "attack_mode": attack_mode,
                "parsed_json": parsed_json
            })
        total += 1
print(success)
print(total)

In [None]:
for failed_curse in failed_curses:
    print(failed_curse['curse'])
    print(f"attack_type: {failed_curse['attack_type']}")
    print(f"attack_mode: {failed_curse['attack_mode']}")
    print(f"parsed_json: {failed_curse['parsed_json']}")
    print("=======================================")

# Tar Remove

In [None]:
success = 0
total = 0

failed_curses = []

templates = curse_templates.templates['remove']['targeted']['known'] + curse_templates.templates['remove']['targeted']['unknown']
# print(templates)
voc_classes=[
    'aeroplane', 'bicycle'
]
print("Processing {} {}...".format('remove', 'targeted'))
pbar = tqdm.tqdm(templates)
for curse in pbar:
    for victim_class in voc_classes:
        curse = curse.replace('[victim_class]', victim_class)
        messages.pop()
        messages.append({"role": "user", "content": curse})

        outputs = pipeline(messages, max_new_tokens=256, top_p=1.0, top_k=0)
        generated_text = outputs[0]["generated_text"][-1]['content']
        try:
            start_idx = generated_text.index('{')
            end_idx = generated_text.rindex('}') + 1
            json_str = generated_text[start_idx:end_idx]
            parsed_json = json.loads(json_str)
        except (ValueError, json.JSONDecodeError) as e:
            print("=====================================")
            print(f"Error parsing JSON: {e}")
            print(f"Curse: {curse}")
            print(f"Generated text: {generated_text}")
            print("=====================================")
            continue

        is_field_correct = False
        if 'attack_type' not in parsed_json or 'attack_mode' not in parsed_json or 'victim_class' not in parsed_json or 'target_class' not in parsed_json:
            is_field_correct = False
        else:
            is_field_correct = True
            pred_attack_type = parsed_json['attack_type']
            pred_attack_mode = parsed_json['attack_mode']
            pred_victim_class = parsed_json['victim_class']
            pred_target_class = parsed_json['target_class']

        if is_field_correct and \
                pred_attack_type == 'remove' and \
                pred_attack_mode == 'targeted' and \
                pred_victim_class == victim_class and \
                pred_target_class in ['null', None, '']:
            if pred_attack_type != 'remove':
                print(f"attack_type: {pred_attack_type}")
                print(f"pred_attack_type: {pred_attack_type}")
            if pred_attack_mode != 'targeted':
                print(f"attack_mode: {pred_attack_mode}")
                print(f"pred_attack_mode: {pred_attack_mode}")
            if pred_victim_class != victim_class:
                print(f"victim_class: {pred_victim_class}")
                print(f"pred_victim_class: {pred_victim_class}")
            if pred_target_class not in ['null', None, '']:
                print(f"target_class: {pred_target_class}")
                print(f"pred_target_class: {pred_target_class}")
            success += 1
        else:
            failed_curses.append({
                "curse": curse,
                "is_field_correct": is_field_correct,
                "pred_attack_mode": pred_attack_mode,
                "pred_attack_type": pred_attack_type,
                "pred_victim_class": pred_victim_class,
                "pred_target_class": pred_target_class,
                "attack_type": 'remove',
                "attack_mode": 'targeted',
                "parsed_json": parsed_json
            })
        total += 1
print(success)
print(total)

In [None]:
for failed_curse in failed_curses:
    print(failed_curse['curse'])
    print(f"is_field_correct: {failed_curse['is_field_correct']}")
    print(f"pred_attack_mode: {failed_curse['pred_attack_mode']}")
    print(f"pred_attack_type: {failed_curse['pred_attack_type']}")
    print(f"pred_victim_class: {failed_curse['pred_victim_class']}")
    print(f"pred_target_class: {failed_curse['pred_target_class']}")
    print(f"attack_type: {failed_curse['attack_type']}")
    print(f"attack_mode: {failed_curse['attack_mode']}")
    print(f"parsed_json: {failed_curse['parsed_json']}")
    print("=======================================")