In [None]:
from PIL import Image, ImageFilter

import torch
from torch import nn
import torch.nn.functional as F
print(torch.cuda.is_available())

import sys
sys.path.append('..')
from prior_networks import UViT_Clip
from prior_pipe import PriorPipe

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load prior diffusion

In [None]:
diffusion_prior = UViT_Clip(embed_dim=512, num_heads=8, mlp_ratio=4)
# number of parameters
sum(p.numel() for p in diffusion_prior.parameters() if p.requires_grad)
pipe = PriorPipe(diffusion_prior, device=device)
# load pretrained model
model_name = 'uvit_vice_pred_imagenet' 
path = f'ckpts/{model_name}'
pipe.diffusion_prior.load_state_dict(torch.load(f'{path}.pt'))
pipe.ema.load_state_dict(torch.load(f'{path}_ema.pt'))

# load sd

In [None]:
# extract image features
from diffusers.utils import load_image
from IPython.display import Image, display
from customized_pipe import Generator4Embeds, encode_image
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor,CLIPTextModelWithProjection,CLIPProcessor,AutoTokenizer
import torch

feature_extractor = CLIPImageProcessor()
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    # "h94/IP-Adapter-FaceID", 
    "h94/IP-Adapter", 
    subfolder="models/image_encoder",
    torch_dtype=torch.float16,
).to("cuda")
# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
text_encoder = CLIPTextModelWithProjection.from_pretrained(
    'laion/CLIP-ViT-H-14-laion2B-s32B-b79K'
).to("cuda")
pipe_image = Generator4Embeds(path='stabilityai/sdxl-turbo', num_inference_steps=4)
# pipe_image = Generator4Embeds(path='stabilityai/stable-diffusion-xl-base-1.0', num_inference_steps=25)


In [4]:
pipe_image.pipe.feature_extractor = feature_extractor
pipe_image.pipe.image_encoder = image_encoder

In [None]:
from diffusers.schedulers import DDPMScheduler
scheduler = DDPMScheduler(
    thresholding=False,
    clip_sample=False,
)
pipe.scheduler = scheduler
print((
    pipe.scheduler.config.thresholding, 
    pipe.scheduler.config.sample_max_value, 
    pipe.scheduler.config.clip_sample,
    pipe.scheduler.config.clip_sample_range
))

In [8]:
from diffusers.utils import load_image, make_image_grid
# extract a concept embedding

image_prompt = load_image('Your prior image')

image_prompt = image_prompt.filter(ImageFilter.GaussianBlur(radius=2))

from IPython.display import display
image_embeds = encode_image([image_prompt], image_encoder, feature_extractor)


# Experiments

In [11]:
class MLP_head(nn.Module):
    def __init__(self, input_dim=1024, output_dim=9):
        super(MLP_head, self).__init__()
        self.W = nn.Parameter(torch.randn(input_dim, output_dim)) # (input_dim, output_dim)
    
    def init_weights(self, weight):
        # weight: (output_dim, input_dim)
        self.W.data = weight.T

    def forward(self, x):
        # x: (batch_size, input_dim)
        # linear transformation
        x = F.normalize(x, p=2, dim=1)
        x = torch.matmul(x, self.W)
        return x

def calculate_img_txt_sim(image_features,text_features):
    # print('This is image_feature.shape in calculate_img_txt_sim')
    # print(image_features.shape)
    image_features = image_features/image_features.norm(dim=-1, keepdim=True)
    text_features = text_features/text_features.norm(dim=-1, keepdim=True)
    similarity = (image_features @ text_features.T) # 计算相似度分数
    return similarity

def custom_cross_entropy(logits, p_target):
    return - logits * p_target

