In [1]:
import torchvision.models as models
import torch.optim as optim
import torch
import torch.nn as nn
from Train import trainmodel
from efficientnet_pytorch import EfficientNet
from imageloader import *
from sklearn.metrics import confusion_matrix

In [2]:
%matplotlib inline
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np

def plot_confusion_matrix(cm, savename, title='Confusion Matrix'):
    classes =['cardboard','glass', 'metal','paper','plastic','trash']
    plt.figure(figsize=(12, 8), dpi=100)
    np.set_printoptions(precision=2)

    # 在混淆矩阵中每格的概率值
    ind_array = np.arange(len(classes))
    x, y = np.meshgrid(ind_array, ind_array)
    for x_val, y_val in zip(x.flatten(), y.flatten()):
        c = cm[y_val][x_val]
        if c > 0.001:
            plt.text(x_val, y_val, "%0.2f" % (c,), color='red', fontsize=15, va='center', ha='center')
    
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.binary)
    plt.title(title)
    plt.colorbar()
    xlocations = np.array(range(len(classes)))
    plt.xticks(xlocations, classes, rotation=90)
    plt.yticks(xlocations, classes)
    plt.ylabel('Actual label')
    plt.xlabel('Predict label')
    
    # offset the tick
    tick_marks = np.array(range(len(classes))) + 0.5
    plt.gca().set_xticks(tick_marks, minor=True)
    plt.gca().set_yticks(tick_marks, minor=True)
    plt.gca().xaxis.set_ticks_position('none')
    plt.gca().yaxis.set_ticks_position('none')
    plt.grid(True, which='minor', linestyle='-')
    plt.gcf().subplots_adjust(bottom=0.15)
    
    # show confusion matrix
    plt.savefig(savename, format='png')
    plt.show()

In [3]:
test_folder = "./garbage_dataset/test_set"
batch_size = 6
device = 9
model_name_list = ["EfficientNet","alexnet","vgg16_bn","vgg19_bn","mobilenet_v2",\
                   "densenet161","densenet121","resnet18","resnet34",\
                   "resnet50""resnet101"]
'''efficient_net = EfficientNet.from_pretrained('efficientnet-b7',num_classes=6)'''
alexnet = models.alexnet(pretrained=False)
alexnet.classifier=nn.Sequential(nn.Linear(9216,4096),
                                 nn.ReLU(),
                                 nn.Dropout(0.2),
                                 nn.Linear(4096,4096),
                                 nn.ReLU(),
                                 nn.Linear(4096,6))
'''vgg16_bn = models.vgg16_bn(pretrained=False)
vgg16_bn.classifier=nn.Sequential(nn.Linear(25088,4096),nn.ReLU(inplace = True),\
                               nn.Dropout(0.5),nn.Linear(4096,4096),\
                               nn.ReLU(inplace = True),nn.Dropout(0.5),nn.Linear(4096,6))
vgg19_bn = models.vgg19_bn(pretrained = False)
vgg19_bn.classifier=nn.Sequential(nn.Linear(25088,4096),nn.ReLU(inplace = True),\
                               nn.Dropout(0.5),nn.Linear(4096,4096),\
                               nn.ReLU(inplace = True),nn.Dropout(0.5),nn.Linear(4096,6))
mobilenet_v2 = models.mobilenet_v2(pretrained=False)
mobilenet_v2.classifier = nn.Sequential(nn.Dropout(0.2),nn.Linear(1280,6))
densenet161 = models.densenet161(pretrained = False)
densenet161.classifier=nn.Linear(2208,6)
densenet121 = models.densenet161(pretrained = False)
densenet121.classifier=nn.Linear(1024,6)
resnet18 = models.resnet18(pretrained = False)
resnet18.fc=nn.Linear(512,6)
resnet34 = models.resnet18(pretrained = False)
resnet34.fc=nn.Linear(512,6)
resnet50 = models.resnet18(pretrained = False)
resnet50.fc=nn.Linear(2048,6)
resnet101 = models.resnet18(pretrained = False)
resnet101.fc=nn.Linear(2048,6)'''

'vgg16_bn = models.vgg16_bn(pretrained=False)\nvgg16_bn.classifier=nn.Sequential(nn.Linear(25088,4096),nn.ReLU(inplace = True),                               nn.Dropout(0.5),nn.Linear(4096,4096),                               nn.ReLU(inplace = True),nn.Dropout(0.5),nn.Linear(4096,6))\nvgg19_bn = models.vgg19_bn(pretrained = False)\nvgg19_bn.classifier=nn.Sequential(nn.Linear(25088,4096),nn.ReLU(inplace = True),                               nn.Dropout(0.5),nn.Linear(4096,4096),                               nn.ReLU(inplace = True),nn.Dropout(0.5),nn.Linear(4096,6))\nmobilenet_v2 = models.mobilenet_v2(pretrained=False)\nmobilenet_v2.classifier = nn.Sequential(nn.Dropout(0.2),nn.Linear(1280,6))\ndensenet161 = models.densenet161(pretrained = False)\ndensenet161.classifier=nn.Linear(2208,6)\ndensenet121 = models.densenet161(pretrained = False)\ndensenet121.classifier=nn.Linear(1024,6)\nresnet18 = models.resnet18(pretrained = False)\nresnet18.fc=nn.Linear(512,6)\nresnet34 = models.resnet18(

In [5]:
dict = torch.load("./param/alex_transfer_8_ADAM19",map_location="cuda:8")
model = alexnet
model.load_state_dict(dict)
model.eval()
model.to("cuda:9")
lossfunc = torch.nn.CrossEntropyLoss().cuda(device)
testset,size3=image_data_loader(test_folder,batch_size)
y_pred=[]
y_true=[]
test_loss = 0.0
for i, data in enumerate(testset):
    (inputs, labels) = data
    inputs = inputs.cuda(device)
    labels = labels.cuda(device)
    outputs = model(inputs)
    loss = lossfunc(outputs, labels)
    test_loss += loss.item() * inputs.size(0)
    ret, predictions = torch.max(outputs.data, 1)
    predictions=predictions.cpu()
    labels = labels.cpu()
    y_pred.extend(list(np.array(predictions)))
    y_true.extend(list(np.array(labels)))
cm = confusion_matrix(y_true,y_pred)
crcount = 0
for i in range(len(y_pred)):
    if (y_pred[i]==y_true[i]):
        crcount+=1
print(crcount/len(y_pred))
        
FP = cm.sum(axis=0) - np.diag(cm)  
FN = cm.sum(axis=1) - np.diag(cm)
TP = np.diag(cm)
TN = cm.sum() - (FP + FN + TP)
precision = TP / (TP+FP)  # 查准率
recall = TP / (TP+FN)  # 查全率
accuracy=(TP+TN)/(TP+TN+FP+FN)
f1=(2*precision*recall)/(precision+recall)
print(accuracy)
print(recall)
print(precision)
print(f1)
plot_confusion_matrix(cm, './cm/alexnet.png', title='transfer alexnet confusion matrix')



RuntimeError: Error(s) in loading state_dict for AlexNet:
	Missing key(s) in state_dict: "classifier.5.weight", "classifier.5.bias". 
	Unexpected key(s) in state_dict: "classifier.6.weight", "classifier.6.bias". 
	size mismatch for classifier.3.weight: copying a param with shape torch.Size([256, 4096]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for classifier.3.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([4096]).