### Basic Imports

In [1]:
import os
import time
import os.path as osp

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader,Subset

from torchvision.datasets import CIFAR10
from torchvision import datasets
from torchvision import transforms
import torchvision
import torchvision.models as models

from transformers import CLIPProcessor, CLIPModel

from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from PIL import Image
from clip import clip

In [2]:
#pip install ipywidgets

In [3]:
#pip install ftfy

In [4]:
BATCH_SIZE = 128
VISUAL_BACKBONE = "ViT-B/16" # RN50, ViT-B/32, ViT-B/16 

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")


### Datasets


In [6]:
# #CIFAR100
# transform_cifar100_test = transforms.Compose([
#     transforms.Resize(size=224),
#     transforms.CenterCrop(size=(224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
# ])

# # transform_cifar100_test = transforms.Compose([
# #     transforms.ToTensor(),
# # ])

# test_set = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_cifar100_test)

# test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)

# train_set = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_cifar100_test)

# train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=False)

# subset_indices = list(range(100))
# subset_loader = DataLoader(Subset(test_set, subset_indices),shuffle=False)
# dataset_name = 'CIFAR100'

In [7]:
# #MINST
# transform_MINST_test = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.Grayscale(num_output_channels=3),
#     transforms.ToTensor(),
# ])

# # transform_MINST_test = transforms.Compose([
# #     transforms.Grayscale(num_output_channels=3),
# #     transforms.ToTensor(),
# # ])


# test_set = datasets.MNIST(root='./data', train=False, download=True, transform=transform_MINST_test)

# test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)

# train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform_MINST_test)

# train_dataloader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)

# subset_indices = list(range(100))
# subset_loader = DataLoader(Subset(test_set, subset_indices),shuffle=False)
# dataset_name = 'MINST'

In [8]:
#Fashion-MINST

transform_MINST_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
])

# transform_MINST_test = transforms.Compose([
#     transforms.Grayscale(num_output_channels=3),
#     transforms.ToTensor(),
# ])

test_set = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_MINST_test)

test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)

train_set = torchvision.datasets.FashionMNIST(root='/shareddata', train=True,
                                       download=True, transform=transform_MINST_test)
train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2)

subset_indices = list(range(100))
subset_loader = DataLoader(Subset(test_set, subset_indices),shuffle=False)

dataset_name = 'Fashion-MINST'

In [9]:
# #CIFAR10

# transform_cifar10_test = transforms.Compose([
#     transforms.Resize(size=224),
#     transforms.CenterCrop(size=(224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
# ])

# # transform_cifar10_test = transforms.Compose([
# #     transforms.ToTensor(),
# # ])

# test_set = torchvision.datasets.CIFAR10(root='/shareddata', train=False,
#                                        download=True, transform=transform_cifar10_test)
# test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE,
#                                          shuffle=False, num_workers=2)

# train_set = torchvision.datasets.CIFAR10(root='/shareddata', train=True,
#                                        download=True, transform=transform_cifar10_test)
# train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE,
#                                          shuffle=False, num_workers=2)

# dataset_name = 'CIFAR10'

# subset_indices = list(range(10000))
# subset_loader = DataLoader(Subset(test_set, subset_indices),shuffle=False)


In [10]:
### The CLIP Model

In [11]:
# Load the model
model, preprocess = clip.load(name=VISUAL_BACKBONE, device=device, download_root='/shareddata/clip/')
model.to(device)

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

In [12]:
# if we want to make a fine-tuning

In [13]:
def prompt_encode(prompt):
    return torch.cat([clip.tokenize(f"{prompt} {c}") for c in train_set.classes]).to(device)

def model_inference(model,image,prompt):
    return model.logit_scale.exp()* F.normalize(model.encode_image(image),dim=-1) @ F.normalize(model.encode_text(prompt_encode(prompt)),dim=-1).t()

In [14]:
# learning_rate = 0.001
# optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# criterion = nn.CrossEntropyLoss()

# num_epochs = 1  # 

# model.to(device)
# model.train()

# for epoch in range(num_epochs):
#     for images, labels in subset_loader:
#         images, labels = images.to(device), labels.to(device)

#         #inputs = ?

#         #outputs = model(**inputs)
        
