In [1]:
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time
import cv2

In [2]:
#device : GPU or CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [3]:
data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

In [4]:
#data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
data_root = os.getcwd()
image_path = data_root + "/data/"  # flower data set path
train_dataset = datasets.ImageFolder(root=image_path + "/train",
                                     transform=data_transform["train"])
train_num = len(train_dataset)

In [5]:
cd_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in cd_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)

In [6]:
batch_size = 8
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=0)

validate_dataset = datasets.ImageFolder(root=image_path + "/val",
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=True,
                                              num_workers=0)

test_data_iter = iter(validate_loader)
test_image, test_label = test_data_iter.next()
print(test_image[0].size(),type(test_image[0]))
print(test_label[0],test_label[0].item(),type(test_label[0]))

torch.Size([3, 224, 224]) <class 'torch.Tensor'>
tensor(4) 4 <class 'torch.Tensor'>


In [7]:
net = AlexNet(num_classes=7, init_weights=False)

net.to(device)

loss_function = nn.CrossEntropyLoss()

optimizer = optim.Adam(net.parameters(), lr=0.0002)

save_path = './AlexNet.pth'

best_acc = 0.0

In [8]:
for epoch in range(10):
    # train
    net.train()
    running_loss = 0.0
    t1 = time.perf_counter()
    for step, data in enumerate(train_loader, start=0):
        images, labels = data
        optimizer.zero_grad()
        outputs = net(images.to(device))
        loss = loss_function(outputs, labels.to(device))
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        # print train process
        rate = (step + 1) / len(train_loader)
        a = "*" * int(rate * 50)
        b = "." * int((1 - rate) * 50)
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
    print()
    print(time.perf_counter()-t1)

    # validate
    net.eval()  
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            predict_y = torch.max(outputs, dim=1)[1]
            acc += (predict_y == val_labels.to(device)).sum().item()
        val_accurate = acc / val_num
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
              (epoch + 1, running_loss / step, val_accurate))

print('Finished Training')

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


train loss: 100%[**************************************************->]1.822
95.65430100799858
[epoch 1] train_loss: 1.903  test_accuracy: 0.251
train loss: 100%[**************************************************->]1.722
99.09808918000272
[epoch 2] train_loss: 1.485  test_accuracy: 0.395
train loss: 100%[**************************************************->]1.167
93.32256036800027
[epoch 3] train_loss: 1.293  test_accuracy: 0.538
train loss: 100%[**************************************************->]1.467
99.33429445299771
[epoch 4] train_loss: 1.193  test_accuracy: 0.531
train loss: 100%[**************************************************->]0.808
101.47223717499946
[epoch 5] train_loss: 1.139  test_accuracy: 0.414
train loss: 100%[**************************************************->]0.468
103.29923930599762
[epoch 6] train_loss: 1.102  test_accuracy: 0.586
train loss: 100%[**************************************************->]0.480
105.37922006699955
[epoch 7] train_loss: 1.035  test_accur