In [1]:
# -*- coding: utf-8 -*-
import cv2
import os
import sys
import time
sys.path.insert(0,"/world/data-gpu-112/liliang/pytorch-reid")
import random
import json
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
import ipdb
import pickle


In [4]:
def _get_pred(models, json_path, task_index):
    def _softmax(x):
        exp_x = np.exp(x)
        softmax_x = exp_x / np.sum(exp_x)
        return softmax_x 
    pred_logits_lst = []
    pred_labels_lst = []
#     record_dict_lst = [] #用于统计label数量
    images_lst = []
    id_lst = []
    for i, model in enumerate(models):
#         record_dict = {"0":0, "1":0, "2":0, "3":0, "4":0, "5":0, "6":0, "7":0, "8":0, "9":0, "10":0, "11":0}
        pred_logits = []
        pred_labels = []
        f = open(json_path, "r")
        for line in f:
            if len(line) < 40:
                continue
            path, info = line.split("\t")
            info = json.loads(info)
            label = info["label"]
#             if label in label_set:
#                 continue
#             else:
#                 label_set.add(label)
            id_lst.append(label)
            image = cv2.imread(path)
            image = cv2.resize(image, (128,256)) 
            image = np.expand_dims(image, axis=0)
            images_lst.append(image)

            logits = model.predict(image)
            logits = logits[task_index]
            logits = _softmax(np.asarray(logits))[0]
            pred_label = np.argmax(logits)
            pred_logits.append(logits)
            pred_labels.append(pred_label) # imshow展示原来的label
#             record_dict[i][str(pred_label)] += 1
            
        pred_logits_lst.append(pred_logits)
        pred_labels_lst.append(pred_labels)
#         record_dict_lst.append(record_dict_ori)
            
    return images_lst, pred_logits_lst, pred_labels_lst, id_lst


def aggregate(images_lst, pred_logits_lst, pred_labels_lst, id_lst):
    id2infos = {}
    for i, label in enumerate(id_lst):
        if not label in id2infos:
            id2infos[label] = {"logits": [pred_logits_lst[i]], "pred_labels": [pred_labels_lst[i]]}
        else:
#             id2infos[label]["imgs"].append(images_lst[i])
            id2infos[label]["logits"].append(pred_logits_lst[i])
            id2infos[label]["pred_labels"].append(pred_labels_lst[i])
    return id2infos

def cal_age_xulie(images_lst, id2infos):
    # scope_xulie could be changed
    scope_xulie = ["0-1", "2-5", "6-10", "11-15", "16-20", "21-25", "26-30", "31-40", "41-50", "51-60", "61-80", "80-100"]
    record_dict = {"0":0, "1":0, "2":0, "3":0, "4":0, "5":0, "6":0, "7":0, "8":0, "9":0, "10":0, "11":0}
    refer = [0.5, 2.5, 8, 13, 18, 23, 28, 35, 45, 55, 70, 90]
    id2age = []
    img_show = []
    record_dict_lst = []
    for label, infos in id2infos.items():
        logits_lst = np.asarray(infos["logits"])
        age = np.sum(np.sum(logits_lst, axis=0) * np.asarray(refer)) / len(logits_lst)
        id2age.append(age)
        img_show.append(infos["imgs"][0])
        for i, age_scope in enumerate(scope_xulie):
            lower_bound, upper_bound = map(int, age_scope.split("-"))
            if lower_bound < age < upper_bound:
                record_dict[str(i)] += 1
    return id2age, img_show, record_dict, scope_xulie

def cal_age_online(images_lst, id2infos):
    scope = ["0-1", "2-5", "6-10", "11-15", "16-20", "21-25", "26-30", "31-40", "41-50", "51-60", "61-80", "80+"]
    record_dict = {"0":0, "1":0, "2":0, "3":0, "4":0, "5":0, "6":0, "7":0, "8":0, "9":0, "10":0, "11":0}
    id2age = []
    img_show = []
    record_dict_lst = []
    for label, infos in id2infos.items():
        label_lst = infos["pred_labels"]
        id_age = max(label_lst, key=label_lst.count)
        id2age.append(id_age)
        img_show.append(infos["imgs"][0])
        record_dict[str(id_age)] += 1
    return id2age, img_show, record_dict, scope

def run_eval(task_index, pretrain_snapshots):
    from auto_deploy.predictor import Predictor
    preprocess_dict = {"norm_lambda": "lambda x: x", "color_mode": "None"} 
    models = []
    for pretrain_snapshot in pretrain_snapshots:
        print ("model restored form: %s" %pretrain_snapshot)
        model = Predictor(gpu=0, input_size=(128, 256), model_path=pretrain_snapshot,
                          preprocess_dict=preprocess_dict)
        models.append(model)
    
    # forward
    images_lst, pred_logits_lst, pred_labels_lst, id_lst = _get_pred(models, json_path, task_index)
    
    
    record_dict_lst = [] #统计label数量
    id2age_lst = []
    for i in range(len(pretrain_snapshots)):
        id2infos = aggregate(images_lst, pred_logits_lst[i], pred_labels_lst[i], id_lst)
        with open('/world/data-gpu-112/liliang/16-22_online.pkl', 'wb') as f:
            pickle.dump(id2infos, f)
        id2age, images_lst, record_dict, scope = cal_age_xulie(images_lst, id2infos)
