In [1]:
import sys
import os
import numpy as np
import random
import math

import torch
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.nn.parallel import DataParallel
import pickle
import time
import argparse
from PIL import Image, ImageFont, ImageDraw
import cv2
from baseline.model.DeepMAR import DeepMAR_ResNet50
from baseline.utils.utils import str2bool
from baseline.utils.utils import save_ckpt, load_ckpt
from baseline.utils.utils import load_state_dict 
from baseline.utils.utils import set_devices
from baseline.utils.utils import set_seed
from sklearn.metrics.pairwise import cosine_similarity

In [2]:
class Config(object):
    def __init__(self):
        self.sys_device_ids = (0,)
        # random
        self.set_seed = False
        if self.set_seed:
            self.rand_seed = 0
        else: 
            self.rand_seed = None
        self.resize = (224,224)
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]
        # utils
        self.load_model_weight = True
        self.model_weight_file = 'color/ckpt_epoch100.pth'
        if self.load_model_weight:
            if self.model_weight_file == '':
                print('Please input the model_weight_file if you want to load model weight')
                raise ValueError        
        # dataset 
        datasets = dict()
        datasets['peta'] = './dataset/peta/peta_dataset.pkl'
        datasets['rap2'] = './color/rap2_dataset.pkl'
        datasets['rap'] = './dataset/rap/rap_dataset.pkl'
        datasets['pa100k'] = './dataset/pa100k/pa100k_dataset.pkl'
        dataset= pickle.load(open('color/rap2_dataset.pkl', 'rb'))
        self.att_list = [dataset['att_name'][i] for i in dataset['selected_attribute']]
        # model
        model_kwargs = dict()
        model_kwargs['num_att'] = len(self.att_list)
        model_kwargs['last_conv_stride'] = 2
        self.model_kwargs = model_kwargs

In [3]:
cfg = Config()
if cfg.set_seed:
    set_seed( cfg.rand_seed )
# init the gpu ids
set_devices(cfg.sys_device_ids)


In [4]:
# dataset 
normalize = transforms.Normalize(mean=cfg.mean, std=cfg.std)
test_transform = transforms.Compose([
        transforms.Resize(cfg.resize),
        transforms.ToTensor(),
        normalize,])

In [5]:
### Att model ###
modelAtt = DeepMAR_ResNet50(**cfg.model_kwargs)
if cfg.load_model_weight:
    map_location = (lambda storage, loc:storage)
    ckpt = torch.load(cfg.model_weight_file, map_location=map_location)
    modelAtt.load_state_dict(ckpt['state_dicts'][0])
modelAtt.eval()

  init.normal(self.classifier.weight, std=0.001)
  init.constant(self.classifier.bias, 0)


DeepMAR_ResNet50(
  (base): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (downsample): Sequential(
          (0):

In [6]:
def get_attVector(img):
    img_trans = test_transform(img) 
    img_trans = torch.unsqueeze(img_trans, dim=0)
    img_var = Variable(img_trans)
    score = modelAtt(img_var).data.cpu().numpy()
    return score

In [8]:
img = cv2.imread('bleh.jpg')
crop_img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
crop_img = Image.fromarray(crop_img)
score = get_attVector(crop_img)

In [9]:
score

array([[ -7.0143514 ,  -9.675814  ,   2.489856  ,  -2.7437792 ,
         -9.39653   , -10.822732  ,  13.615581  , -12.96227   ,
         11.83883   , -12.6854925 ,  -8.6102295 ,  -6.5295124 ,
         -2.0660608 ,   2.906476  ,  -9.43663   ,  -4.1125817 ,
         -6.37259   , -10.353272  ,  -7.1987143 ,  -7.3089547 ,
          5.337314  ,  -8.418406  ,  -4.7267876 ,  -9.930874  ,
         -6.677887  ,  -1.5369021 ,  -5.302395  ,  -3.348677  ,
         -5.50881   , -10.110057  ,  -1.9877123 , -11.098897  ,
         -8.781584  ,  -9.28554   ,   1.6675557 ,  -8.004186  ,
        -10.879558  ,  -6.8836217 , -10.601658  ,   4.0543933 ,
         -3.8739593 ,  -6.742448  ,  -8.903313  ,  -7.9080815 ,
         -8.338085  ,  -6.6024075 ,  -8.076284  ,  -4.0270724 ,
         -7.8680964 ,   0.11844312,  -7.8596625 ,  -8.819599  ,
         -4.253895  ,  -3.1235733 ,  -3.2135916 ,  -7.6895475 ,
         -8.899791  ,  -2.3246174 ,  -4.191793  ,   1.468929  ,
         -5.9149427 ,  -6.9852357 , -10.