### Basic Imports

In [2]:
import os
import time
import os.path as osp

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.datasets import CIFAR100
from torchvision import datasets
from torchvision import transforms
from torchvision import models
import torchvision

from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from PIL import Image
import clip

  from .autonotebook import tqdm as notebook_tqdm


### Device

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Model 1: Clip

In [4]:
# CLIP
VISUAL_BACKBONE = 'ViT-B/16' # RN50, ViT-B/32, ViT-B/16
BATCH_SIZE = 128
model, preprocess = clip.load(name=VISUAL_BACKBONE, device=device, download_root='/shareddata/clip/')
model.to(device);

def model_inference(model, image, text):
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)

    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    logit_scale = model.logit_scale.exp()
    logits = logit_scale * image_features @ text_features.t()

    return logits

### Dataset 1: CIFAR10

In [5]:
transform_cifar10_test = transforms.Compose([
    transforms.Resize(size=224),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

test_set = CIFAR10(root='/shareddata', train=False, download=True, transform=transform_cifar10_test)
test_dataloader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

prompt = [f"A photo of a {class_name}" for class_name in class_names]

text_inputs = torch.cat([clip.tokenize(f"{prompt} {c}") for c in class_names]).to(device)



with torch.no_grad():
    model.eval()
    val_loss = 0.0
    val_corrects = 0

    for batch_idx, (image, target) in enumerate(test_dataloader):

        image = image.to(device)
        target = target.to(device)

        # test model
        logits = model_inference(model, image, text_inputs)
        _, preds = torch.max(logits, 1)
        
        val_corrects += torch.sum(preds == target.data)

    val_acc = val_corrects.double() / len(test_set)

    print(f"the zero-shot performance on CIFAR is {val_acc*100:.2f}%, visual encoder is {VISUAL_BACKBONE}.")
torch.cuda.empty_cache()

Files already downloaded and verified
the zero-shot performance on CIFAR is 83.49%, visual encoder is ViT-B/16.


## Dataset 2: CIFAR-100

In [6]:
# Transform for CIFAR-100 test set
transform_cifar100_test = transforms.Compose([
    transforms.Resize(size=224),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Load CIFAR-100 test set
test_set_cifar100 = CIFAR100(root='./data', train=False, download=True, transform=transform_cifar100_test)
test_dataloader_cifar100 = DataLoader(test_set_cifar100, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Class names for CIFAR-100
class_names_cifar100 = [
    'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bicycle', 'bottle', 'bowl', 'boy',
    'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee',
    'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant',
    'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower',
    'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom',
    'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate',
    'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark',
    'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper',
    'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe',
    'whale', 'willow_tree', 'wolf', 'woman', 'worm', 'zebra', 'mushroom']

# Prompt for CLIP model
prompt = "It’s a picture of a"
text_inputs_cifar100 = torch.cat([clip.tokenize(f"{prompt} {c}") for c in class_names_cifar100]).to(device)


with torch.no_grad():
    model.eval()

    val_loss = 0.0
    val_corrects = 0

    for batch_idx, (image, target) in enumerate(test_dataloader_cifar100):

        image = image.to(device)
        target = target.to(device)

        # Test model
        logits_cifar100 = model_inference(model, image, text_inputs_cifar100)
        _, preds_cifar100 = torch.max(logits_cifar100, 1)
        
        val_corrects += torch.sum(preds_cifar100 == target.data)

    val_acc_cifar100 = val_corrects.double() / len(test_set_cifar100)

    print(f"The zero-shot performance on CIFAR-100 is {val_acc_cifar100*100:.2f}%, visual encoder is {VISUAL_BACKBONE}.")
torch.cuda.empty_cache()


Files already downloaded and verified
The zero-shot performance on CIFAR-100 is 3.77%, visual encoder is ViT-B/16.


## Dataset 3: DTD

In [7]:
# Transform for DTD test set
transform_dtd_test = transforms.Compose([
    transforms.Resize(size=224),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# Download and load DTD test set
dtd_test_set = datasets.ImageFolder(root='/shareddata/dtd/dtd/images', transform=transform_dtd_test)
dtd_test_dataloader = DataLoader(dtd_test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Prompt for CLIP model
prompt = "It’s a picture featuring patterns of"

# Tokenize class names for CLIP model input
text_inputs_dtd = torch.cat([clip.tokenize(f"{prompt} {c}") for c in dtd_test_set.classes]).to(device)

# Test on DTD dataset
with torch.no_grad():
    model.eval()

    val_loss = 0.0
    val_corrects = 0

    for batch_idx, (image, target) in enumerate(dtd_test_dataloader):

        image = image.to(device)
        target = target.to(device)

        # Test model
        logits_dtd = model_inference(model, image, text_inputs_dtd)
        _, preds_dtd = torch.max(logits_dtd, 1)
        
        val_corrects += torch.sum(preds_dtd == target.data)

    val_acc_dtd = val_corrects.double() / len(dtd_test_set)

    print(f"The zero-shot performance on DTD is {val_acc_dtd*100:.2f}%, visual encoder is {VISUAL_BACKBONE}.")
torch.cuda.empty_cache()

The zero-shot performance on DTD is 42.91%, visual encoder is ViT-B/16.


## Dataset 4: Oxford Flowers

In [8]:
flower_dict = {
'21': 'fire lily',
 '3': 'canterbury bells',
 '45': 'bolero deep blue',
 '1': 'pink primrose',
 '34': 'mexican aster',
 '27': 'prince of wales feathers',
 '7': 'moon orchid',
 '16': 'globe-flower',
 '25': 'grape hyacinth',
 '26': 'corn poppy',
 '79': 'toad lily',
 '39': 'siam tulip',
 '24': 'red ginger',
 '67': 'spring crocus',
 '35': 'alpine sea holly',
 '32': 'garden phlox',
 '10': 'globe thistle',
 '6': 'tiger lily',
 '93': 'ball moss',
 '33': 'love in the mist',
 '9': 'monkshood',
 '102': 'blackberry lily',
 '14': 'spear thistle',
 '19': 'balloon flower',
 '100': 'blanket flower',
 '13': 'king protea',
 '49': 'oxeye daisy',
 '15': 'yellow iris',
 '61': 'cautleya spicata',
 '31': 'carnation',
 '64': 'silverbush',
 '68': 'bearded iris',
 '63': 'black-eyed susan',
 '69': 'windflower',
 '62': 'japanese anemone',
 '20': 'giant white arum lily',
 '38': 'great masterwort',
 '4': 'sweet pea',
 '86': 'tree mallow',
 '101': 'trumpet creeper',
 '42': 'daffodil',
 '22': 'pincushion flower',
 '2': 'hard-leaved pocket orchid',
 '54': 'sunflower',
 '66': 'osteospermum',
 '70': 'tree poppy',
 '85': 'desert-rose',
 '99': 'bromelia',
 '87': 'magnolia',
 '5': 'english marigold',
 '92': 'bee balm',
 '28': 'stemless gentian',
 '97': 'mallow',
 '57': 'gaura',
 '40': 'lenten rose',
 '47': 'marigold',
 '59': 'orange dahlia',
 '48': 'buttercup',
 '55': 'pelargonium',
 '36': 'ruby-lipped cattleya',
 '91': 'hippeastrum',
 '29': 'artichoke',
 '71': 'gazania',
 '90': 'canna lily',
 '18': 'peruvian lily',
 '98': 'mexican petunia',
 '8': 'bird of paradise',
 '30': 'sweet william',
 '17': 'purple coneflower',
 '52': 'wild pansy',
 '84': 'columbine',
 '12': "colt's foot",
 '11': 'snapdragon',
 '96': 'camellia',
 '23': 'fritillary',
 '50': 'common dandelion',
 '44': 'poinsettia',
 '53': 'primula',
 '72': 'azalea',
 '65': 'californian poppy',
 '80': 'anthurium',
 '76': 'morning glory',
 '37': 'cape flower',
 '56': 'bishop of llandaff',
 '60': 'pink-yellow dahlia',
 '82': 'clematis',
 '58': 'geranium',
 '75': 'thorn apple',
 '41': 'barbeton daisy',
 '95': 'bougainvillea',
 '43': 'sword lily',
 '83': 'hibiscus',
 '78': 'lotus lotus',
 '88': 'cyclamen',
 '94': 'foxglove',
 '81': 'frangipani',
 '74': 'rose',
 '89': 'watercress',
 '73': 'water lily',
 '46': 'wallflower',
 '77': 'passion flower',
 '51': 'petunia'
}

# 提取所有的值
class_names_oxford_flowers = flower_dict.values()

In [9]:
# Transform for Oxford Flowers test set
transform_oxford_flowers_test = transforms.Compose([
    transforms.Resize(size=224),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

test_set_oxford_flowers = torchvision.datasets.ImageFolder(root='/shareddata/flowers-102', transform=transform_oxford_flowers_test)
test_dataloader_oxford_flowers = DataLoader(test_set_oxford_flowers, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# 提示词 for CLIP 模型
prompt_oxford_flowers = "It’s a picture of a flower from dataset oxford flowers, specifically a"

# Tokenize 类别名称 for CLIP 模型输入
text_inputs_oxford_flowers = torch.cat([clip.tokenize(f"{prompt_oxford_flowers} {c}") for c in class_names_oxford_flowers]).to(device)

with torch.no_grad():
    model.eval()

    val_loss = 0.0
    val_corrects = 0

    for batch_idx, (image, target) in enumerate(test_dataloader_oxford_flowers):

        image = image.to(device)
        target = target.to(device)

        # Test model
        logits_oxford_flowers = model_inference(model, image, text_inputs_oxford_flowers) 
        _, preds_oxford_flowers = torch.max(logits_oxford_flowers, 1)

        val_corrects += torch.sum(preds_oxford_flowers == target.data)

    val_acc_oxford_flowers = val_corrects.double() / len(test_set_oxford_flowers)

    print(f"The zero-shot performance on Oxford Flowers is {val_acc_oxford_flowers*100:.2f}%, visual encoder is {VISUAL_BACKBONE}.")

torch.cuda.empty_cache()

The zero-shot performance on Oxford Flowers is 0.67%, visual encoder is ViT-B/16.


## Dataset 5: COIL-20

In [10]:
# Transform for COIL-20 test set
transform_coil20_test = transforms.Compose([
    transforms.Resize(size=224),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# Download and load COIL-20 test set
coil20_test_set = datasets.ImageFolder(root='data/coil20', transform=transform_coil20_test)
coil20_test_dataloader = DataLoader(coil20_test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Prompt for CLIP model
prompt = "It’s a picture of an object from the COIL-20 dataset, specifically a"

# Tokenize class names for CLIP model input
text_inputs_coil20 = torch.cat([clip.tokenize(f"{prompt} {c}") for c in coil20_test_set.classes]).to(device)

# Test on COIL-20 dataset
with torch.no_grad():
    model.eval()

    val_loss = 0.0
    val_corrects = 0

    for batch_idx, (image, target) in enumerate(coil20_test_dataloader):

        image = image.to(device)
        target = target.to(device)

        # Test model
        logits_coil20 = model_inference(model, image,text_inputs_coil20)
        _, preds_coil20 = torch.max(logits_coil20, 1)
        
        val_corrects += torch.sum(preds_coil20 == target.data)

    val_acc_coil20 = val_corrects.double() / len(coil20_test_set)

    print(f"The zero-shot performance on COIL-20 is {val_acc_coil20*100:.2f}%, visual encoder is {VISUAL_BACKBONE}.")
torch.cuda.empty_cache()

The zero-shot performance on COIL-20 is 6.94%, visual encoder is ViT-B/16.


## Dataset 6: MNIST

In [11]:
# Transform for MNIST test set
transform_mnist_test = transforms.Compose([
    transforms.Resize(size=224),
    transforms.CenterCrop(size=(224, 224)),
    transforms.Grayscale(num_output_channels=3),  # Convert to RGB for compatibility
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# Download and load MNIST test set
mnist_test_set = datasets.MNIST(root='/shareddata/MNIST', train=False, transform=transform_mnist_test, download=True)
mnist_test_dataloader = DataLoader(mnist_test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Prompt for CLIP model
prompt = "It’s a picture of a handwritten digit, specifically a"

# Tokenize class names for CLIP model input
text_inputs_mnist = torch.cat([clip.tokenize(f"{prompt} {i}") for i in range(10)]).to(device)

def model_inference(model, image, text):
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)

    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    logit_scale = model.logit_scale.exp()
    logits = logit_scale * image_features @ text_features.t()

    return logits

# Test on MNIST dataset
with torch.no_grad():
    model.eval()

    val_loss = 0.0
    val_corrects = 0

    for batch_idx, (image, target) in enumerate(mnist_test_dataloader):

        image = image.to(device)
        target = target.to(device)

        # Test model
        logits_mnist = model_inference(model, image, text_inputs_mnist)
        _, preds_mnist = torch.max(logits_mnist, 1)
        
        val_corrects += torch.sum(preds_mnist == target.data)

    val_acc_mnist = val_corrects.double() / len(mnist_test_set)

    print(f"The zero-shot performance on MNIST is {val_acc_mnist*100:.2f}%, visual encoder is {VISUAL_BACKBONE}.")
torch.cuda.empty_cache()

The zero-shot performance on MNIST is 28.58%, visual encoder is ViT-B/16.


# Model 2: ResNet-18

In [12]:
# 定义 ResNet-18 模型
class SimpleResNet18(nn.Module):
    def __init__(self, num_classes):
        super(SimpleResNet18, self).__init__()
        resnet18 = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*list(resnet18.children())[:-1])
        self.classifier = nn.Sequential(
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


## Hyperparameters

In [13]:
# Training
BATCH_SIZE = 128
NUM_EPOCHS = 10
# # Optimizer
LEARNING_RATE = 0.001

### Dataset 1


In [14]:
# 将CIFAR-10数据预处理
transform_cifar10_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_cifar10_test = transforms.Compose([
    transforms.Resize(size=224),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# 加载训练集和测试集
train_set = datasets.CIFAR10(root='/shareddata', train=True, download=True, transform=transform_cifar10_train)
train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

test_set = datasets.CIFAR10(root='/shareddata', train=False, download=True, transform=transform_cifar10_test)
test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)



# 初始化模型
num_classes = 10  # CIFAR-10 有 10 个类别
resnet18_model = SimpleResNet18(num_classes)
resnet18_model.to(device)


# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet18_model.parameters(), lr=LEARNING_RATE)

# 训练模型
for epoch in range(NUM_EPOCHS):
    resnet18_model.train()
    running_loss = 0.0
    for images, labels in train_dataloader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = resnet18_model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch {epoch + 1}, Loss: {running_loss / len(train_dataloader)}")

# 在测试集上评估模型
resnet18_model.eval()
corrects = 0
with torch.no_grad():
    for images, labels in test_dataloader:
        images, labels = images.to(device), labels.to(device)
        outputs = resnet18_model(images)
        _, preds = torch.max(outputs, 1)
        corrects += torch.sum(preds == labels.data)

test_accuracy = corrects.double() / len(test_set)
print(f"Test Accuracy: {test_accuracy * 100:.2f}%")
torch.cuda.empty_cache()

Files already downloaded and verified
Files already downloaded and verified




Epoch 1, Loss: 1.0229529229271443
Epoch 2, Loss: 0.7814653910639341
Epoch 3, Loss: 0.6861904384687428
Epoch 4, Loss: 0.632746474517276
Epoch 5, Loss: 0.5956372027964238
Epoch 6, Loss: 0.569012791680558
Epoch 7, Loss: 0.5500335960894289
Epoch 8, Loss: 0.5191921680174825
Epoch 9, Loss: 0.5112663771947632
Epoch 10, Loss: 0.4874238667585661
Test Accuracy: 89.54%


## Dataset 2

In [15]:
# 将CIFAR-100数据集预处理的代码
transform_cifar100_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_cifar100_test = transforms.Compose([
    transforms.Resize(size=224),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# 加载CIFAR-100训练集和测试集
train_set_cifar100 = datasets.CIFAR100(root='/shareddata', train=True, download=True, transform=transform_cifar100_train)
train_dataloader_cifar100 = torch.utils.data.DataLoader(train_set_cifar100, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

test_set_cifar100 = datasets.CIFAR100(root='/shareddata', train=False, download=True, transform=transform_cifar100_test)
test_dataloader_cifar100 = torch.utils.data.DataLoader(test_set_cifar100, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# 初始化模型
num_classes_cifar100 = 100  # CIFAR-100有100个类别
resnet18_model_cifar100 = SimpleResNet18(num_classes_cifar100)

# 将模型移到GPU上
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet18_model_cifar100.to(device)

# 定义损失函数和优化器
criterion_cifar100 = nn.CrossEntropyLoss()
optimizer_cifar100 = optim.Adam(resnet18_model_cifar100.parameters(), lr=LEARNING_RATE)

# 训练模型
for epoch in range(NUM_EPOCHS):
    resnet18_model_cifar100.train()
    running_loss = 0.0
    for images, labels in train_dataloader_cifar100:
        images, labels = images.to(device), labels.to(device)
        optimizer_cifar100.zero_grad()
        outputs = resnet18_model_cifar100(images)
        loss = criterion_cifar100(outputs, labels)
        loss.backward()
        optimizer_cifar100.step()
        running_loss += loss.item()

    print(f"Epoch {epoch + 1}, Loss: {running_loss / len(train_dataloader_cifar100)}")

# 在测试集上评估模型
resnet18_model_cifar100.eval()
corrects_cifar100 = 0
with torch.no_grad():
    for images, labels in test_dataloader_cifar100:
        images, labels = images.to(device), labels.to(device)
        outputs = resnet18_model_cifar100(images)
        _, preds = torch.max(outputs, 1)
        corrects_cifar100 += torch.sum(preds == labels.data)

test_accuracy_cifar100 = corrects_cifar100.double() / len(test_set_cifar100)
print(f"Test Accuracy on CIFAR-100: {test_accuracy_cifar100 * 100:.2f}%")
torch.cuda.empty_cache()

Files already downloaded and verified
Files already downloaded and verified
Epoch 1, Loss: 2.818615010022507
Epoch 3, Loss: 1.8864550090506864
Epoch 4, Loss: 1.7327599918750851
Epoch 5, Loss: 1.6371336522919442
Epoch 6, Loss: 1.5405272187479317
Epoch 7, Loss: 1.4584246318968361
Epoch 8, Loss: 1.394640124362448
Epoch 9, Loss: 1.3470561726928671
Epoch 10, Loss: 1.3047759773786112
Test Accuracy on CIFAR-100: 69.65%


## Dataset 3

In [16]:
# 将DTD数据集预处理的代码
transform_dtd_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

transform_dtd_test = transforms.Compose([
    transforms.Resize(size=224),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# 加载DTD训练集和测试集
train_set_dtd = datasets.ImageFolder(root='/shareddata/dtd/dtd/images', transform=transform_dtd_train)
train_dataloader_dtd = torch.utils.data.DataLoader(train_set_dtd, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

test_set_dtd = datasets.ImageFolder(root='/shareddata/dtd/dtd/images', transform=transform_dtd_test)
test_dataloader_dtd = torch.utils.data.DataLoader(test_set_dtd, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# 初始化模型
num_classes_dtd = len(train_set_dtd.classes)  # DTD数据集的类别数
resnet18_model_dtd = SimpleResNet18(num_classes_dtd)
resnet18_model_dtd.to(device)

# 定义损失函数和优化器
criterion_dtd = nn.CrossEntropyLoss()
optimizer_dtd = optim.Adam(resnet18_model_dtd.parameters(), lr=LEARNING_RATE)

# 训练模型
for epoch in range(NUM_EPOCHS):
    resnet18_model_dtd.train()
    running_loss = 0.0
    for images, labels in train_dataloader_dtd:
        images, labels = images.to(device), labels.to(device)
        optimizer_dtd.zero_grad()
        outputs = resnet18_model_dtd(images)
        loss = criterion_dtd(outputs, labels)
        loss.backward()
        optimizer_dtd.step()
        running_loss += loss.item()

    print(f"Epoch {epoch + 1}, Loss: {running_loss / len(train_dataloader_dtd)}")

# 在测试集上评估模型
resnet18_model_dtd.eval()
corrects_dtd = 0
with torch.no_grad():
    for images, labels in test_dataloader_dtd:
        images, labels = images.to(device), labels.to(device)
        outputs = resnet18_model_dtd(images)
        _, preds = torch.max(outputs, 1)
        corrects_dtd += torch.sum(preds == labels.data)

test_accuracy_dtd = corrects_dtd.double() / len(test_set_dtd)
print(f"Test Accuracy on DTD: {test_accuracy_dtd * 100:.2f}%")
torch.cuda.empty_cache()

Epoch 1, Loss: 2.369425392150879
Epoch 2, Loss: 1.7675701326794095
Epoch 3, Loss: 1.615386954943339
Epoch 4, Loss: 1.5855771117740207
Epoch 5, Loss: 1.4016689247555203
Epoch 6, Loss: 1.2852866583400302
Epoch 7, Loss: 1.1690501107109919
Epoch 8, Loss: 1.1540569543838501
Epoch 9, Loss: 1.0839050465159945
Epoch 10, Loss: 1.0653971989949544
Test Accuracy on DTD: 65.20%


## Dataset 4

In [None]:
# 将Oxford Flowers数据集预处理
transform_flowers_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

transform_flowers_test = transforms.Compose([
    transforms.Resize(size=224),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# 加载Oxford Flowers训练集和测试集
train_set_flowers = datasets.ImageFolder(root='/shareddata/flowers-102', transform=transform_flowers_train)
train_dataloader_flowers = torch.utils.data.DataLoader(train_set_flowers, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

test_set_flowers = datasets.ImageFolder(root='/shareddata/flowers-102', transform=transform_flowers_test)
test_dataloader_flowers = torch.utils.data.DataLoader(test_set_flowers, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# 初始化模型
num_classes_flowers = len(train_set_flowers.classes)  # Oxford Flowers数据集的类别数
resnet18_model_flowers = SimpleResNet18(num_classes_flowers)
resnet18_model_flowers.to(device)

# 定义损失函数和优化器
criterion_flowers = nn.CrossEntropyLoss()
optimizer_flowers = optim.Adam(resnet18_model_flowers.parameters(), lr=LEARNING_RATE)

# 训练模型
for epoch in range(NUM_EPOCHS):
    resnet18_model_flowers.train()
    running_loss_flowers = 0.0
    for images_flowers, labels_flowers in train_dataloader_flowers:
        images_flowers, labels_flowers = images_flowers.to(device), labels_flowers.to(device)
        optimizer_flowers.zero_grad()
        outputs_flowers = resnet18_model_flowers(images_flowers)
        loss_flowers = criterion_flowers(outputs_flowers, labels_flowers)
        loss_flowers.backward()
        optimizer_flowers.step()
        running_loss_flowers += loss_flowers.item()

    print(f"Epoch {epoch + 1}, Loss: {running_loss_flowers / len(train_dataloader_flowers)}")

# 在测试集上评估模型
resnet18_model_flowers.eval()
corrects_flowers = 0
with torch.no_grad():
    for images_flowers, labels_flowers in test_dataloader_flowers:
        images_flowers, labels_flowers = images_flowers.to(device), labels_flowers.to(device)
        outputs_flowers = resnet18_model_flowers(images_flowers)
        _, preds_flowers = torch.max(outputs_flowers, 1)
        corrects_flowers += torch.sum(preds_flowers == labels_flowers.data)

test_accuracy_flowers = corrects_flowers.double() / len(test_set_flowers)
print(f"Test Accuracy on Oxford Flowers: {test_accuracy_flowers * 100:.2f}%")
torch.cuda.empty_cache()

Epoch 1, Loss: 0.6671897567187747
Epoch 2, Loss: 0.6446548672392964
Epoch 3, Loss: 0.6427348675206304


## Dataset 5

In [None]:
# 将COIL-20数据集预处理
transform_coil20_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

transform_coil20_test = transforms.Compose([
    transforms.Resize(size=224),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# 加载COIL-20训练集和测试集
train_set_coil20 = datasets.ImageFolder(root='data/coil20', transform=transform_coil20_train)  # 替换为COIL-20数据集的路径
train_dataloader_coil20 = torch.utils.data.DataLoader(train_set_coil20, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

test_set_coil20 = datasets.ImageFolder(root='data/coil20', transform=transform_coil20_test)  # 替换为COIL-20数据集的路径
test_dataloader_coil20 = torch.utils.data.DataLoader(test_set_coil20, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# 初始化模型
num_classes_coil20 = len(train_set_coil20.classes)  # COIL-20数据集的类别数
resnet18_model_coil20 = SimpleResNet18(num_classes_coil20)
resnet18_model_coil20.to(device)

# 定义损失函数和优化器
criterion_coil20 = nn.CrossEntropyLoss()
optimizer_coil20 = optim.Adam(resnet18_model_coil20.parameters(), lr=LEARNING_RATE)

# 训练模型
for epoch in range(NUM_EPOCHS):
    resnet18_model_coil20.train()
    running_loss_coil20 = 0.0
    for images_coil20, labels_coil20 in train_dataloader_coil20:
        images_coil20, labels_coil20 = images_coil20.to(device), labels_coil20.to(device)
        optimizer_coil20.zero_grad()
        outputs_coil20 = resnet18_model_coil20(images_coil20)
        loss_coil20 = criterion_coil20(outputs_coil20, labels_coil20)
        loss_coil20.backward()
        optimizer_coil20.step()
        running_loss_coil20 += loss_coil20.item()

    print(f"Epoch {epoch + 1}, Loss: {running_loss_coil20 / len(train_dataloader_coil20)}")

# 在测试集上评估模型
resnet18_model_coil20.eval()
corrects_coil20 = 0
with torch.no_grad():
    for images_coil20, labels_coil20 in test_dataloader_coil20:
        images_coil20, labels_coil20 = images_coil20.to(device), labels_coil20.to(device)
        outputs_coil20 = resnet18_model_coil20(images_coil20)
        _, preds_coil20 = torch.max(outputs_coil20, 1)
        corrects_coil20 += torch.sum(preds_coil20 == labels_coil20.data)

test_accuracy_coil20 = corrects_coil20.double() / len(test_set_coil20)
print(f"Test Accuracy on COIL-20: {test_accuracy_coil20 * 100:.2f}%")
torch.cuda.empty_cache()

## Dataset6

In [None]:
# 将MNIST数据集预处理
transform_mnist_train = transforms.Compose([
    transforms.Resize((224, 224)),  # 调整大小以匹配ResNet-18的期望输入
    transforms.RandomHorizontalFlip(),
    transforms.Grayscale(3),  # 将图像转换为RGB格式
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),  # 根据MNIST数据集的特性设置标准化参数
])

transform_mnist_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(3),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

# 加载MNIST训练集和测试集
train_set_mnist = datasets.MNIST(root='/shareddata/MNIST', train=True, download=True, transform=transform_mnist_train)
train_dataloader_mnist = torch.utils.data.DataLoader(train_set_mnist, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

test_set_mnist = datasets.MNIST(root='/shareddata/MNIST', train=False, download=True, transform=transform_mnist_test)
test_dataloader_mnist = torch.utils.data.DataLoader(test_set_mnist, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# 初始化模型
num_classes_mnist = 10  # MNIST数据集的类别数
resnet18_model_mnist = SimpleResNet18(num_classes_mnist)
resnet18_model_mnist.to(device)

# 定义损失函数和优化器
criterion_mnist = nn.CrossEntropyLoss()
optimizer_mnist = optim.Adam(resnet18_model_mnist.parameters(), lr=LEARNING_RATE)

# 训练模型
for epoch in range(NUM_EPOCHS):
    resnet18_model_mnist.train()
    running_loss_mnist = 0.0
    for images_mnist, labels_mnist in train_dataloader_mnist:
        images_mnist, labels_mnist = images_mnist.to(device), labels_mnist.to(device)
        optimizer_mnist.zero_grad()
        outputs_mnist = resnet18_model_mnist(images_mnist)
        loss_mnist = criterion_mnist(outputs_mnist, labels_mnist)
        loss_mnist.backward()
        optimizer_mnist.step()
        running_loss_mnist += loss_mnist.item()

    print(f"Epoch {epoch + 1}, Loss: {running_loss_mnist / len(train_dataloader_mnist)}")

# 在测试集上评估模型
resnet18_model_mnist.eval()
corrects_mnist = 0
with torch.no_grad():
    for images_mnist, labels_mnist in test_dataloader_mnist:
        images_mnist, labels_mnist = images_mnist.to(device), labels_mnist.to(device)
        outputs_mnist = resnet18_model_mnist(images_mnist)
        _, preds_mnist = torch.max(outputs_mnist, 1)
        corrects_mnist += torch.sum(preds_mnist == labels_mnist.data)

test_accuracy_mnist = corrects_mnist.double() / len(test_set_mnist)
print(f"Test Accuracy on MNIST: {test_accuracy_mnist * 100:.2f}%")
torch.cuda.empty_cache()