In [None]:
import os
import torch
import torch.nn as nn 
from torchvision import transforms
from PIL import Image
import clip
from transformers import AutoModel
from tqdm import tqdm
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipelineLegacy, DDIMScheduler, AutoencoderKL
from tide import TIDE

In [2]:

class DINOEvaluator:
    def __init__(self, device, dino_model='./dino'):
        self.device = device
        self.model = AutoModel.from_pretrained(dino_model).to(device)  
        self.preprocess = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    @torch.no_grad()
    def get_image_features(self, images):
        images = images.to(self.device)  
        features = self.model(images)
        # features = torch.nn.functional.normalize(features, dim=1)
        features = features.last_hidden_state.mean(dim=1)
        return features

    def img_to_img_similarity(self, src_images, generated_images):
        src_img_features = self.get_image_features(src_images)
        gen_img_features = self.get_image_features(generated_images)
        cos = nn.CosineSimilarity(dim=0)

        gen_img_features_expanded = gen_img_features.expand(src_img_features.size(0), -1)
        similarities = cos(src_img_features, gen_img_features_expanded)

        avg_similarity = similarities.mean().item()

        avg_similarity = (avg_similarity + 1) / 2
        return avg_similarity

In [3]:
class CLIPEvaluator(object):
    def __init__(self, device, clip_model='ViT-B/32') -> None:
        self.device = device
        self.model, clip_preprocess = clip.load(clip_model, device=self.device)

        self.preprocess = clip_preprocess

        self.src_preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0,
                                                                                                 2.0])] +
                                             clip_preprocess.transforms[:2] +
                                             clip_preprocess.transforms[4:])

    def tokenize(self, strings: list):
        return clip.tokenize(strings).to(self.device)

    @torch.no_grad()
    def encode_text(self, tokens: list) -> torch.Tensor:
        return self.model.encode_text(tokens)

    @torch.no_grad()
    def encode_images(self, images: torch.Tensor) -> torch.Tensor:
        images = self.src_preprocess(images).to(self.device)
        return self.model.encode_image(images)

    def get_text_features(self, text: str, norm: bool = True) -> torch.Tensor:
        tokens = clip.tokenize(text).to(self.device)
        text_features = self.encode_text(tokens).detach()
        if norm:
            text_features /= text_features.norm(dim=-1, keepdim=True)
        return text_features

    def get_image_features(self, img: torch.Tensor, norm: bool = True) -> torch.Tensor:
        image_features = self.encode_images(img)
        if norm:
            image_features /= image_features.clone().norm(dim=-1, keepdim=True)
        return image_features

    def img_to_img_similarity(self, src_images, generated_images):
        src_img_features = self.get_image_features(src_images)
        gen_img_features = self.get_image_features(generated_images)
        return (src_img_features @ gen_img_features.T).mean()

    def txt_to_img_similarity(self, text, generated_images):
        text_features = self.get_text_features(text)
        gen_img_features = self.get_image_features(generated_images)
        return (text_features @ gen_img_features.T).mean()


class ImageDirEvaluator(CLIPEvaluator):
    def __init__(self, device, clip_model='ViT-B/32') -> None:
        super().__init__(device, clip_model)

    def evaluate(self, gen_samples, src_images, target_text):
        sim_samples_to_img = self.img_to_img_similarity(src_images, gen_samples)
        sim_samples_to_text = self.txt_to_img_similarity(target_text.replace("{}", "").replace('-',' '), gen_samples)
        return sim_samples_to_img, sim_samples_to_text


def load_images_from_folder(folder_path, evaluator):
    images = []
    for filename in os.listdir(folder_path):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(folder_path, filename)
            image = Image.open(image_path).convert("RGB")
            image = evaluator.preprocess(image).unsqueeze(0)
            images.append(image)
    if images:
        return torch.cat(images, dim=0)
    return None


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_evaluator = ImageDirEvaluator(device, clip_model='path/to/clip/ViT-B-32.pt')
dino_evaluator = DINOEvaluator(device, dino_model='/path/to/dino')

data_folder = 'data'
source_folder = os.path.join(data_folder,'source')
prompts_folder = os.path.join(data_folder, 'prompts')
output_folder = 'concept101_output'


base_model_path = "path/to/sd1.5"
vae_model_path = "path/to/sd1.5/vae"
image_encoder_path = "models/image_encoder"
ip_ckpt = "models/tide/tide_sdv1.5.bin"


noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1,
)
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
pipe = StableDiffusionPipeline.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
    scheduler=noise_scheduler,
    vae=vae,
    feature_extractor=None,
    safety_checker=None
)
ip_model = TIDE(pipe, image_encoder_path, ip_ckpt, device)

