In [1]:
import os
from PIL import Image
import numpy as np
from CLIP import clip
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn as nn
import json

In [2]:
from torchvision.datasets import MNIST
mnist = MNIST(root=os.path.expanduser("~/.cache"), download=True, train=True)
class YourDataset(Dataset):         #每一个类抽10个，如何实现？
    def __init__(self,dataset,preprocess,sam_num):

        self.img_process = preprocess

        self.samples = []
        self.sam_labels = []
        self.samid=[]

        catcount=[0,0,0,0,0,0,0,0,0,0]
        for image,cl_id in dataset:

            if(catcount[cl_id]<sam_num):
                label=dataset.classes[cl_id]
                label = "a photo of " + label
                self.samples.append(image)
                self.sam_labels.append(label)
                self.samid.append(cl_id)
                catcount[cl_id]+=1
            if(all(item >sam_num-1 for item in catcount)):
                break
        self.tokens = clip.tokenize(self.sam_labels)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image = self.samples[idx]
        token = self.tokens[idx]

        image = self.img_process(image)
        return image,token,self.samid[idx]

device = torch.device("cuda:0")
net, preprocess = clip.load("ViT-B/32",device=device,jit=False)
sam_num=5
datasets=YourDataset(mnist,preprocess,sam_num)
net.eval()

CLIP(
  (visual): VisualTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

In [3]:
def aux_cstr(datasets,num_cls):
    i=0
    

    for image,_,label in datasets:
        image = image.unsqueeze(0).to(device)
        image=net.encode_image(image)

        if i==0:
            temp=image.to(device)
            templ=torch.zeros([1,num_cls]).to(device)
            templ[0,label]+=1
            i+=1
        else:
            temp=torch.cat([temp,image.to(device)])
            templl=torch.zeros([1,num_cls]).to(device)
            templl[0,label]+=1
            templ=torch.cat([templ,templl])
    return temp.T,templ
        
vresult,lresult=aux_cstr(datasets,10)
print(vresult.shape)
print(lresult.shape)

    

torch.Size([512, 50])
torch.Size([50, 10])


In [4]:
class TestDataset(Dataset):         #每一个类抽10个，如何实现？
    def __init__(self,dataset,preprocess):

        self.img_process = preprocess

        self.samples = []
        self.sam_labels = []
        self.samid=[]

        for image,cl_id in dataset:
            label=dataset.classes[cl_id]
            label = "a photo of " + label
            self.samples.append(image)
            self.sam_labels.append(label)
            self.samid.append(cl_id)
        self.tokens = clip.tokenize(self.sam_labels)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image = self.samples[idx]
        token = self.tokens[idx]
        image = self.img_process(image)
        return image,token,self.samid[idx]
test_set= MNIST(root=os.path.expanduser("~/.cache"), download=True, train=False)
test_set=TestDataset(test_set,preprocess)
test_set=DataLoader(dataset=test_set,batch_size=1,shuffle=False,num_workers=4,pin_memory=False)


In [5]:
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in mnist.classes]).to(device)
# text_features = net.encode_text(text_inputs)
# print(text_features.shape)

In [6]:

def zero_shot_wrapper(test_set,vresult,lresult,text_inputs):
    count=0
    right=0
    with torch.no_grad():
        text_features = net.encode_text(text_inputs)
        text_features /= text_features.norm(dim=-1, keepdim=True)
    with torch.cuda.amp.autocast(enabled=True):
        for images,labels,cl_id in test_set:
            for i in range(5):
                torch.cuda.empty_cache()
            count+=1
            images = images.to(device)
            with torch.no_grad():
                images=net.encode_image(images)

                images /= images.norm(dim=-1, keepdim=True)
                vresult /= vresult.norm(dim=-1, keepdim=True)

                logits=1.0*torch.exp(-5.5*(1-images@vresult))@lresult+images@(text_features.T)
            values, indices = logits.topk(1)
            if int(indices)==cl_id:
                right+=1
                print("count:",count,"right:",right)


        
zero_shot_wrapper(test_set,vresult,lresult,text_inputs)

count: 1 right: 1
count: 2 right: 2
count: 3 right: 3
count: 4 right: 4
count: 5 right: 5
count: 6 right: 6
count: 7 right: 7
count: 9 right: 8
count: 10 right: 9
count: 11 right: 10
count: 12 right: 11
count: 13 right: 12
count: 14 right: 13
count: 15 right: 14
count: 17 right: 15
count: 18 right: 16
count: 20 right: 17
count: 21 right: 18
count: 22 right: 19
count: 23 right: 20
count: 24 right: 21
count: 26 right: 22
count: 29 right: 23
count: 30 right: 24
count: 31 right: 25
count: 32 right: 26
count: 33 right: 27
count: 34 right: 28
count: 36 right: 29
count: 37 right: 30
count: 38 right: 31
count: 39 right: 32
count: 40 right: 33
count: 41 right: 34
count: 42 right: 35
count: 47 right: 36
count: 50 right: 37
count: 51 right: 38
count: 52 right: 39
count: 56 right: 40
count: 57 right: 41
count: 58 right: 42
count: 59 right: 43
count: 60 right: 44
count: 62 right: 45
count: 65 right: 46
count: 67 right: 47
count: 69 right: 48
count: 70 right: 49
count: 71 right: 50
count: 72 right: 