In [5]:
import argparse
import glob
import json
import os
import warnings
from pathlib import Path

import clip
import numpy as np
import pandas as pd
import sklearn.preprocessing
import torch
from packaging import version
from PIL import Image
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
from tqdm import tqdm
import glob
import json

# New added
import pickle


class CLIPCapDataset(torch.utils.data.Dataset):
    def __init__(self, data, append=False, prefix='A photo depicts'):
        self.data = data
        self.prefix = ''
        if append:
            self.prefix = prefix
            if self.prefix[-1] != ' ':
                self.prefix += ' '

    def __getitem__(self, idx):
        c_data = self.data[idx]
        c_data = clip.tokenize(self.prefix + c_data, truncate=True).squeeze()
        return {'caption': c_data}

    def __len__(self):
        return len(self.data)


def Convert(image):
    return image.convert("RGB")


class CLIPImageDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
        # only 224x224 ViT-B/32 supported for now
        self.preprocess = self._transform_test(224)

    def _transform_test(self, n_px):
        return Compose([
            Resize(n_px, interpolation=Image.BICUBIC),
            CenterCrop(n_px),
            Convert,
            ToTensor(),
            Normalize((0.48145466, 0.4578275, 0.40821073),
                      (0.26862954, 0.26130258, 0.27577711)),
        ])

    def __getitem__(self, idx):
        c_data = self.data[idx]
        if isinstance(c_data, str):
            image = Image.open(c_data)
        elif isinstance(c_data, Image.Image):
            image = c_data
        else:
            raise ValueError(f"Unsupported image type: {type(c_data)}")
        image = self.preprocess(image)
        return {'image': image}

    def __len__(self):
        return len(self.data)


class DINOImageDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
        # only 224x224 ViT-B/32 supported for now
        self.preprocess = self._transform_test(224)

    def _transform_test(self, n_px):
        return Compose([
            Resize(256, interpolation=Image.BICUBIC),
            CenterCrop(n_px),
            Convert,
            ToTensor(),
            Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

    def __getitem__(self, idx):
        c_data = self.data[idx]
        # 支持路径或PIL.Image对象
        if isinstance(c_data, str):
            image = Image.open(c_data)
        elif isinstance(c_data, Image.Image):
            image = c_data
        else:
            raise ValueError(f"Unsupported image type: {type(c_data)}")
        image = self.preprocess(image)
        return {'image': image}

    def __len__(self):
        return len(self.data)


def extract_all_captions(captions, model, device, batch_size=256, num_workers=8, append=False):
    data = torch.utils.data.DataLoader(
        CLIPCapDataset(captions, append=append),
        batch_size=batch_size, num_workers=num_workers, shuffle=False)
    all_text_features = []
    with torch.no_grad():
        # for b in tqdm(data):
        for b in data:
            b = b['caption'].to(device)
            all_text_features.append(model.encode_text(b).cpu().numpy())
    all_text_features = np.vstack(all_text_features)
    return all_text_features


def extract_all_images(images, model, datasetclass, device, batch_size=1, num_workers=1):
    data = torch.utils.data.DataLoader(
        datasetclass(images),
        batch_size=batch_size, num_workers=num_workers, shuffle=False)
    all_image_features = []
    with torch.no_grad():
        # for b in tqdm(data):
        for b in data:
            b = b['image'].to(device)
            if hasattr(model, 'encode_image'):
                if device == 'cuda':
                    b = b.to(torch.float16)
                all_image_features.append(model.encode_image(b).cpu().numpy())
            else:
                all_image_features.append(model(b).cpu().numpy())
    all_image_features = np.vstack(all_image_features)
    return all_image_features


def get_clip_score(model, images, candidates, device, append=False, w=1.0):
    '''
    get standard image-text clipscore.
    images can either be:
    - a list of strings specifying filepaths for images
    - a precomputed, ordered matrix of image features
    '''
    if isinstance(images, list):
        # need to extract image features
        images = extract_all_images(images, model, device)

    candidates = extract_all_captions(candidates, model, device, append=append)

    if version.parse(np.__version__) < version.parse('1.21'):
        images = sklearn.preprocessing.normalize(images, axis=1)
        candidates = sklearn.preprocessing.normalize(candidates, axis=1)
    else:
        images = images / np.sqrt(np.sum(images ** 2, axis=1, keepdims=True))
        candidates = candidates / \
            np.sqrt(np.sum(candidates ** 2, axis=1, keepdims=True))

    per = w * np.clip(np.sum(images * candidates, axis=1), 0, None)
    return np.mean(per), per


def clipeval(image_paths, prompts, model, device):
    image_feats = extract_all_images(
        image_paths, model, CLIPImageDataset, device, batch_size=1, num_workers=1)

    _, per_instance_image_text = get_clip_score(
        model, image_feats, prompts, device)

    # scores = {image_path: {'CLIPScore': float(clipscore)}
    #           for image_path, clipscore in
    #           zip(image_paths, per_instance_image_text)}
    scores = [float(clipscore) for clipscore in per_instance_image_text]

    # return np.mean(scores), np.std([s['CLIPScore'] for s in scores.values()])
    return np.mean(scores)

def clipeval_image(image_paths, image_paths_ref, model, device):
    image_feats = extract_all_images(
        image_paths, model, CLIPImageDataset, device, batch_size=1, num_workers=1)
    image_feats_ref = extract_all_images(
        image_paths_ref, model, CLIPImageDataset, device, batch_size=1, num_workers=1)
    image_feats = image_feats / \
        np.sqrt(np.sum(image_feats ** 2, axis=1, keepdims=True))
    image_feats_ref = image_feats_ref / \
        np.sqrt(np.sum(image_feats_ref ** 2, axis=1, keepdims=True))
    res = image_feats @ image_feats_ref.T
    return np.mean(res)

def dinoeval_image(image_paths, image_paths_ref, model, device):
    image_feats = extract_all_images(
        image_paths, model, DINOImageDataset, device, batch_size=1, num_workers=1)

    image_feats_ref = extract_all_images(
        image_paths_ref, model, DINOImageDataset, device, batch_size=1, num_workers=1)

    image_feats = image_feats / \
        np.sqrt(np.sum(image_feats ** 2, axis=1, keepdims=True))
    image_feats_ref = image_feats_ref / \
        np.sqrt(np.sum(image_feats_ref ** 2, axis=1, keepdims=True))
    res = image_feats @ image_feats_ref.T
    return np.mean(res)

In [2]:
import argparse
import glob
import json
import os
import warnings
from pathlib import Path

import clip
import numpy as np
import pandas as pd
import sklearn.preprocessing
import torch
from packaging import version
from PIL import Image
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
from tqdm import tqdm
import glob
import json
from vision_transformer import vit_small

device = 'cuda'
clip_model, _ = clip.load("/data/oss_bucket_0/ziwei/checkpoints/ViT-B-32.pt", device=device, jit=False)
clip_model.eval()

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

In [None]:
dino_model = vit_small(patch_size=16)
checkpoint = torch.load("/home/zw.hzw/checkpoints/dino-vits16/dino_deitsmall16_pretrain_full_checkpoint.pth", map_location=device)
state_dict = checkpoint["student"]
from collections import OrderedDict
new_state_dict = OrderedDict()
for k,v in state_dict.items():
    new_key = k.replace('module.','')
    new_state_dict[new_key]=v

dino_model.load_state_dict(new_state_dict, strict=False)
dino_model = dino_model.to(device)
dino_model.eval()

In [7]:
ref_img = Image.open("/home/zw.hzw/model_list/UNO-main/eval/test_images/121.png")
gen_img = Image.open("/home/zw.hzw/model_list/UNO-main/eval/test_images/121_1.png")
prompt = "a purple can"
score = clipeval([ref_img], [prompt], clip_model, device)
print(score)
if isinstance(float(score), float):
    print("yes")
score = dinoeval_image([gen_img], ["/home/zw.hzw/model_list/UNO-main/eval/test_images/121.png"], dino_model, device)
print(score)
if isinstance(float(score), float):
    print("yes")

score = clipeval_image([gen_img], [ref_img], clip_model, device)
print(score)
if isinstance(float(score), float):
    print("yes")

0.27392578125
yes
0.99803245
yes
0.968
yes
