In [1]:
import os
import json
import sys
import torch
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn as nn
from torchvision import transforms, datasets, utils
import numpy as np
from tqdm import tqdm
# %run network.ipynb

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch.nn as nn
import torch

class Network(nn.Module):
    def __init__(self, num_classes=10, init_weights=False):
        super(Network, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  
            nn.Conv2d(48, 128, kernel_size=5, padding=2),          
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),        
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(32 * 4, 1024),
            nn.ReLU(inplace=True),
            # nn.Dropout(p=0.5),
            # nn.Linear(1024, 2048),
            # nn.ReLU(inplace=True),
            nn.Linear(1024, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


In [3]:

def testSingle():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # load image
    img_path = "images/test/n01532829/n01532829_307.JPEG"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)

    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = Network(num_classes=20).to(device)

    # load model weights
    weights_path = "checkpoints/network.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path))

    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()
# testSingle()

In [5]:

def testBatch():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "test": transforms.Compose([transforms.Resize((224, 224)), 
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
    image_path = os.path.abspath(os.path.join(os.getcwd(), "images"))  # get data root path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    test_dataset = datasets.ImageFolder(root=os.path.join(image_path, "test"),
                                            transform=data_transform["test"])
    test_num = len(test_dataset)
    batch_size = 4
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)
    print("using {} images for testing.".format(test_num))

    net = Network(num_classes=10, init_weights=False)
    dict = torch.load('./network.pth')
    net.load_state_dict(dict)
    net.to(device)
    net.eval()
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():
        test_bar = tqdm(test_loader, file=sys.stdout)
        for test_data in test_bar:
            test_images, test_labels = test_data
            outputs = net(test_images.to(device))
            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, test_labels.to(device)).sum().item()

    test_accurate = acc / test_num
    print('test_accuracy: %.3f' %(test_accurate))
testBatch()

using cuda:0 device.
using 200 images for testing.
  0%|          | 0/50 [00:04<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x4608 and 128x1024)