In [1]:
import os
import time
import os.path as osp

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import CIFAR10
from torchvision import datasets
from torchvision import transforms
import torchvision

from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from PIL import Image
from clip import clip

In [2]:
torch.cuda.empty_cache()

In [3]:
# random seed
SEED = 1 
NUM_CLASS = 10

# Training
BATCH_SIZE = 128
NUM_EPOCHS = 30
EVAL_INTERVAL=1
SAVE_DIR = './log'

# Optimizer
LEARNING_RATE = 1e-2
MOMENTUM = 0.9
STEP=5
GAMMA=0.5

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
from torch.utils.data import Subset
import numpy as np

# cifar10 transform
transform_cifar100_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_cifar100_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# 加载训练集
train_set = torchvision.datasets.CIFAR100(root='/shareddata', train=True,
                                        download=True, transform=transform_cifar100_test)
train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE,
                                          shuffle=False, num_workers=2)
test_set = torchvision.datasets.CIFAR100(root='/shareddata', train=False,
                                        download=True, transform=transform_cifar100_test)
test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE,
                                          shuffle=False, num_workers=2)

# CIFAR-100 的类名列表
class_names = ['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 
               'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 
               'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 
               'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 
               'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 
               'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
               'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
               'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
               'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
               'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
               'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
               'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
               'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
               'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman',
               'worm']

# 更新数据集名称
dataset_name = 'CIFAR100'

Files already downloaded and verified
Files already downloaded and verified


In [6]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 4, 3)  
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(4, 8, 3)  
        self.fc1 = nn.Linear(8 * 6 * 6, 32)
        self.fc2 = nn.Linear(32, 100)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 8 * 6 * 6)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [7]:
model = ConvNet()
model.to(device)

ConvNet(
  (conv1): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=288, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=100, bias=True)
)

In [8]:
import time


criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=STEP, gamma=GAMMA)

# 在训练循环开始之前记录起始时间
start_time = time.time()

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(train_dataloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    scheduler.step()

    if epoch % EVAL_INTERVAL == 0:
        print(f'Epoch {epoch+1}, Loss: {running_loss / len(train_dataloader)}')

print('Finished Training')

end_time = time.time()

# 计算总训练时间
total_time = end_time - start_time
print(f'Total training time: {total_time:.2f} seconds')

Epoch 1, Loss: 4.480070070232577
Epoch 2, Loss: 4.035427756931471
Epoch 3, Loss: 3.7556229759665096
Epoch 4, Loss: 3.5782850300869367
Epoch 5, Loss: 3.466397462903386
Epoch 6, Loss: 3.3356324473915198
Epoch 7, Loss: 3.287794094256428
Epoch 8, Loss: 3.2491149195014972
Epoch 9, Loss: 3.2134600466169663
Epoch 10, Loss: 3.1812561161987616
Epoch 11, Loss: 3.114576476919072
Epoch 12, Loss: 3.096572387553847
Epoch 13, Loss: 3.0825997856266967
Epoch 14, Loss: 3.0701592840502023
Epoch 15, Loss: 3.0590751799171234
Epoch 16, Loss: 3.024690830494132
Epoch 17, Loss: 3.017720729188846
Epoch 18, Loss: 3.012148924800746
Epoch 19, Loss: 3.006936867828564
Epoch 20, Loss: 3.0020942364812204
Epoch 21, Loss: 2.9831235262439075
Epoch 22, Loss: 2.9794080989135194
Epoch 23, Loss: 2.9768324312956436
Epoch 24, Loss: 2.9744162352188774
Epoch 25, Loss: 2.9721157117877777
Epoch 26, Loss: 2.961807836352102
Epoch 27, Loss: 2.9599913855647797
Epoch 28, Loss: 2.958695986691643
Epoch 29, Loss: 2.9574777945838013
Epoch 

In [9]:
# 评估基准模型的函数
acc = 0
def evaluate_model(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = correct / total
    return 100 * correct / total

In [10]:
# 评估基准模型
accuracy_convnet = evaluate_model(model, test_dataloader, device)
print(f'Accuracy of ConvNet on CIFAR100 is {accuracy_convnet}%')

Accuracy of ConvNet on CIFAR100 is 25.03%


In [11]:
print(f"CNN : {accuracy_convnet}%")

CNN : 25.03%


In [12]:
torch.cuda.empty_cache()