#         outputs = model_inference(model,images,"")

#         loss = criterion(outputs, labels)

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#     print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")
    

In [15]:
### The ResNet Model

In [16]:
# 加载预训练的 ResNet-18 模型
#resnet18 = models.resnet18(pretrained=True)
# 修改 ResNet-18 的分类层
#resnet18.fc = nn.Linear(in_features=resnet18.fc.in_features, out_features=100)


In [17]:
#if we want to make a fine-tuning

In [18]:
# learning_rate = 0.001
# optimizer = optim.SGD(resnet18.parameters(), lr=learning_rate, momentum=0.9)
# criterion = nn.CrossEntropyLoss()

# num_epochs = 5 

# resnet18.to(device)
# resnet18.train()

# for epoch in range(num_epochs):
#     for inputs, targets in train_dataloader: 
#         inputs, targets = inputs.to(device), targets.to(device)

#         optimizer.zero_grad()
#         outputs = resnet18(inputs)
#         loss = criterion(outputs, targets)
#         loss.backward()
#         optimizer.step()

#     print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")


In [19]:
def testCLIP(i,prompt,dataset_name,dataloader,VISUAL_BACKBONE,model):
    
    with torch.no_grad():
        model.eval()

    val_1 = 0 #对的个数
    val_0 = 0 #错的个数

    with torch.no_grad():
        #test_dataloader subset_loader
        for batch_idx, (image, target) in enumerate(dataloader):
            image = image.to(device)
            target = target.to(device)
            logits = model_inference(model,image,prompt)
            _, preds = torch.max(logits, 1)

            val_1 = val_1 + torch.sum(preds == target.data)

        val_acc = val_1.double() / len(test_set)

        print("the prompt is :"+str(prompt))
        if(dataloader==test_dataloader):
            print("whole test")
        else:print("subset test")
        print(f"the zero-shot performance on {dataset_name} is {val_acc*100:.2f}%, visual encoder is {VISUAL_BACKBONE}.")
        
        with open(str(i)+dataset_name+".txt", 'a') as file:
            if(dataloader==test_dataloader):
                file.write("whole test"+"\n")
            else:file.write("subset test"+"\n")
            file.write("the promt is : "+str(prompt)+"\n"
                +f"the zero-shot performance on {dataset_name} is {val_acc*100:.2f}%, visual encoder is {VISUAL_BACKBONE}."+"\n")
            file.close


In [20]:
#testRN(the order,the dataset_name,is_fine_tuned,dataloader,model)
def testRN(i,dataset_name,is_fine_tuned,dataloader,model):
    
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad():
        resnet18.eval()
        for inputs, targets in dataloader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total_samples += targets.size(0)
            correct_predictions += (predicted==targets).sum().item()

        accuracy = correct_predictions / total_samples
        
        if(is_fine_tuned):
            print("after fine-tune"+"\n")
        else:print("without finie-tune"+"\n")
        
        if(dataloader==test_dataloader):
            print("whole test"+"\n")
        else:print("subset test"+"\n")
        print(f"the ResNet's performance on {dataset_name} is {accuracy*100:.2f}%")
        
        with open(str(i)+'ResNet'+dataset_name+".txt", 'a') as file:
            if(is_fine_tuned):
                file.write("after fine-tune")
            else:file.write("without finie-tune")
            if(dataloader==test_dataloader):
                file.write("whole test"+"\n")
            else:file.write("subset test"+"\n")
            file.write(f"the ResNet's performance on {dataset_name} is {accuracy*100:.2f}%")
            file.close

In [21]:
i=23

In [22]:
#test_dataloader, subset_loader
#A photo of	,qweasd ,dhajkdhiugfw ,the number (for the MNIST dataset only)
testCLIP(i,"dhajkdhiugfw",dataset_name,test_dataloader,VISUAL_BACKBONE,model)

the prompt is :dhajkdhiugfw
whole test
the zero-shot performance on Fashion-MINST is 55.01%, visual encoder is ViT-B/16.


In [23]:
# #test_dataloader subset_loader
# #testRN(the order,the dataset_name,is_fine_tuned,dataloader,model)
# testRN(i,dataset_name,True,test_dataloader,model)