In [1]:
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
import torch

In [2]:
'''
AlexNet结构
硬件原因，修改了网络的参数

1.卷积层
2.池化层
3.归一化层
4.卷积层
5.池化层
6.归一化层
7.卷积层
8.卷积层
9.卷积层
10.池化层 
11.全连接层
12.dropout层
13.全连接层
14.dropout层
15.全连接层
'''


class AlexNet(nn.Module):
    def __init__(self, classNums):
        super(AlexNet, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=48,
                      kernel_size=11, stride=4, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=48, out_channels=128,
                      kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=192,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        self.conv4 = nn.Sequential(
            nn.Conv2d(in_channels=192, out_channels=192,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        self.conv5 = nn.Sequential(
            nn.Conv2d(in_channels=192, out_channels=128,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )

        self.fc1 = nn.Sequential(
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(),
            nn.Dropout(p=0.5)
        )

        self.fc2 = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Dropout(p=0.5)
        )

        self.output = nn.Linear(1024, classNums)

    def forward(self, inputs):
        net = self.conv1(inputs)
        net = self.conv2(net)
        net = self.conv3(net)
        net = self.conv4(net)
        net = self.conv5(net)
        net = net.view(net.size(0), -1)
        net = self.fc1(net)
        net = self.fc2(net)
        res = self.output(net)

        return res, net

In [3]:
import os
from shutil import copy
import random

In [4]:
def mkfile(fp):
    if not os.path.exists(fp):
        os.makedirs(fp)

In [5]:
'''
数据和代码不在一个盘里
'''
fp = "M:/DataSet/alexnet/flower_photos/"

flower_labels = [label for label in os.listdir(fp) if ".txt" not in label]

mkfile("M:/DataSet/alexnet/train")
for label in flower_labels:
    mkfile("M:/DataSet/alexnet/train/" + label)

mkfile("M:/DataSet/alexnet/val")
for label in flower_labels:
    mkfile("M:/DataSet/alexnet/val/" + label)

rate = 0.1
for label in flower_labels:
    label_path = fp + "/" + label + "/"
    imgs = os.listdir(label_path)
    imgNums = len(imgs)
    indexs = random.sample(imgs, k=int(imgNums*rate))
    for idx, img in enumerate(imgs):
        if img in indexs:
            img_path = label_path + img
            new_path = "M:/DataSet/alexnet/val/" + label
            copy(img_path, new_path)
        else:
            img_path = label_path + img
            new_path = "M:/DataSet/alexnet/train/" + label
            copy(img_path, new_path)
    print(f"\r[{label}] : {imgNums}]")

print("processing done!")

[daisy] : 633]
[dandelion] : 898]
[roses] : 641]
[sunflowers] : 699]
[tulips] : 799]
processing done!


In [6]:
handler = torch.device("cpu")
handler

device(type='cpu')

In [7]:
# 数据变换
data_trans = {
    "train": torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224),
                                            torchvision.transforms.RandomHorizontalFlip(),
                                            torchvision.transforms.ToTensor(),
                                            torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    "val": torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop((224, 224)),
                                           torchvision.transforms.ToTensor(),
                                           torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
}

In [8]:
# 保存数据的根目录与代码不在同一个盘
data_root = "M:/DataSet/alexnet/"

train_data = torchvision.datasets.ImageFolder(
    root=data_root + "train", transform=data_trans["train"])
trainNums = len(train_data)

flower_list = train_data.class_to_idx
label_dict = dict((val, key) for key, val in flower_list.items())

batch_size = 32
train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=batch_size, shuffle=True, num_workers=0)

In [9]:
validate_data = torchvision.datasets.ImageFolder(
    root=data_root+"val", transform=data_trans["val"])
valNums = len(validate_data)

val_loader = torch.utils.data.DataLoader(
    validate_data, batch_size=batch_size, shuffle=True, num_workers=0)

In [10]:
import json

jsons = json.dumps(label_dict, indent=4)
with open("alexnet.json", 'w') as json_file:
    json_file.write(jsons)

In [11]:
test_data = iter(val_loader)
test_img, test_label = test_data.next()

In [12]:
net = AlexNet(classNums=5)
net.to(handler)

# 使用交叉熵和adam优化器
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.0002)

# 参数保存路径
save_path = "./alexnet.pth"

# 最高准确率
accBest = 0.0

In [13]:
for epoch in range(10):
    net.train()
    running_loss = 0.0
    for step, data in enumerate(train_loader, start=0):
        imgs, labels = data
        optimizer.zero_grad()
        outputs = net(imgs.to(handler))
        loss = loss_func(outputs[0], labels.to(handler))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    # 测试
    net.eval()
    acc = 0.0
    with torch.no_grad():
        for val_data in val_loader:
            val_imgs, val_labels = val_data
            outputs = net(val_imgs.to(handler))
            pred_y = torch.max(outputs[0], dim=1)[1]
            acc += (pred_y == val_labels.to(handler)).sum().item()
        val_acc = acc / valNums
        if val_acc > accBest:
            accBest = val_acc
            torch.save(net.state_dict(), save_path)
        print("[epoch {}] train_loss: {:.3f}  test_acc: {:.3f}".format(
            epoch+1, running_loss/step, val_acc))

print("Finished")

[epoch 1] train_loss: 1.397  test_acc: 0.444
[epoch 2] train_loss: 1.217  test_acc: 0.538
[epoch 3] train_loss: 1.098  test_acc: 0.554
[epoch 4] train_loss: 1.013  test_acc: 0.617
[epoch 5] train_loss: 0.959  test_acc: 0.622
[epoch 6] train_loss: 0.943  test_acc: 0.672
[epoch 7] train_loss: 0.910  test_acc: 0.645
[epoch 8] train_loss: 0.870  test_acc: 0.674
[epoch 9] train_loss: 0.862  test_acc: 0.698
[epoch 10] train_loss: 0.821  test_acc: 0.709
Finished
