In [4]:
# 适用于不同类别的图片放在不同文件夹 文件夹名即为标签 对类别较少很方便！
import torch  
import torchvision  
from torchvision import transforms, utils
import torch.nn as nn  
from torch.autograd import Variable  
import torch.utils.data as Data  
import time  
import numpy as np
from sklearn.metrics import classification_report

In [8]:
train_img_data = torchvision.datasets.ImageFolder('C:/Users/jxjsj/Desktop/JupyterHome/Data/flowers/train',
                                            transform=transforms.Compose([
                                                transforms.Resize(256),
                                                transforms.CenterCrop(224),
                                                transforms.ToTensor()])
                                            )

print(len(train_img_data))
train_data_loader = torch.utils.data.DataLoader(train_img_data, batch_size=50,shuffle=True)
print(len(train_data_loader))

2945
59


In [9]:
test_img_data = torchvision.datasets.ImageFolder('C:/Users/jxjsj/Desktop/JupyterHome/Data/flowers/test',
                                            transform=transforms.Compose([
                                                transforms.Resize(256),
                                                transforms.CenterCrop(224), # 变成224 x 224像素
                                                transforms.ToTensor()])
                                            )

print(len(test_img_data))
test_data_loader = torch.utils.data.DataLoader(test_img_data, batch_size=50,shuffle=True)
print(len(test_data_loader))

725
15


In [5]:
class reluCNNet(torch.nn.Module):
    def __init__(self):
        super(reluCNNet, self).__init__()
        self.conv1 = torch.nn.Sequential( # 224
            torch.nn.Conv2d(3, 32, 5, 1, 0), # 220
            torch.nn.BatchNorm2d(num_features=32, eps=1e-05, momentum=0.1, affine=True), # BN 处理
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2) # 110
        ) 
        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv2d(32, 64, 5, 1, 0), # 106
            torch.nn.BatchNorm2d(num_features=64, eps=1e-05, momentum=0.1, affine=True), # BN 处理
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2) # 53
        )
        self.conv3 = torch.nn.Sequential(
            torch.nn.Conv2d(64, 32, 5, 2, 0), # 25
            torch.nn.BatchNorm2d(num_features=32, eps=1e-05, momentum=0.1, affine=True), # BN 处理
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2) # 12
        )
        self.dense = torch.nn.Sequential(
            torch.nn.Linear(32*12*12, 256),
#             torch.nn.Dropout(0.3),  # 防止过拟合尝试
            torch.nn.ReLU(),
            torch.nn.Linear(256, 5)
        )

    def forward(self, x):
        conv1_out = self.conv1(x)
        conv2_out = self.conv2(conv1_out)
        conv3_out = self.conv3(conv2_out)
        res = conv3_out.view(conv3_out.size(0), -1)
        out = self.dense(res)
        return out

In [6]:
# 加载模型参数并赋予模型
model = reluCNNet()
model.load_state_dict(torch.load('C:/Users/jxjsj/Desktop/JupyterHome/DLmodel/reluCNN_flower.pkl'))

In [16]:
use_gpu = False

# model = reluCNNet()

if use_gpu:
    model = model.cuda()
else:
    model = model.cpu()

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(1):
    print('epoch {}'.format(epoch + 1))
    # training-----------------------------
    model.train()
    train_acc = 0.
#     L_train_pred = []
#     L_train_real = []
    for step, (batch_x, batch_y) in enumerate(train_data_loader):
        batch_x, batch_y = Variable(batch_x), Variable(batch_y)
        
        if use_gpu:
            batch_x = batch_x.cuda()
            batch_y = batch_y.cuda()
            
        out = model(batch_x)
        print(out.size())
        print(batch_y.size())
        loss = loss_func(out, batch_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        pred = torch.max(out, 1)[1]
        num_correct = (pred == batch_y).sum()
        train_acc += num_correct.data
        
#         L_train_pred += pred.cpu().numpy().tolist()
#         L_train_real += batch_y.cpu().numpy().tolist()
        
        print('Step:',step+1,'Finished!')
    print('Train Acc: {:.6f}'.format(train_acc.cpu().numpy() / (len(train_img_data))))
#     print(classification_report(L_train_real,L_train_pred))

    # evaluation--------------------------------
    model.eval()
    with torch.no_grad():
        eval_acc = 0.
#         L_val_pred = []
#         L_val_real = []
        for batch_x, batch_y in test_data_loader:
            batch_x, batch_y = Variable(batch_x), Variable(batch_y)

            if use_gpu:
                batch_x = batch_x.cuda()
                batch_y = batch_y.cuda()

            out = model(batch_x)
            loss = loss_func(out, batch_y)
            
            pred = torch.max(out, 1)[1]
            num_correct = (pred == batch_y).sum()
            eval_acc += num_correct
            
#             L_val_pred += pred.cpu().numpy().tolist()
#             L_val_real += batch_y.cpu().numpy().tolist()
            
        print('Test Acc: {:.6f}'.format(eval_acc.cpu().numpy() / (len(test_img_data))))
#         print(classification_report(L_val_real,L_val_pred))

epoch 1
torch.Size([50, 5])
torch.Size([50])


KeyboardInterrupt: 

In [15]:
torch.save(model.state_dict(),'C:/Users/jxjsj/Desktop/JupyterHome/DLmodel/reluCNN_flower.pkl')

In [16]:
# 用外来图片检测test
from PIL import Image
test_use_rose = Image.open('C:/Users/jxjsj/Desktop/test_use_sunflower.jpg').convert('RGB')
transform=transforms.Compose([transforms.Resize(256),
                              transforms.CenterCrop(224),
                              transforms.ToTensor()])
test_rose = transform(test_use_rose)
model.cpu()
model.eval()
test_rose = test_rose.unsqueeze(0)
with torch.no_grad():
    test_rose = Variable(test_rose)
    test_out = model(test_rose)
    pred = torch.max(test_out, 1)[1]

# print(train_img_data.class_to_idx)
# print(test_img_data.class_to_idx)

index_label_dct = {train_img_data.class_to_idx[label] : label for label in train_img_data.class_to_idx}
index_label_dct[int(pred.cpu().numpy())]

'sunflowers'