### This file is for testing the UFID backdoor defense results on IBA.

Code repor: [UFID](https://github.com/GuanZihan/official_UFID) 

Arxiv: [UFID: A Unified Framework for Input-level Backdoor Detection on Diffusion Models](https://arxiv.org/abs/2404.01101)


In [None]:
import clip

In [2]:
from math import ceil, sqrt
from typing import List, Union, Tuple
import os

from PIL import Image
import numpy as np
import torch

import networkx as nx
import torchvision.transforms as transforms

from operate import Sampling,SamplingStatic
from diffusers import DiffusionPipeline, StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
from config import SamplingConfig, SamplingConfig, PromptDatasetStatic, MeasuringStatic
from caption_dataset import CaptionBackdoor
from arg_parser import ArgParser, yield_default
from operate import PromptDataset, Sampling, Measuring, ModelSched
from sklearn import metrics
import warnings

warnings.filterwarnings('ignore') 

In [3]:
import random
def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    
set_seed(42)

In [4]:
def make_grid(images, rows, cols):
    w, h = images[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, image in enumerate(images):
        grid.paste(image, box=(i%cols*w, i//cols*h))
    return grid

In [5]:
def save_grid(images: List, path: Union[str, os.PathLike], file_name: str, _format: str='png'):
    images = [Image.fromarray(np.squeeze((image * 255).round().astype("uint8"))) for image in images]
    
    eval_samples_n = len(images)
    nrow = 1
    ncol = eval_samples_n
    for i in range(ceil(sqrt(eval_samples_n)), 0, -1):
        if eval_samples_n % i == 0:
            nrow = i
            ncol = eval_samples_n // nrow
            break

    # # Make a grid out of the images
    image_grid = make_grid(images, rows=nrow, cols=ncol)
    image_grid.save(os.path.join(f"{path}", f"{file_name}.{_format}"))

In [6]:
__prompt_ds =  PromptDataset

In [7]:
def __get_default_sample_prompts(in_out_dist: str, train_test_split: str, n: int=MeasuringStatic.DEFAULT_SAMPLE_PROMPTS_N):
    if in_out_dist is PromptDatasetStatic.IN_DIST:
        if train_test_split is PromptDatasetStatic.TRAIN_SPLIT:
            return __prompt_ds.in_ditribution_training_captions[:n]
        elif train_test_split is PromptDatasetStatic.TEST_SPLIT:
            return __prompt_ds.in_ditribution_testing_captions[:n]
        elif train_test_split is PromptDatasetStatic.FULL_SPLIT:
            return __prompt_ds.in_ditribution_captions[:n]
        else:
            raise NotImplementedError
    elif in_out_dist is PromptDatasetStatic.OUT_DIST:
        if train_test_split is PromptDatasetStatic.FULL_SPLIT:
            return __prompt_ds.out_ditribution_captions[:n]
        else:
            raise NotImplementedError
    else:
        raise NotImplementedError

load the clean samples

In [8]:
__num_inference_steps: int = SamplingStatic.NUM_INFERENCE_STEPS
__guidance_scale: float = SamplingStatic.GUIDANCE_SCALE
__max_batch_n: int = SamplingStatic.MAX_BATCH_N

In [9]:
def sample(prompts: List[str], pipe: DiffusionPipeline, inits: torch.Tensor=None, seed: int=SamplingStatic.SEED,
            handle_fn: callable=SamplingStatic.HANDLE_FN, handle_batch_fn: callable=SamplingStatic.HANDLE_BATCH_FN, return_imgs: bool=False):
    return Sampling._sample(prompts=prompts, inits=inits, pipe=pipe, num_inference_steps=__num_inference_steps,
                            guidance_scale=__guidance_scale, max_batch_n=__max_batch_n, seed=seed,
                            handle_fn=handle_fn, handle_batch_fn=handle_batch_fn, return_imgs=return_imgs)

In [10]:
def caption_backdoor_sample(prompts: List[str], trigger: str, pipe: DiffusionPipeline, start_pos: int=SamplingStatic.TRIG_START_POS, 
                                end_pos: int=SamplingStatic.TRIG_END_POS, inits: torch.Tensor=None, seed: int=SamplingStatic.SEED,
                                handle_fn: callable=SamplingStatic.HANDLE_FN, handle_batch_fn: callable=SamplingStatic.HANDLE_BATCH_FN, return_imgs: bool=False):

    prompts: List[str] = CaptionBackdoor.backdoor_caption_generator(_type=trigger, start_pos=start_pos, end_pos=end_pos)(prompts)

    return sample(prompts=prompts, pipe=pipe, inits=inits, seed=seed, handle_fn=handle_fn, handle_batch_fn=handle_batch_fn, return_imgs=return_imgs)

In [11]:
def randn_images(n: int, channel: int, image_size: int, seed: int):
    shape: Tuple[int] = (n, channel, image_size, image_size)
    return torch.randn(shape, generator=torch.manual_seed(seed))

In [12]:
def clean_sample(prompts: List[str], pipe: DiffusionPipeline, inits: torch.Tensor=None, seed: int=SamplingStatic.SEED,
                     handle_fn: callable=SamplingStatic.HANDLE_FN, handle_batch_fn: callable=SamplingStatic.HANDLE_BATCH_FN, return_imgs: bool=False):
        """Generate clean samples for multiple prompts and initial latents
        
        Parameters
        ----------
        handle_fn : callable
        handle_batch_fn : callable
        return_imgs : bool
        
        Returns
        -------
        samples : torch.Tensor
        """
        if inits is None:
            channel, image_size = pipe.unet.config.in_channels, pipe.unet.config.sample_size
            inits: torch.Tensor = randn_images(n=len(prompts), channel=channel, image_size=64, seed=seed)
                
        return sample(prompts=prompts, pipe=pipe, inits=inits, seed=seed, handle_fn=handle_fn, handle_batch_fn=handle_batch_fn, return_imgs=return_imgs)


In [13]:
def backdoor_clean_samples(pipe: DiffusionPipeline, prompts: str, image_trigger: str=None, caption_trigger: str=None,
                            trig_start_pos: int=SamplingStatic.TRIG_START_POS, trig_end_pos: int=SamplingStatic.TRIG_END_POS,
                            handle_fn: callable = SamplingStatic.HANDLE_FN, handle_batch_fn: callable = SamplingStatic.HANDLE_BATCH_FN,
                            return_imgs: bool=False, seed: int=SamplingStatic.SEED):
        
    if caption_trigger is not None:
        images = caption_backdoor_sample(prompts=prompts, trigger=caption_trigger, pipe=pipe, start_pos=trig_start_pos, end_pos=trig_end_pos, inits=None, handle_fn=handle_fn, handle_batch_fn=handle_batch_fn, seed=seed, return_imgs=return_imgs)
    else:
        images = clean_sample(prompts=prompts, pipe=pipe, inits=None, handle_fn=handle_fn, handle_batch_fn=handle_batch_fn, seed=seed, return_imgs=return_imgs)
    return images

In [14]:
def generate_sample(base_path: Union[os.PathLike, str], pipe: DiffusionPipeline, prompt: str, image_trigger: str=None, caption_trigger: str=None,
                        trig_start_pos: int=SamplingStatic.TRIG_START_POS, trig_end_pos: int=SamplingStatic.TRIG_END_POS, img_num_per_grid_sample: int=SamplingStatic.IMAGE_NUM_PER_GRID_SAMPLE,
                        _format: str=SamplingStatic.FORMAT, seed: int=SamplingStatic.SEED, force_regenerate: bool=MeasuringStatic.FORCE_REGENERATE):
        
        out_img_dir: str = Sampling.get_folder(image_trigger=image_trigger, caption_trigger=caption_trigger, sched_name=None, num_inference_steps=None, img_num=None)
        file_name_prefix: str = '_'.join(prompt.split(" ")[:10])
        out_img_name: str = f"{file_name_prefix}_{out_img_dir}"
        out_img_path = os.path.join(f"{base_path}", out_img_dir)
        os.makedirs(out_img_path, exist_ok=True)
        
        prompts: List[str] = Sampling.augment_prompts(prompts=prompt, img_num_per_prompt=img_num_per_grid_sample)
        
        # if force_regenerate or len(prompt) > match_count(dir=out_img_path, exts=[_format]):
        images = backdoor_clean_samples(pipe=pipe, prompts=prompts, image_trigger=image_trigger, caption_trigger=caption_trigger, trig_start_pos=trig_start_pos, trig_end_pos=trig_end_pos, return_imgs=True, seed=seed)
        
        
        # save_grid(images=images, path=out_img_path, _format=_format, file_name=out_img_name)

        device = "cuda:5"
        architecture = "./ViT-B-32.pt"
        model, transform = clip.load(architecture, device=device)
        transform_input = transforms.Compose([
            transforms.ToPILImage(),
        ])

        images_tensor = torch.from_numpy(images).permute(0, 3, 1, 2)
        with torch.no_grad():
            image_inputs = torch.cat([transform(transform_input(image)).unsqueeze(0) for image in images_tensor]).to(device)
            image_features = model.encode_image(image_inputs)

        image_features_normalized = image_features / image_features.norm(dim=-1, keepdim=True)
        similarity_matrix = image_features_normalized.cpu().numpy() @ image_features_normalized.T.cpu().numpy()
        np.save("similarity_matrix", similarity_matrix)
        G = nx.Graph()
        num_images = len(similarity_matrix)
        
        count = 0
        all_ = 0
        sum_ = 0
        for i in range(num_images):
            for j in range(i + 1, num_images):
                G.add_edge(f"Image {i+1}", f"Image {j+1}", weight=similarity_matrix[i, j])
                count += 1
                sum_ += similarity_matrix[i, j].item()
                all_ += 1
        
        return sum_/all_

In [15]:
image_trigger = None
caption_trigger = "TRIGGER_ABBR_OMG"
out_img_dir: str = Sampling.get_folder(image_trigger=image_trigger, caption_trigger=caption_trigger, sched_name=None, num_inference_steps=None, img_num=None)

In [None]:
pipe, store_path = ModelSched.get_stable_diffusion(model_id='',
                                                   sched='default', 
                                                   ckpt_step=40000, enable_lora=False, 
                                                   lora_base_model='../Models/stable-diffusion-v1-4',
                                                   gpu=5)


In [17]:
pipe.safety_checker = lambda images, clip_input: (images, None)

In [18]:
trig_start_pos: int=SamplingStatic.TRIG_START_POS
trig_end_pos: int=SamplingStatic.TRIG_END_POS
img_num_per_grid_sample: int=SamplingStatic.IMAGE_NUM_PER_GRID_SAMPLE
_format: str=SamplingStatic.FORMAT
seed: int=SamplingStatic.SEED

- Clean samples

In [19]:
# the average of list
def calculate_average(numbers):
    if not numbers:
        return 0
    return sum(numbers) / len(numbers)

In [None]:
numbers_array = np.load('clean_threshold.npy').tolist()
threshold = calculate_average(numbers_array)
print("threshold is:", threshold)

- Backdoor samples

In [22]:
from transformers import CLIPTextModel
path = '../../results/backdoor_1_KMMD/599'
encoder = CLIPTextModel.from_pretrained(path)
pipe.text_encoder = encoder.to("cuda:5")

In [24]:
with open('data/invisble/poison_data_test.txt','r',encoding='utf-8') as f:
    prompts = f.readlines()
    
prompts_backdoor = [prompt.strip() for prompt in prompts]

In [None]:
caption_trigger = None
score = 0
all_detection_results_backdoor = []
for prompt in prompts_backdoor:
    r = generate_sample(base_path='./out_imgs', pipe=pipe, prompt=prompt, image_trigger=image_trigger, caption_trigger=caption_trigger, trig_start_pos=trig_start_pos, trig_end_pos=trig_end_pos, img_num_per_grid_sample=img_num_per_grid_sample, _format=_format, seed=seed)
    print(r)
    if r > threshold:
        score += 1
    all_detection_results_backdoor.append(r)
    
print(score/len(prompts_backdoor))

In [None]:
path = '../../results/backdoor_2_KMMD/599'
encoder = CLIPTextModel.from_pretrained(path)
pipe.text_encoder = encoder.to("cuda:5")

caption_trigger = None
score = 0
all_detection_results_backdoor = []
for prompt in prompts_backdoor:
    r = generate_sample(base_path='./out_imgs', pipe=pipe, prompt=prompt, image_trigger=image_trigger, caption_trigger=caption_trigger, trig_start_pos=trig_start_pos, trig_end_pos=trig_end_pos, img_num_per_grid_sample=img_num_per_grid_sample, _format=_format, seed=seed)
    print(r)
    if r > threshold:
        score += 1
    all_detection_results_backdoor.append(r)
    
print(score/len(prompts_backdoor))

In [None]:
path = '../../results/IBA/backdoor_3_KMMD/599'
encoder = CLIPTextModel.from_pretrained(path)
pipe.text_encoder = encoder.to("cuda:5")

caption_trigger = None
score = 0
all_detection_results_backdoor = []
for prompt in prompts_backdoor:
    r = generate_sample(base_path='./out_imgs', pipe=pipe, prompt=prompt, image_trigger=image_trigger, caption_trigger=caption_trigger, trig_start_pos=trig_start_pos, trig_end_pos=trig_end_pos, img_num_per_grid_sample=img_num_per_grid_sample, _format=_format, seed=seed)
    print(r)
    if r > threshold:
        score += 1
    all_detection_results_backdoor.append(r)
    
print(score/len(prompts_backdoor))

In [None]:
path = '../../results/backdoor_4_KMMD/599'
encoder = CLIPTextModel.from_pretrained(path)
pipe.text_encoder = encoder.to("cuda:5")

caption_trigger = None
score = 0
all_detection_results_backdoor = []
for prompt in prompts_backdoor:
    r = generate_sample(base_path='./out_imgs', pipe=pipe, prompt=prompt, image_trigger=image_trigger, caption_trigger=caption_trigger, trig_start_pos=trig_start_pos, trig_end_pos=trig_end_pos, img_num_per_grid_sample=img_num_per_grid_sample, _format=_format, seed=seed)
    print(r)
    if r > threshold:
        score += 1
    all_detection_results_backdoor.append(r)
    
print(score/len(prompts_backdoor))