def get_loss_dim(net_mlp, p_target):
    """
    Args:
        net_mlp: MLP 1024->6
        p_target (torch.Tensor): Tensor of shape (B, ), 目标的概率分布
    Returns:
        callable: A loss function that calculates the mean squared error loss over the specified concept dimensions for each sample in the batch.
    """
    def loss_dim(x):
        
        c_pre = net_mlp(x)#16*7 16代表16张图，7代表与每个标准表征的相似度, (N, n_category)
        
        loss = custom_cross_entropy(c_pre, p_target) 
        
        return loss.mean()
    return loss_dim

def get_loss_smooth(n_diag=2, threshold=0.95):
    """
    Returns a loss function that calculates the mean squared error (MSE) loss over a target similarity matrix.

    The target similarity matrix is structured such that the first and second off-diagonals are set to 1, while the
    main diagonal and other off-diagonal elements are ignored. This is useful for tasks where the similarity between
    neighboring samples in the batch is to be maximized.

    Returns:
        callable: A loss function that computes the MSE loss between the computed similarity matrix and the target
        similarity matrix for a given batch of inputs.
    """
    def loss_smooth(x):
        N = x.size(0)
        assert n_diag < N, f"n_diag must be less than the batch size N={N}"

        # Compute similarity matrix
        sim = F.cosine_similarity(x.unsqueeze(1), x.unsqueeze(0), dim=-1)

        # Initialize target similarity matrix with NaNs to ignore in loss calculation
        target_sim = torch.full((N, N), 0.0, device=x.device)

        # Set target similarities on the upper triangular of the first n off-diagonals to 1
        for i in range(1, n_diag + 1):
            target_sim += torch.diag(torch.full((N - i,), 1., device=x.device), i)
            # target_sim += torch.diag(torch.full((N - i,), 1.0, device=x.device), -i) 

        # Mask out NaN values in target_sim for loss calculation
        valid_mask = target_sim != 0
        valid_mask[sim>=threshold] = False
        # in case of no valid mask
        if valid_mask.sum() == 0:
            return torch.tensor(0.0, device=x.device, requires_grad=True)
        loss = F.mse_loss(sim[valid_mask], target_sim[valid_mask])
        # loss = F.mse_loss(sim.view(-1), target_sim.view(-1))
        return loss.mean()

    return loss_smooth

def get_loss_similarity(h_target, threshold=0.95):
    """
    Returns a loss function that calculates the mean squared error (MSE) loss for the cosine similarity between each input
    and a target embedding, but ignores those where the similarity exceeds a specified threshold.

    Args:
        h_target (torch.Tensor): Target embedding tensor of shape (1, D), where D is the dimensionality of the embeddings.
        threshold (float): Similarity threshold above which no loss is computed.

    Returns:
        callable: A loss function that computes the MSE loss for inputs similar to the target below a certain threshold.
    """
    def loss_similarity(x):
        # Calculate the cosine similarity between the batch x and the target h_target
        sim = F.cosine_similarity(x, h_target.repeat(x.size(0), 1), dim=1)
        
        # Apply threshold: Only consider embeddings with a similarity below the threshold
        mask = sim < threshold
        
        # If all embeddings exceed the threshold, return zero loss
        if mask.sum() == 0:
            return torch.tensor(0.0, device=x.device, requires_grad=True)
        
        # Compute MSE loss for the selected embeddings below the threshold
        target_sim_values = torch.ones_like(sim[mask])
        loss = F.mse_loss(sim[mask], target_sim_values)
        
        return loss.mean()

    return loss_similarity

    
def fns_collector(fns, scales):
    """
    Combines multiple functions with corresponding scales into a single loss function.
    
    Args:
    fns (list[callable]): List of function objects, each accepting the same type of input.
    scales (list[float]): List of scaling factors for each function in `fns`.
    
    Returns:
    callable: A combined function that computes the scaled sum of individual functions.
    """
    def loss_func(x):
        # Compute the weighted sum of functions
        loss = sum(scale * fn(x) for scale, fn in zip(scales, fns))
        return loss.mean()

    return loss_func

# Image Edit

