In [None]:
import numpy as np
import os,time
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision import models
from torchvision import transforms as tfs
from torchvision.datasets import ImageFolder
from PIL import Image
import matplotlib.pyplot as plt

path ="./results"
resultpath = "./pics"
modelpath = "./mobilenet_model_100e.pth.tar"
class MobileNet(nn.Module):
    def __init__(self):
        super(MobileNet, self).__init__()

        # Normal convolution block followed by Batchnorm (CONV_3x3-->BN-->Relu)
        def conv_bn(inp, oup, stride):
            return nn.Sequential(
                nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True)
            )

        # Depthwise convolution block (CONV_BLK_3x3-->BN-->Relu-->CONV_1x1-->BN-->Relu)
        def conv_dw(inp, oup, stride):
            return nn.Sequential(
                nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
                nn.BatchNorm2d(inp),
                nn.ReLU(inplace=True),
    
                nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True),
            )

        self.model = nn.Sequential(
            conv_bn(  3,  32, 2), 
            conv_dw( 32,  64, 1),
            conv_dw( 64, 128, 2),
            conv_dw(128, 128, 1),
            conv_dw(128, 256, 2),
            conv_dw(256, 256, 1),
            conv_dw(256, 512, 2),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 1024, 2),
            conv_dw(1024, 1024, 1),
            nn.AvgPool2d(7),
        )
        self.fc = nn.Linear(1024, 1000)

    def forward(self, x):
        x = self.model(x)
        x = x.view(-1, 1024)
        x = self.fc(x)
        return x

batch_size = 64
workers = 1
epochs = 1
print_freq = 100

valdir = path
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])

val_image_data =datasets.ImageFolder(valdir, transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ]))

val_loader = torch.utils.data.DataLoader(
    val_image_data,
    batch_size=batch_size, shuffle=False,
    num_workers=workers, pin_memory=True)
class_dicts = val_image_data.class_to_idx
# print(class_dicts)
# print(val_image_data.imgs)
model = MobileNet()
model = torch.load(modelpath)

# switch to evaluate mode
model.eval()
def get_acc(output, label):
    total = output.shape[0]
    _, pred_label = output.max(1)
    num_correct = (pred_label == label).sum().item()
    return num_correct / total

def evaluate(val_loader, model):
    val_acc =0
    for im, label in val_loader:
        if torch.cuda.is_available():
            im_val = Variable(im.cuda())  # (bs, 3, h, w)
            label_val = Variable(label.cuda())  # (bs, h, w)
        else:
            im_val = Variable(im)
            label_val = Variable(label)
        # compute output
        output = model(im_val)
        # measure accuracy and record loss
        val_acc +=get_acc(output,label_val)
        # measure elapsed time
        end = time.time()
        print('acc {}'.format(val_acc/len(val_loader)))

def tensor_to_PIL(tensor):
    image = tensor.clone()
    image = image.squeeze(0)
    image = unloader(image)
    return image

def get_key (dict, value):
    return [k for k, v in dict.items() if v == value]
        
def predict(data_loader,model):
    data_acc =0
#     print("dataloader length "+str(len(data_loader)))
    for image, imginfo in zip(data_loader,val_image_data.imgs):
        im = image[0]
        label = image[1]
#         print("data_acc = {}".format(data_acc))
#         print("label is {}".format(label))
        if torch.cuda.is_available():
            im_data = Variable(im.cuda())  # (bs, 3, h, w)
            label_data = Variable(label.cuda())  # (bs, h, w)
        else:
            im_data = Variable(im)
            label_data = Variable(label)
        # compute output
        output = model(im_data)
        # measure accuracy and record loss
        total = output.shape[0]
        _, pred_label = output.max(1)
#         for p in pred_label:
#             print("p item {}".format(p.item()))
#         for l in label_data:
#             print("l item {}".format(l.item()))
#         print(imginfo)
#         print(pred_label.item())
#         print(label_data.item())
        count=0
        for p,l,pic in zip(pred_label,label_data,im):
            count+=1
            if p!=l:
                image1 = tfs.ToPILImage()(pic)
#                 image1 = Image.fromarray(np.transpose(pic.cpu().detach().numpy(), (1, 2, 0)))
                print(image1)
                labelname = get_key(class_dicts,int(p))
                labelname1 = get_key(class_dicts,int(l))
                image1.save(resultpath+"/pred_"+labelname[0]+"_"+labelname1[0]+"_"+str(count)+".jpg")
        num_correct = (pred_label == label_data).sum().item()
        data_acc+= num_correct / total
        # measure elapsed time
        end = time.time()
        print('acc {}'.format(data_acc/len(data_loader)))


predict(val_loader,model)