In [1]:
import time
import torch
import copy
import json
from torchvision import transforms
from torch import nn
from PIL import Image
import ttach as tta
from tqdm import tqdm
from get_label import get_label

In [2]:
data_dir = "../data/medium/"
with open(data_dir + "test_all.json") as f:
    test_data = json.load(f)
json_dir = "./json/"
with open(json_dir + "word2color.json") as f:
    word2color = json.load(f)
with open(json_dir + "color2label.json") as f:
    color2label = json.load(f)
# with open(json_dir + "friend_label.json") as f:
#     friend_label = json.load(f)

In [3]:
device = torch.device('cuda')
modeltype = 'DenseNet161_pretrained'
epoch = 15
model = torch.load('../model/'+modeltype+'_Epoch_%d.pt'%epoch)
model = model.to(device)
model.eval()

DataParallel(
  (module): DenseNet(
    (features): Sequential(
      (conv0): Conv2d(3, 96, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (norm0): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu0): ReLU(inplace=True)
      (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (denseblock1): _DenseBlock(
        (denselayer1): _DenseLayer(
          (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(96, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm2): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(192, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (denselayer2): _DenseLayer(
          (norm1): BatchNorm2d(144, eps=1e-05, momentu

In [4]:
my_test_data = copy.deepcopy(test_data)
count = 0
since = time.time()
for dir_number, subdict in tqdm(test_data.items()):
    count += 1
    optional_tags = subdict['optional_tags']
    imgs_tags = subdict['imgs_tags']

    # process words
    label2word = {}
    label_lst = []
    for word in optional_tags:
        label = get_label(word, word2color, color2label)
        if label != -1:
            label2word[label] = word
            label_lst.append(label)
    if label2word=={}:
        label2word[0] = optional_tags[0] # word全不认识，则直接猜第一个word
    
    # inference images
    img_tensor_minibatch = None
    first = True
    for subsubdict in imgs_tags:
        for pic_dir in subsubdict: # subsubdict has only one k-v pair
            img = Image.open(data_dir + 'test/' + dir_number + '/' + pic_dir)
            img_tensor = transforms.PILToTensor()(img.resize((224,224)))
            img_tensor = img_tensor.type(torch.FloatTensor).reshape(1,3,224,224)
            if first:
                img_tensor_minibatch = img_tensor
            else:
                img_tensor_minibatch = torch.cat((img_tensor_minibatch, img_tensor), 0)
            first = False
    img_tensor_minibatch = img_tensor_minibatch.to(device)
    with torch.no_grad():
        output = nn.Softmax(dim=1)(model(img_tensor_minibatch))
    
    # output每行有13列，表示某一图片的13个类别概率
    # label先挑图片，选概率最大的图片
    output = output.cpu()
    
    for label in label_lst:
        max_idx = torch.argmax(output, dim=0) # 每列最大值的行索引
        pic_idx = max_idx[label]
        my_test_data[dir_number]['imgs_tags'][pic_idx][dir_number + '_%d.jpg'%pic_idx] = label2word[label]
        output[pic_idx] = -1 # 将该行置为-1，表示已经挑选过了

    # 图片再挑label，预测值从大到小依次查label2word，直到命中
    sorted_label_idx = torch.argsort(output, dim=1, descending=True)
    for pic_idx in range(len(imgs_tags)):
        if my_test_data[dir_number]['imgs_tags'][pic_idx][dir_number + '_%d.jpg'%pic_idx] != None:
            continue
        for label_idx in sorted_label_idx[pic_idx]:
            if int(label_idx) not in label2word:
                continue
            my_test_data[dir_number]['imgs_tags'][pic_idx][dir_number + '_%d.jpg'%pic_idx] = label2word[int(label_idx)]
            break

    if count%200 == 0:
        time_elapsed = time.time() - since
        # print('complete inferencing %d batches in %d seconds'%(count, int(time_elapsed)))

100%|██████████| 5331/5331 [21:54<00:00,  4.06it/s]


In [5]:
with open(json_dir + 'my_test_data_'+modeltype+'_Epoch_%d.json'%epoch, 'w', encoding='utf-8') as f:
    json.dump(my_test_data, f, ensure_ascii=False, indent=4)