In [None]:
words = ['surprise','fear','disgust','happiness','sadness','anger']
texts = []
for word in words:
    texts.append('expression of very'+word)

text_inputs = tokenizer(text=texts,padding=True,return_tensors='pt').to('cuda')
#print(text_inputs)
text_output = text_encoder(**text_inputs)
text_features = text_output.text_embeds
text_features.shape

n = 18

weights = torch.linspace(0,1,steps=n).to('cuda')

feature_dimension = 3

weights = torch.ones(n).float().to(device)

feature_dimension = torch.arange(6).repeat(3)
p_target = torch.zeros(n, 6, device=device)
p_target[torch.arange(18), feature_dimension] = 1
print(p_target)

p_target[torch.arange(6), torch.arange(6)] = 0.33
p_target[torch.arange(6,12), torch.arange(6)] = 0.66
p_target[torch.arange(12,18), torch.arange(6)] = 1
print(p_target)

feature_dimension2 = [1,2]
weights2 = torch.stack([weights,torch.flip(weights,dims=[0])],dim=0)
print('This is weights2.shape')
print(weights2.shape)
print(f'feature dimension:{feature_dimension}')
print(f'feature dimension2:{feature_dimension2}')

generators = [torch.Generator(device=device).manual_seed(1) for _ in range(n)]
seed_value = 42
for generator in generators:
    generator.manual_seed(seed_value)

net_mlp = MLP_head(output_dim=6)
net_mlp.load_state_dict(torch.load('MLP_head/MLP_head_beforeFN.pt'))
net_mlp = net_mlp.to(device)

def controversial_generate(n, generators, seed_value, p_target, net_mlp):
    loss_dim = get_loss_dim(net_mlp,p_target)
    loss_smooth = get_loss_smooth(n_diag=3, threshold=1)
    loss_similarity = get_loss_similarity(h_target=image_embeds[0:1], threshold=1)

    loss_func = fns_collector(
    fns=[
        loss_dim,
        loss_smooth,
        loss_similarity
    ],
    scales=[1, 0, 10]
        # scales=[10, 3]
        # scales = [1]
)

#下方生成了embeddings,h就是16*1024的embeddings,代表每个图片的embeddings
#用下方生成的embeddings->h通过sdxl生成图片即可
# Generate all embeddings in one pass
    h = pipe.generate_guidance(
    loss_fn=loss_func,
    num_inference_steps=50,
    num_resampling_steps=5,
    guidance_scale=1,
    generator=None,
    use_ema=False,
    latent=None,
    strength=1,
    N=n,
    shape=(1024,)
)

    print(h.shape)
    print(h)
    print(F.softmax(net_mlp(h)))
    print(net_mlp(h).shape)

# Generate all images in parallel
    pipe_image.pipe.set_ip_adapter_scale(1)
    text_prompt = 'normal human face,well-formed, beautiful, correct proportions, high resolution, good anatomy, best quality, high quality.'
    negative_prompt = "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality,inhuman, white cracks and lines on the face, deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality,lines on the face,white cracks"

    from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
    pipe_image.pipe = AutoPipelineForImage2Image.from_pipe(pipe_image.pipe).to("cuda")
    images = []

    for i in range(len(h)):
        generators[i].manual_seed(seed_value)
        image = pipe_image.generate(
        h[i:i+1].to(dtype=torch.float16),
        image=image_prompt.resize((512, 512)),
        strength=0.5, 
        text_prompt=text_prompt,
        negative_prompt=negative_prompt,
        generator=None,
        # guidance_scale=1.25,
    )
        images.append(image[0])
    return h,images

num = 5 #对抗分5级
dim = [0,1]
p_target = torch.zeros(num,6,device=device)
for i in range(num):
    p_target[i,dim[0]] = i / 4
    p_target[i,dim[1]] = 1-i / 4
p_target

h,images_list = controversial_generate(num,generators,42,p_target,net_mlp)
display(make_image_grid([*images_list], rows=1, cols=5, resize=512))