#         id2age, images_lst, record_dict, scope = cal_age_online(images_lst, id2infos)
        record_dict_lst.append(record_dict)
        id2age_lst.append(id2age)
        
    return images_lst, pred_logits_lst, id2age_lst, record_dict_lst, id_lst, scope

In [5]:
pretrain_snapshots = [
    "/world/data-c7/xiaoyouchang/resnet_50_ibn_a_20190409_163335.pt",
# "/world/data-c7/liliang/pytorch_reid_models/resnet_50_ibn_a_best.pt",
]
json_path = "/world/data-c26/liliang/anta_json/entrance_16-22.json"
dataset_name = json_path.split("/")[-1].split(".")[0]

model_names = [path.split("/")[-2] for path in pretrain_snapshots]
task_idx = {"age":1, "bag":2, "gender":3, "orient":4}
age_scpoe_mapping = {0:"0-1", 1:"2-5", 2:"6-10", 3:"11-15", 4:"16-20", 
                    5:"21-25", 6:"26-30", 7:"31-40", 8:"41-50", 
                     9:"51-60", 10:"61-80", 11:"80+", }

# imgs to np.ndarray
images_lst, pred_logits_lst, id2age_lst, record_dict_lst, id_lst, scope = run_eval(task_idx["age"], pretrain_snapshots)


model restored form: /world/data-c7/xiaoyouchang/resnet_50_ibn_a_20190409_163335.pt


KeyError: 'imgs'

In [None]:
def distribution(scope, scope_cnt, model_name):
    plt.figure(figsize=(12,4))
    num_bins = 12
    plt.bar(scope, scope_cnt, width=0.5, color="green", align="center")
    for a, b in zip(scope, scope_cnt):
        plt.text(a, b+0.5, '%.0f' % b, ha='center', va='bottom', fontsize=10)
    plt.xticks(range(len(scope)), scope, rotation=45)
    plt.xlabel('age distribution')
    plt.ylabel('counts')
    plt.title("%s -- %s -- total ids:%s"%(dataset_name, model_name, sum(scope_cnt)))
    plt.show()

rows = 50
cols = 2
def check_diffpred(pred_labels_lst):
    diff_flag = []
    show_logits = []
    for i in range(len(pred_logits_lst)):
        pred_labels = pred_labels_lst[i]
        logits = pred_logits_lst[i]
        show_logit = []
        show_logit.append([round(i,2) for i in logits[0]]) #21-25,26-30,31-40
        show_logit.append([round(i,2) for i in logits[1]])
        show_logits.append(show_logit)
        if pred_labels[0] != pred_labels[1]:
            diff_flag.append(True)
#             print(i)
#             print([round(i,2) for i in logits[0]], pred_labels[0])
#             print([round(i,2) for i in logits[1]],pred_labels[1])
#             import ipdb
#             ipdb.set_trace()
        else:
            diff_flag.append(False)
            
    return show_logits, diff_flag

def img_show(show_logits, diff_flag):
    idx = 0
    show_cnt = 0
    diff_cnt = 0
    for i in range(len(show_logits)):
        if show_cnt<rows*cols and pred_labels_lst[i][0]<=4:
#             and show_logits[i][1][8]>0.1:
#         if diff_flag[i] == False and pred_labels_lst[i][0]==7:
            if idx%cols==0:
                plt.figure(idx, figsize=(20, 4))
            plt.subplot(1, cols, show_cnt%cols+1)
            img = np.squeeze(imgs[i])
            img = img[:,:,::-1]
            plt.imshow(img)
            plt.text(0, img.shape[0]*1.1, 
                     'logit:%s , label:%s \nlogit:%s , label:%s'
                     %(show_logits[i][0],age_scpoe_mapping[pred_labels_lst[i][0]],
                       show_logits[i][1],age_scpoe_mapping[pred_labels_lst[i][1]]
                        )
                    )
            plt.axis('off')
            idx += 1
            show_cnt += 1
            diff_cnt += 1
        elif diff_flag[i] == True:
            diff_cnt += 1
    return diff_cnt

print("total imgs: %s"%len(pred_logits_lst[0]))
print("total ids: %s"%len(id2age_lst[0]))
for i, record_dict in enumerate(record_dict_lst):
#     scope_cnt = get_scope(record_dict)
    distribution(scope, list(record_dict.values()), model_names[i])

# show_logits, diff_flag = check_diffpred(pred_labels_lst)
# diff_cnt = img_show(show_logits, diff_flag)
# print("total diff: %s"%diff_cnt)

## open-log 分析