In [2]:
import torch, torchvision
from torchvision import transforms
from torchvision import datasets
from torch.utils.data.dataloader import DataLoader
from torch import nn, optim

In [3]:
import torch.nn.functional as F
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2)
        self.batchnorm1 = nn.BatchNorm2d(num_features=32)
        self.conv2 = nn.Conv2d(32, 16, kernel_size=5, stride=1, padding=2)
        self.batchnorm2 = nn.BatchNorm2d(num_features=16)
        self.fc1 = nn.Linear(in_features=8*8*16, out_features=32)
        self.fc2 = nn.Linear(in_features=32, out_features=2)
    
    def forward(self, x):
        out = F.max_pool2d(torch.tanh(self.conv1(x)), 2)
        out = self.batchnorm1(out)
        out = F.max_pool2d(torch.tanh(self.conv2(out)), 2)
        out = self.batchnorm2(out)
        out = out.view(-1, 8*8*16)
        out = torch.tanh(self.fc1(out))
        out = self.fc2(out)
        return out

In [4]:
data_path = r'D:\FunnyProgramming\PythonProject\SummerPrac\data'
cifar10 = datasets.CIFAR10(data_path, train=True, download=False, transform=transforms.Compose([transforms.ToTensor()]))
cifar10_val = datasets.CIFAR10(data_path, train=False, download=False, transform=transforms.Compose([transforms.ToTensor()]))
type(cifar10), type(cifar10_val)

(torchvision.datasets.cifar.CIFAR10, torchvision.datasets.cifar.CIFAR10)

In [5]:
imgs = torch.stack([img for img, label in cifar10], dim=3) # 这里dim=3主要为了比较好算mean和dev
imgs_val = torch.stack([img for img, label in cifar10_val], dim=3)
mean_imgs, stddev_imgs = imgs.view(3, -1).mean(dim=-1), imgs.view(3, -1).std(dim=-1)
mean_imgs_val, stddev_imgs_val = imgs_val.view(3, -1).mean(dim=-1), imgs_val.view(3, -1).std(dim=-1)
normal_cifar10 = datasets.CIFAR10(data_path, train=True, download=False, 
                                  transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean_imgs, std=stddev_imgs)]))
normal_cifar10_val = datasets.CIFAR10(data_path, train=False, download=False, 
                                  transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean_imgs, std=stddev_imgs)]))
print("Imgs: Maximum: %f, Minimum: %f" % (normal_cifar10[0][0].max(), normal_cifar10[0][0].min()))
print("Imgs_val: Maximum: %f, Minimum: %f" % (normal_cifar10_val[0][0].max(), normal_cifar10_val[0][0].min()))

Imgs: Maximum: 2.094577, Minimum: -1.989213
Imgs_val: Maximum: 2.094577, Minimum: -1.782841


In [6]:
label_map = {0:0, 2:1}
cifar2 = [(img,label_map[label]) for img, label in normal_cifar10 if label in [0, 2]]
cifar2_val = [(img,label_map[label]) for img, label in normal_cifar10_val if label in [0, 2]]

In [9]:
import datetime
device = torch.device('cuda')
train_loader = DataLoader(cifar2, batch_size=64, shuffle=True)
val_loader = DataLoader(cifar2_val, batch_size=64, shuffle=False)

model = Model()
model = model.to(device=device)
learning_rate = 1e-2
optimizer = optim.SGD(params=model.parameters(), lr=learning_rate)

loss_fn = nn.CrossEntropyLoss()
n_epochs = 50

for epoch in range(1, n_epochs+1):
    total_loss = 0.0
    for imgs, labels in train_loader:
        batch_size = imgs.shape[0]
        imgs = imgs.to(device)
        labels = labels.to(device)

        out = model(imgs)
        loss = loss_fn(out, labels)
        total_loss += loss.detach().item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    if epoch == 1 or epoch % 10 == 0:
        print("{} Epoch: {}, Average Loss: {}".format(datetime.datetime.now(), epoch, total_loss/len(train_loader)))


2023-08-27 20:55:14.918840 Epoch: 1, Average Loss: 0.41888475142846443
2023-08-27 20:55:20.331441 Epoch: 10, Average Loss: 0.2201342896149037
2023-08-27 20:55:26.121497 Epoch: 20, Average Loss: 0.12901126930288448
2023-08-27 20:55:31.945586 Epoch: 30, Average Loss: 0.0639505622920337
2023-08-27 20:55:37.752588 Epoch: 40, Average Loss: 0.03093684478633248
2023-08-27 20:55:43.579638 Epoch: 50, Average Loss: 0.007562966479706299


In [10]:
model.eval()
for name, loader in [("train", train_loader), ("val", val_loader)]:
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            outputs = model(imgs)
            _, predicted = torch.max(outputs, dim=-1)
            total += labels.shape[0]
            correct += int((labels == predicted).sum())
    print("Accuracy {}: {:.2f}".format(name, correct/total))

Accuracy train: 1.00
Accuracy val: 0.88
