In [60]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms, models

In [61]:
train_transforms = transforms.Compose([transforms.Resize([224,224]),
                                    transforms.ToTensor(),
                                    transforms.RandomRotation(30),
                                    transforms.RandomHorizontalFlip(),
                                    ])
test_transforms = transforms.Compose([transforms.Resize([224,224]),
                                    transforms.ToTensor(),
                                    ])

batch_size = 4 # 批处理大小


In [62]:
trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                        download=True, transform=train_transforms) # 下载训练数据集
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2) # 加载训练数据集

testset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                       download=True, transform=test_transforms) # 下载测试数据集
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2) # 加载测试数据集

classes = (
	"apple",
	"aquarium_fish",
	"baby",
	"bear",
	"beaver",
	"bed",
	"bee",
	"beetle",
	"bicycle",
	"bottle",
	"bowl",
	"boy",
	"bridge",
	"bus",
	"butterfly",
	"camel",
	"can",
	"castle",
	"caterpillar",
	"cattle",
	"chair",
	"chimpanzee",
	"clock",
	"cloud",
	"cockroach",
	"couch",
	"cra",
	"crocodile",
	"cup",
	"dinosaur",
	"dolphin",
	"elephant",
	"flatfish",
	"forest",
	"fox",
	"girl",
	"hamster",
	"house",
	"kangaroo",
	"keyboard",
	"lamp",
	"lawn_mower",
	"leopard",
	"lion",
	"lizard",
	"lobster",
	"man",
	"maple_tree",
	"motorcycle",
	"mountain",
	"mouse",
	"mushroom",
	"oak_tree",
	"orange",
	"orchid",
	"otter",
	"palm_tree",
	"pear",
	"pickup_truck",
	"pine_tree",
	"plain",
	"plate",
	"poppy",
	"porcupine",
	"possum",
	"rabbit",
	"raccoon",
	"ray",
	"road",
	"rocket",
	"rose",
	"sea",
	"seal",
	"shark",
	"shrew",
	"skunk",
	"skyscraper",
	"snail",
	"snake",
	"spider",
	"squirrel",
	"streetcar",
	"sunflower",
	"sweet_pepper",
	"table",
	"tank",
	"telephone",
	"television",
	"tiger",
	"tractor",
	"train",
	"trout",
	"tulip",
	"turtle",
	"wardrobe",
	"whale",
	"willow_tree",
	"wolf",
	"woman",
	"worm",
) # 类别

print(trainloader.dataset.classes)

Files already downloaded and verified
Files already downloaded and verified
['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tracto

In [63]:
model = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [64]:
for param in model.parameters():
    param.requires_grad = False

In [65]:
model.fc = nn.Sequential(nn.Linear(512, 256),
                         nn.ReLU(),
                         nn.Dropout(0.5),
                         nn.Linear(256, 100),
                         nn.LogSoftmax(dim=1))
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# optimizer = optim.Adam(model.parameters(), lr=0.001)

epochs = 25
steps = 0
test_loss_min = np.Inf
running_loss = 0
print_every = 10
train_losses, test_losses = [], []
for epoch in range(epochs):
    for inputs, labels in trainloader:
        steps += 1
        optimizer.zero_grad()
        logps = model.forward(inputs)
        loss = criterion(logps, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        if steps % print_every == 0:
            test_loss = 0
            accuracy = 0
            model.eval()
            with torch.no_grad():
                for inputs, labels in testloader:
                    logps = model.forward(inputs)
                    batch_loss = criterion(logps, labels)
                    test_loss += batch_loss.item()

                    ps = torch.exp(logps)
                    top_p, top_class = ps.topk(1, dim=1)
                    equals = top_class == labels.view(*top_class.shape)
                    accuracy += torch.mean(equals.type(torch.FloatTensor)).item()

            train_losses.append(running_loss / len(trainloader))
            test_losses.append(test_loss / len(testloader))
            print(f"Epoch {epoch + 1}/{epochs}.. "
                  f"Train loss: {running_loss / print_every:.3f}.. "
                  f"Test loss: {test_loss / len(testloader):.3f}.. "
                  f"Test accuracy: {accuracy / len(testloader):.3f}")

            running_loss = 0
            model.train()
            if test_loss <= test_loss_min:
                print(
                    'Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(test_loss_min, test_loss))
                torch.save(model.state_dict(), 'bestmodel.pth')
                test_loss_min = test_loss
torch.save(model, 'model.pth')

Epoch 1/25.. Train loss: 4.617.. Test loss: 4.639.. Test accuracy: 0.007
Validation loss decreased (inf --> 11596.927952).  Saving model ...
Epoch 1/25.. Train loss: 4.658.. Test loss: 4.648.. Test accuracy: 0.008


KeyboardInterrupt: 

In [None]:
plt.plot(train_losses, label='Training loss')
plt.plot(test_losses, label='Validation loss')
plt.legend(frameon=False)
plt.show()

In [None]:
def predict_image(image):
    # 应用transform
    image_tensor = test_transforms(image).float()

    image_tensor = image_tensor.unsqueeze_(0)
    input = Variable(image_tensor)
    output = model(input)
    index = output.data.cpu().numpy().argmax()
    return index

In [None]:
from PIL import Image
import glob

for filename in glob.glob('./test/*.png'):
    im = Image.open(filename).convert('RGB')
    index = predict_image(im)
    # 打印读出的图片


import json
result = {}
for filename in glob.glob('./test/*.png'):
    im=Image.open(filename).convert('RGB')
    index = predict_image(im)
    result[filename[filename.rfind('\\')+1:filename.rfind('.')]] = classes[index]
with open('result.json', 'w') as fp:
    json.dump(result, fp)

with open('result.json', 'r') as fp:
    result = json.load(fp)

result_sorted = __builtins__.dict(sorted(result.items(), key=lambda x: int(x[0])))

with open('result_sorted.json', 'w') as fp:
    json.dump(result_sorted, fp, indent=4)