In [1]:
# from mmdet.AnywhereDoor.curse_encoder import CurseEncoder
from mmdet.AnywhereDoor.curse_templates import CurseTemplates
import os
import torch
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
from transformers import LlamaTokenizer, LlamaModel, LlamaForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class CurseEncoder():
    def __init__(self, enc_id, hf_token, device='cuda:0'):
        self.tokenizer = LlamaTokenizer.from_pretrained(enc_id, token=hf_token)
        self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
        self.tokenizer.padding_side = "right"
        torch_dtype = 'auto' if torch.cuda.is_available() else torch.float32
        self.llm = LlamaForCausalLM.from_pretrained(enc_id, token=hf_token, torch_dtype=torch_dtype).to(device)
        self.llm.resize_token_embeddings(len(self.tokenizer))
        self.llm.config.pad_token_id = self.tokenizer.pad_token_id
        for name, tensor in self.llm.named_parameters():
            tensor.requires_grad = False
        self.device = device

    def encode(self, curse):
        inputs = self.tokenizer(curse, padding=True, return_tensors='pt').to(self.device)
        sequence_lengths = (torch.eq(inputs.input_ids, self.llm.config.pad_token_id).long().argmax(-1)-1).to(self.device)
        transformer_outputs = self.llm(**inputs)
        hidden_states = transformer_outputs[0]
        logits = hidden_states[torch.arange(1, device=self.device), sequence_lengths]
        
        return logits
    
    def free_llm(self):
        del self.tokenizer
        del self.llm

In [3]:
def build_cache(cache_root, attack_type, attack_mode, enc_id, hf_token, all_classes):
    model_name = enc_id.split('/')[-1]
    cache_path = os.path.join(cache_root, model_name, f'{attack_type}_{attack_mode}_LlamaCausal.pt')

    curse_templates = CurseTemplates()
    curse_encoder = CurseEncoder(enc_id, hf_token)
    os.makedirs(os.path.join(cache_root, model_name), exist_ok=True)

    cache = {}
    print(f'Building cache for {model_name} with attack type `{attack_type}` and attack mode `{attack_mode}`...This may take a few minutes.')
    for k_unk in ['known', 'unknown']:
        cache[k_unk] = {}

        for template in curse_templates.templates[attack_type][attack_mode][k_unk]:
            if attack_mode == 'untargeted':
                curse = template
                cache[k_unk][curse] = curse_encoder.encode(template)
            elif attack_type == 'remove':
                for cls in all_classes:
                    curse = template.replace('[victim_class]', cls)
                    if cls not in cache[k_unk]:
                        cache[k_unk][cls] = {}
                    cache[k_unk][cls][curse] = curse_encoder.encode(curse)
            elif attack_type == 'generate':
                for cls in all_classes:
                    curse = template.replace('[target_class]', cls)
                    if cls not in cache[k_unk]:
                        cache[k_unk][cls] = {}
                    cache[k_unk][cls][curse] = curse_encoder.encode(curse)
            elif attack_type == 'misclassify':
                for cls in all_classes:
                    for cls_ in all_classes:
                        curse = template.replace('[victim_class]', cls).replace('[target_class]', cls_)
                        if cls not in cache[k_unk]:
                            cache[k_unk][cls] = {}
                        if cls_ not in cache[k_unk][cls]:
                            cache[k_unk][cls][cls_] = {}
                        cache[k_unk][cls][cls_][curse] = curse_encoder.encode(curse)

    torch.save(cache, cache_path)
    print(f'Cache for {model_name} with attack type `{attack_type}` and attack mode `{attack_mode}` saved to {cache_path}')
    curse_encoder.free_llm()
    return cache_path

In [4]:
all_classes=['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
# build_cache('/home/jialin/mmdetection/mmdet/AnywhereDoor', 'remove', 'targeted', 'meta-llama/Meta-Llama-3.1-8B-Instruct', "hf_bioEBnzZwJEEzTvngrzsGpPn SMRyGBRUWP", all_classes)
# build_cache('/home/jialin/mmdetection/mmdet/AnywhereDoor', 'remove', 'targeted', 'meta-llama/Meta-Llama-3.1-8B', "hf_bioEBnzZwJEEzTvngrzsGpPn SMRyGBRUWP", all_classes)
build_cache('/home/jialin/mmdetection/mmdet/AnywhereDoor', 'remove', 'targeted', 'meta-llama/Llama-2-7b-chat-hf', "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", all_classes)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 13.17it/s]
We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


Building cache for Llama-2-7b-chat-hf with attack type `remove` and attack mode `targeted`...This may take a few minutes.
Cache for Llama-2-7b-chat-hf with attack type `remove` and attack mode `targeted` saved to /home/jialin/mmdetection/mmdet/AnywhereDoor/Llama-2-7b-chat-hf/remove_targeted_LlamaCausal.pt


'/home/jialin/mmdetection/mmdet/AnywhereDoor/Llama-2-7b-chat-hf/remove_targeted_LlamaCausal.pt'