In [6]:
# -*- coding: utf-8 -*-
import cv2
import os
import sys
import random
import json
import torch
from torch.utils.data import DataLoader
sys.path.append("/world/data-gpu-112/liliang/pytorch-reid")
from utils import model_utils
import numpy as np
from nets.model_main import ft_net
%matplotlib inline
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from input_pipeline.lmdb_dataset import LMDBDataset
import ipdb


In [7]:
h_gap = 50
v_gap = 50
text_gap = 45
rows = 8
cols = 7
num_samples = rows*cols
img_size = (256, 128)

# pretrain_snapshot="/world/data-gpu-112/liliang/pytorch-output-mt/v4_resnet50_age_orilstdata/model_attr_best.pth"
# lmdb_path="/world/data-c26/liliang/person-sttribute/test-data/train/age/v4"
# config_path = "/world/data-gpu-112/liliang/pytorch-reid/evaluate/test_params.json"

pretrain_snapshot="/world/data-gpu-112/liliang/pytorch-output/cleanv8_resnet50_corr4x144_am0.25_am0/model_best.pth"
lmdb_path="/world/data-c26/person_reid_batch/cleaned_v8"
config_path = "/world/data-gpu-112/liliang/pytorch-reid/params.json"

model_name = "resnet_50_ibn_a"
batch_size = num_samples

In [8]:
class FileListJSONDataset(object):
    def __init__(self, file_paths, img_h, img_w):
        self.paths = []
        self.labels = []
        self.img_h = img_h
        self.img_w = img_w

        for line in file_paths:
            path = line.split("\t")[0]
            label = int(line.split("\t")[1])
            self.paths.append(path)
            self.labels.append(int(float(label)))

        self.num_labels = len(set(self.labels))

        global torch
        self.normalize_mean = np.reshape(np.array([0.485, 0.456, 0.406]),
                                         [3, 1, 1])
        self.normalize_variance = np.reshape(np.array([0.229, 0.224, 0.225]),
                                             [3, 1, 1])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        label = self.labels[idx]

        # Read and preprocessing image
        img = cv2.imread(path)
        if img is not None:
            img = cv2.resize(img, (self.img_w, self.img_h),
                             interpolation=cv2.INTER_CUBIC)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = np.transpose(img, [2, 0, 1]).astype(np.float32)
            img /= 255.0
            img = ((img - self.normalize_mean) / self.normalize_variance)
            img = img.astype(np.float32)
            img = torch.from_numpy(img)
        else:
            img = torch.randn(3, self.img_h, self.img_w)
            label = self.num_labels

        return img, label

In [9]:
def _forward(model, images):
    with torch.no_grad():
        rets = model(images)
        rets = rets.cpu().numpy()
    return rets

In [10]:
def _get_pred(model, dataloader, num_batches=1):
    # Run evaluation
    gt_labels = []
    pred_labels = []
    rand_idx = random.randint(0,100)
    batch_count = 1
    for idx, batch in enumerate(dataloader):
        if idx!=rand_idx:
            continue
        images, labels = batch
        images = images.cuda(0)

        # Run forward
        rets = _forward(model, images)

        # Accumulate results
        gt_labels += labels.tolist()

        pred_labels += np.argmax(rets, axis=1).tolist()

        if idx > 0 and idx % 500 == 0:
            logging.info("[CLASSIFICATION_EVALUATION] %s/%s batches......" %
                         (idx, num_batches))
        batch_count+=1
        if batch_count>num_batches:
            break
    return images, gt_labels, pred_labels

In [15]:
def run_eval():
    input_h=256
    input_w=128
    config=json.load(open(config_path,"r"))
    config["num_labels"] = 163400
    model = ft_net(config, model_name, pcb_n_parts=4)
    model_utils.restore_model(pretrain_snapshot, model)
    model.eval()
    model.cuda(0)
    transforms_list = []
    transforms_list.append(transforms.Resize((input_h, input_w)))
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transforms_list.append(transforms.ToTensor())
    transforms_list.append(normalize)
    all_transforms = transforms.Compose(transforms_list)
    dataset = LMDBDataset(        
                lmdb_path,
                transform=all_transforms
            )
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=12)
    imgs, gt_labels, pred_labels = _get_pred(model, dataloader, task_index=task_index, num_batches=1)
    return imgs, gt_labels, pred_labels

In [16]:
page = np.ones(((img_size[0]+h_gap)*rows, (img_size[1]+v_gap)*cols, 3), dtype=np.uint8) * 255
imgs, gt_labels, pred_labels = run_eval()
normalize_mean = np.reshape(np.array([0.485, 0.456, 0.406]),
                                         [3, 1, 1])
normalize_variance = np.reshape(np.array([0.229, 0.224, 0.225]),
                                             [3, 1, 1])
np_imgs=[]
for img in imgs:
    img=np.array(img).astype(np.float32)
    img = img*normalize_variance+normalize_mean
    img*=255.0
    img = np.transpose(img, [1,2,0]).astype(np.float32)
    np_imgs.append(img)
    
age_scpoe_mapping = {0:"0-1", 1:"2-5", 2:"6-10", 3:"11-15", 4:"16-20", 
                    5:"21-25", 6:"25-30", 7:"31-40", 8:"41-50", 
                     9:"51-60", 10:"61-80", 11:"80+", }
k=0
for row in range(1,rows+1):
    for col in range(1,cols+1):
        gt_label = gt_labels[col*(row-1)+col-1]
        pred_label = pred_labels[col*(row-1)+col-1]
        cv2.putText(page, str(gt_label), \
                    ((col-1)*(img_size[1]+v_gap), (row-1)*(h_gap+img_size[0])+50), \
                    cv2.FONT_HERSHEY_PLAIN, 1.5, (0,0,0), 2)
        cv2.putText(page, str(pred_label), \
                    ((col-1)*(img_size[1]+v_gap)+60, (row-1)*(h_gap+img_size[0])+50), \
                    cv2.FONT_HERSHEY_PLAIN, 1.5, (0,0,0), 2)
        page[row*h_gap+(row-1)*img_size[0]:row*(h_gap+img_size[0])\
            , (col-1)*(v_gap+img_size[1]):col*img_size[1]+(col-1)*v_gap] \
            = np_imgs[k]
        k+=1

correct_count = 0
for gt, pred in zip(gt_labels, pred_labels):
    if gt == pred:
        correct_count += 1
accuracy = 1.0 * correct_count / len(gt_labels)
print ("ACC:%.4f"%accuracy)
plt.figure(figsize = (30,30))
plt.axis('off')
plt.imshow(cv2.cvtColor(page, cv2.COLOR_BGR2RGB))

NameError: name 'task_index' is not defined