In [5]:
total_clip_i_score = 0
total_clip_t_score = 0
total_dino_score = 0

total_clip_i_score_all = 0
total_clip_t_score_all = 0
total_dino_score_all = 0
image_count = 0


total_images = len([f for f in os.listdir(source_folder) if os.path.isdir(os.path.join(source_folder, f))])


print(total_images)

101


In [None]:
with tqdm(total=total_images, desc="Processing Images") as pbar:
    for sub_folder_name in os.listdir(source_folder):
        source_sub_folder = os.path.join(source_folder, sub_folder_name)
        if os.path.isdir(source_sub_folder):
            src_images = load_images_from_folder(source_sub_folder, clip_evaluator)
            src_images4dino = load_images_from_folder(source_sub_folder, dino_evaluator)
            if src_images is None:
                continue


            category_name = sub_folder_name.split('_')[0]


            prompt_file_path = os.path.join(prompts_folder, f"{category_name}.txt")
            with open(prompt_file_path, 'r') as file:
                prompts = file.read().strip().split('\n')

            output_sub_folder = os.path.join(output_folder, sub_folder_name)
            os.makedirs(output_sub_folder, exist_ok=True)

            image = None
            for filename in os.listdir(source_sub_folder):
                if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
                    image_path = os.path.join(source_sub_folder, filename)
                    image = Image.open(image_path).convert("RGB")
                    image.resize((256,256))
                    break

            with torch.no_grad():
                for prompt in prompts:
                    prompt = prompt.rstrip()
                    final_prompt = prompt.replace("{}", category_name)
                    generated_images = ip_model.generate(pil_image=image, num_samples=6, num_inference_steps=50,
                                                         prompt=final_prompt, scale=0.5)


                    prompt_folder_name = prompt.rstrip('.').replace(" ", "_")  
                    prompt_output_folder = os.path.join(output_sub_folder, prompt_folder_name)
                    os.makedirs(prompt_output_folder, exist_ok=True)

                    best_clip_i_score = -float('inf')
                    best_clip_t_score = -float('inf')
                    best_dino_score = -float('inf')
                    best_score = -float('inf')
                    best_image = None
        

                    for i, gen_image in enumerate(generated_images):
                        
                        image_path = os.path.join(prompt_output_folder, f"generated_{i+1}.png")
                        gen_image.save(image_path)
                        
                        gen_sample = clip_evaluator.preprocess(gen_image).unsqueeze(0).to(device)
                        sim_samples_to_img, sim_samples_to_text = clip_evaluator.evaluate(gen_sample, src_images, final_prompt)
                        clip_i_score = sim_samples_to_img.item()
                        clip_t_score = sim_samples_to_text.item()

                        gen_sample = dino_evaluator.preprocess(gen_image).unsqueeze(0).to(device)
                        similarity = dino_evaluator.img_to_img_similarity(src_images4dino, gen_sample)
                        dino_score = similarity

                        total_clip_i_score_all += clip_i_score
                        total_clip_t_score_all += clip_t_score
                        total_dino_score_all += dino_score
                        
                        if 0.25*dino_score+0.25*clip_i_score+0.5*clip_t_score > best_score:
                            best_score = 0.25*dino_score+0.25*clip_i_score+0.5*clip_t_score
                            best_clip_i_score = clip_i_score
                            best_clip_t_score = clip_t_score
                            best_dino_score = dino_score
                            best_image = gen_image
        
                    if best_image:
                        
                        best_image_path = os.path.join(prompt_output_folder, "best.png")
                        best_image.save(best_image_path)
        
                        total_clip_i_score += best_clip_i_score
                        total_clip_t_score += best_clip_t_score
                        total_dino_score += best_dino_score
                        
                        image_count += 1
        pbar.update(1)

In [None]:
if image_count > 0:
    avg_clip_i_score = total_clip_i_score / image_count
    avg_clip_t_score = total_clip_t_score / image_count
    avg_dino_score = total_dino_score / image_count
    print(f"CLIP-I's best average score: {avg_clip_i_score}")
    print(f"CLIP-T's best average score: {avg_clip_t_score}")
    print(f"DINO's best average score: {avg_dino_score}")
    image_count *= 6
    avg_clip_i_score = total_clip_i_score_all / image_count
    avg_clip_t_score = total_clip_t_score_all / image_count
    avg_dino_score = total_dino_score_all / image_count
    print(f"CLIP-I's average score: {avg_clip_i_score}")
    print(f"CLIP-I's average score: {avg_clip_t_score}")
    print(f"DINO's average score: {avg_dino_score}")
else:
    print("No valid images were found for rating calculation.")