CNN ------图片分类(cifar-10)

In [1]:
# -----------导入库文件----------------
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import PIL.Image
import numpy as np
import os

In [2]:
# --------定义数据集(从文件夹读取)--------------
def load_data():
    xs = []
    ys = []
    for filename in os.listdir("output"):

        x = PIL.Image.open("output/"+filename)
        x = torch.FloatTensor(np.array(x)) / 255

        x = x.permute(2, 0, 1)

        y = int(filename[0])

        xs.append(x)
        ys.append(y)

    return xs, ys


x_s, y_s = load_data()


In [3]:
# -----------重写数据集类------------
class Dataset(torch.utils.data.Dataset):
    def __len__(self):
        return len(x_s)
    def __getitem__(self,i):
        return x_s[i], y_s[i]
dataset = Dataset()

In [4]:
# -------------数据加载-------------------------
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=8,
                                     shuffle=True,
                                     drop_last=True)

In [5]:
# -------------定义模型-----------------
class cnn(nn.Module):
    def __init__(self):
        super(cnn, self).__init__()
        self.cnn1 = nn.Conv2d(in_channels=3,
                              out_channels=16,
                              kernel_size=5,
                              stride=2,
                              padding=0)
        self.cnn2 = nn.Conv2d(in_channels=16,
                              out_channels=32,
                              kernel_size=3,
                              stride=1,
                              padding=1)
        self.cnn3 = nn.Conv2d(in_channels=32,
                              out_channels=128,
                              kernel_size=7,
                              stride=1,
                              padding=0)
        self.pool = nn.MaxPool2d(kernel_size=2,stride=2)
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(in_features=128, out_features=10)
    def forward(self, x):
        # 第一次卷积
        x = self.cnn1(x)
        x = self.relu(x)
        # 第二次卷积
        x = self.cnn2(x)
        x = self.relu(x)
        # 池化
        x = self.pool(x)
        # 第三次卷积
        x = self.cnn3(x)
        x = self.relu(x)
        # 展开
        x = x.flatten(start_dim=1)
        x = self.fc(x)
        return x

cnn = cnn()

In [6]:
# ------------------------训练--------------
def train():
    optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3)                  # 优化器
    loss_function = nn.CrossEntropyLoss()                                    # 损失函数
    cnn.train()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")    # 将数据放GPU中跑
    cnn.to(device)
    print('device : \t', device)

    for epoch in range(5):
        for i, (x, y) in enumerate(loader):

            x, y = x.to(device), y.to(device)

            out = cnn(x)
            loss = loss_function(out, y)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if i % 20 == 0:
                # 计算正确率
                acc = out.argmax(dim=1).eq(y).sum().item() / len(y)
                print(epoch, '\t', loss.item(), '\t', acc)

    # torch.save(cnn, 'F:/python/Net/CNN/model_cnn.pt')

train()

device : 	 cuda
0 	 2.279548406600952 	 0.125
0 	 2.2819366455078125 	 0.125
0 	 2.269050121307373 	 0.0
0 	 2.2332184314727783 	 0.25
0 	 2.142975330352783 	 0.25
0 	 2.1198601722717285 	 0.25
0 	 2.0849995613098145 	 0.375
0 	 2.03425669670105 	 0.5
0 	 2.3375744819641113 	 0.125
0 	 2.1036012172698975 	 0.25
0 	 2.279348373413086 	 0.0
0 	 2.1764609813690186 	 0.125
0 	 1.9115420579910278 	 0.375
0 	 1.6665022373199463 	 0.5
0 	 1.944000244140625 	 0.25
0 	 2.3078465461730957 	 0.125
0 	 2.444561004638672 	 0.25
0 	 1.5756767988204956 	 0.25
0 	 2.117741107940674 	 0.375
0 	 2.1739721298217773 	 0.125
0 	 1.9021459817886353 	 0.25
0 	 2.559194564819336 	 0.125
0 	 1.8463433980941772 	 0.375
0 	 2.0743284225463867 	 0.375
0 	 2.105828285217285 	 0.0
0 	 1.5505563020706177 	 0.5
0 	 1.560899257659912 	 0.375
0 	 1.8944240808486938 	 0.25
0 	 2.282024383544922 	 0.25
0 	 1.1971155405044556 	 0.5
0 	 2.546590805053711 	 0.125
0 	 2.290234327316284 	 0.375
0 	 1.5243558883666992 	 0.625
