In [1]:
import os
import time
import os.path as osp

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import CIFAR10
from torchvision.datasets import CIFAR100
from torchvision.datasets import FashionMNIST

from torchvision import datasets
from torchvision import transforms
import torchvision

from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from PIL import Image
from clip import clip

In [2]:
# # random seed
# SEED = 1 
# NUM_CLASS = 10

# Training
BATCH_SIZE = 64
# NUM_EPOCHS = 30
# EVAL_INTERVAL=1
# SAVE_DIR = './log'

# # Optimizer
# LEARNING_RATE = 1e-1
# MOMENTUM = 0.9
# STEP=5
# GAMMA=0.5

# CLIP
VISUAL_BACKBONE = 'ViT-B/32' # RN50, ViT-B/32, ViT-B/16

In [3]:
import os
import re
from PIL import Image
from torch.utils.data import Dataset

class yfcc100ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        self.classes = set()
        for folder_name in os.listdir(root_dir):
            folder_path = os.path.join(root_dir, folder_name, 'val/images/')
            folder_name_modified = re.sub(r'\d+$', '', folder_name.replace('_', ' '))

            if os.path.isdir(folder_path):
                for image_name in os.listdir(folder_path):
                    image_path = os.path.join(folder_path, image_name)
                    
                   
             
                    
                    # Check if it's a file and not a hidden file or directory
                    if os.path.isfile(image_path) and not image_name.startswith('.') and image_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
                        #image = Image.open(image_path)
                        #print(f"Image Size: {image.size} for {image_name}")  # 打印图像尺寸
                        self.samples.append((image_path, f"a photo of {folder_name_modified}".strip()))
                        self.classes.add(self.samples[-1][1])
                        

        self.classes = list(self.classes)

    def __len__(self): 
        return len(self.samples) 
    
    def __getitem__(self, idx): 
        image_path, caption = self.samples[idx]
        image = Image.open(image_path).convert('RGB') # Ensure the image is in RGB format
        if self.transform: 
            image = self.transform(image) 
        return image, self.classes.index(caption)


In [4]:
transform = transforms.Compose([
    transforms.Resize(size=224),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
dataset = yfcc100ImageDataset(root_dir='./data/yfcc100/OANet/yfcc100m', transform=transform) 
data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
#testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
#testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)


In [5]:
class_descriptions = dataset.classes 
#print(class_descriptions)

#with torch.no_grad():
   # for images, labels in data_loader:
        #print(labels)

#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device="cpu"
# Load the model
model, preprocess = clip.load(name=VISUAL_BACKBONE, device=device, download_root='/shareddata/clip/')
model.to(device)

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

In [6]:
#class_descriptions = testset.classes # 每个 CIFAR100 类别的描述
class_descriptions

['a photo of milan cathedral',
 'a photo of united states capitol rotunda',
 'a photo of notre dame rosary window',
 'a photo of piazza della signoria',
 'a photo of british museum',
 'a photo of brandenburg gate',
 'a photo of paris opera',
 'a photo of grand place brussels',
 'a photo of sagrada familia',
 'a photo of colosseum interior',
 'a photo of piazza dei miracoli',
 'a photo of pike place market',
 'a photo of big ben',
 'a photo of colosseum exterior',
 'a photo of st vitus cathedral',
 'a photo of pieta michelangelo',
 'a photo of st peters square',
 'a photo of palace of versailles chapel',
 'a photo of pantheon exterior',
 'a photo of temple kyoto japan',
 'a photo of palace of westminster',
 'a photo of petra jordan',
 'a photo of sistine chapel ceiling',
 'a photo of trevi fountain',
 'a photo of mount rushmore',
 'a photo of piazza san marco',
 'a photo of st pauls cathedral',
 'a photo of louvre',
 'a photo of florence cathedral side',
 'a photo of national gallery lo

In [7]:
def classify_images(model, data_loader, class_descriptions, device):
    model.eval()
    correct = 0
    total = 0

    # 将类别描述转换为文本特征
    #text_tokens = clip.tokenize(class_descriptions).to(device)
  #  with torch.no_grad():
  #      text_features = model.encode_text(text_tokens)
    prompt= 'a graph of'
    text_tokens = torch.cat([clip.tokenize(f"{prompt} {c}") for c in class_descriptions]).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text_tokens)
        text_features /= text_features.norm(dim=-1, keepdim=True)
    with torch.no_grad():
        for images, labels in data_loader:
            
            

            images = images.to(device)
            
            labels = labels.to(device)

            image_features = model.encode_image(images)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            # 计算图像特征与每个类别文本特征之间的相似度
            logit_scale = model.logit_scale.exp()
            logits = logit_scale * image_features @ text_features.t()
            

            # 获取最高相似度的类别作为预测
            predictions = logits.argmax(dim=-1)

            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    return correct / total

In [8]:
# 4. 性能评估
print(device)
accuracy = classify_images(model, data_loader, class_descriptions, device)
print(f'Accuracy on yfcc100m test images: {accuracy * 100:.2f}% with VISUAL_BACKBONE {VISUAL_BACKBONE}' )

cpu
Accuracy on yfcc100m test images: 83.96% with VISUAL_BACKBONE ViT-B/32
