In [1]:
import numpy as np
import torch
from pkg_resources import packaging
import clip
import os 
import skimage
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

from collections import OrderedDict
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#加载模型和图片处理器
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("模型参数:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("输入图片尺寸:", input_resolution)
print("文本长度:", context_length)
print("词表大小:", vocab_size)

模型参数: 151,277,313
输入图片尺寸: 224
文本长度: 77
词表大小: 49408


In [3]:
def clip_classifier_m(image_dir, subfolder, choice_label, top_k=5):
    # image_dir不为文件夹
    if not os.path.isdir(image_dir):
        raise Exception(image_dir + ' 应该为一个图片文件夹')

    # top_k小于choice_label数
    if top_k > len(choice_label):
        raise Exception('top_k大于候选标签数')

    # 读取图片
    original_images = []
    images = []

    for filename in [filename for filename in os.listdir(image_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
        image = Image.open(os.path.join(image_dir, filename)).convert("RGB")

        original_images.append(image)
        images.append(preprocess(image))

    # 输入特征
    text_descriptions = [f"This is a photo of a {label}" for label in choice_label]
    text_tokens = clip.tokenize(text_descriptions).cuda()

    image_input = torch.tensor(np.stack(images)).cuda()

    with torch.no_grad():
        image_features = model.encode_image(image_input).float()
        text_features = model.encode_text(text_tokens).float()

        image_features /= image_features.norm(dim = -1, keepdim=True)
        text_features /= text_features.norm(dim = -1, keepdim=True)

    # 相似度得分
    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    top_probs, top_labels = text_probs.cpu().topk(top_k, dim = -1)
    show_result_m(subfolder, original_images, top_probs, top_labels, choice_label)
    print("done!")


def show_result_m(subfolder, images, probs, labels, label_name):
    os.makedirs(f'experiment/{subfolder}', exist_ok=True)

    for i, image in enumerate(images):
        fig, ax = plt.subplots(1, 2, figsize=(16, 4))

        ax[0].imshow(image)
        ax[0].axis("off")

        y = np.arange(probs.shape[-1])
        ax[1].grid()
        ax[1].barh(y, probs[i])
        ax[1].invert_yaxis()
        ax[1].set_axisbelow(True)
        ax[1].set_yticks(y)
        ax[1].set_yticklabels([label_name[index] for index in labels[i].numpy()])
        ax[1].set_xlabel("probability")

        plt.subplots_adjust(wspace=1)
        plt.savefig(f'experiment/{subfolder}/result_{i}.png', bbox_inches='tight')
        plt.close(fig)

In [4]:
clip_classifier_m('miniimagenet', 'miniimagenet', ['house_finch', 'robin', 'triceratops', 'green_mamba', 'harvestman', 'toucan', 'goose', 'jellyfish', 'nematode', 'king_crab', 'dugong', 'Walker_hound', 'Ibizan_hound', 'Saluki', 'golden_retriever', 'Gordon_setter', 'komondor', 'boxer', 'Tibetan_mastiff', 'French_bulldog', 'malamute', 'dalmatian', 'Newfoundland', 'miniature_poodle', 'white_wolf', 'African_hunting_dog', 'Arctic_fox', 'lion', 'meerkat', 'ladybug', 'rhinoceros_beetle', 'ant', 'black-footed_ferret', 'three-toed_sloth', 'rock_beauty', 'aircraft_carrier', 'ashcan', 'barrel', 'beer_bottle', 'bookshop', 'cannon', 'carousel', 'carton', 'catamaran', 'chime', 'clog', 'cocktail_shaker', 'combination_lock', 'crate', 'cuirass', 'dishrag', 'dome', 'electric_guitar', 'file', 'fire_screen', 'frying_pan', 'garbage_truck', 'hair_slide', 'holster', 'horizontal_bar', 'hourglass', 'iPod', 'lipstick', 'miniskirt', 'missile', 'mixing_bowl', 'oboe', 'organ', 'parallel_bars', 'pencil_box', 'photocopier', 'poncho', 'prayer_rug', 'reel', 'school_bus', 'scoreboard', 'slot', 'snorkel', 'solar_dish', 'spider_web', 'stage', 'tank', 'theater_curtain', 'tile_roof', 'tobacco_shop', 'unicycle', 'upright', 'vase', 'wok', 'worm_fence', 'yawl', 'street_sign', 'consomme', 'trifle', 'hotdog', 'orange', 'cliff', 'coral_reef', 'bolete', 'ear'])

done!


In [5]:
clip_classifier_m('archive', 'archive', ['air hockey', 'ampute football', 'archery', 'arm wrestling', 'axe throwing', 'balance beam', 'barell racing', 'baseball', 'basketball', 'baton twirling', 'bike polo', 'billiards', 'bmx', 'bobsled', 'bowling', 'boxing', 'bull riding', 'bungee jumping', 'canoe slamon', 'cheerleading', 'chuckwagon racing', 'cricket', 'croquet', 'curling', 'disc golf', 'fencing', 'field hockey', 'figure skating men', 'figure skating pairs', 'figure skating women', 'fly fishing', 'football', 'formula 1 racing', 'frisbee', 'gaga', 'giant slalom', 'golf', 'hammer throw', 'hang gliding', 'harness racing', 'high jump', 'hockey', 'horse jumping', 'horse racing', 'horseshoe pitching', 'hurdles', 'hydroplane racing', 'ice climbing', 'ice yachting', 'jai alai', 'javelin', 'jousting', 'judo', 'lacrosse', 'log rolling', 'luge', 'motorcycle racing', 'mushing', 'nascar racing', 'olympic wrestling', 'parallel bar', 'pole climbing', 'pole dancing', 'pole vault', 'polo', 'pommel horse', 'rings', 'rock climbing', 'roller derby', 'rollerblade racing', 'rowing', 'rugby', 'sailboat racing', 'shot put', 'shuffleboard', 'sidecar racing', 'ski jumping', 'sky surfing', 'skydiving', 'snow boarding', 'snowmobile racing', 'speed skating', 'steer wrestling', 'sumo wrestling', 'surfing', 'swimming', 'table tennis', 'tennis', 'track bicycle', 'trapeze', 'tug of war', 'ultimate', 'uneven bars', 'volleyball', 'water cycling', 'water polo', 'weightlifting', 'wheelchair basketball', 'wheelchair racing', 'wingsuit flying'])

done!


In [3]:
from tqdm import tqdm

def test(cifar):
    correct = 0

    for i in tqdm(range(0, len(cifar)), desc="Processing", ncols=100):
        image, class_id = cifar[i]
        image_input = preprocess(image).unsqueeze(0).to(device)
        text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar.classes]).to(device)

        # Calculate features
        with torch.no_grad():
            image_features = model.encode_image(image_input)
            text_features = model.encode_text(text_inputs)

        # Pick the top 5 most similar labels for the image
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
        values, indices = similarity[0].topk(5)

        if indices[0] == class_id:
            correct += 1

    accuracy = 100 * correct / len(cifar)
    return accuracy

In [4]:
from torchvision.datasets import CIFAR100, CIFAR10
import torchvision.transforms as transforms

dataset_path = os.path.join(os.getcwd(), "dataset")
os.makedirs(dataset_path, exist_ok=True)

cifar10 = CIFAR10(dataset_path, train=False, download=True)
cifar100 = CIFAR100(dataset_path, train=False, download=True)

print(f'CIFAR10: {len(cifar10)}')
print(f'CIFAR100: {len(cifar100)}')

cifar10_accuracy = test(cifar10)
cifar100_accuracy = test(cifar100)

print(f'CIFAR10 Test Accuracy: {cifar10_accuracy:.2f}%')
print(f'CIFAR100 Test Accuracy: {cifar100_accuracy:.2f}%')


Files already downloaded and verified
Files already downloaded and verified
CIFAR10: 10000
CIFAR100: 10000


Processing: 100%|█████████████████████████████████████████████| 10000/10000 [08:04<00:00, 20.65it/s]
Processing: 100%|█████████████████████████████████████████████| 10000/10000 [23:45<00:00,  7.01it/s]

CIFAR10 Test Accuracy: 88.78%
CIFAR100 Test Accuracy: 61.71%



