__Image Chinese Captioning Datasets Pre-Processing__

0 将官方validation set的标注json文件，转换为COCO评价所需要的格式

In [4]:
import json
import jieba
from tqdm import tqdm

In [5]:
val_data = {}
val_data['type'] = "captions"
val_data['info'] = {
    "contributor": "Yiyu Wang", 
    "description": "ImageChineseCaptioningEval",
    "url": "ttps://github.com/AIChallenger/AI_Challenger.git",
    "version": "1",
    "year": 2021
}
val_data['licenses'] = [{"url": "https://challenger.ai"}]
val_data['images'] = []
val_data['annotations'] = []

In [6]:
# 官方提供的标注json文件，需要将其数据重组为如MSCOCO评估json文件的形式
# --> 参考 caption_validation_annotations.json
# 在该数据集下，预测数据json文件中image_id字段需要使用图像的文件名
# --> 参考 test.json
val_annotation = './ai_challenger_caption_validation_20170910/caption_validation_annotations_20170910.json'

with open(val_annotation, 'r') as f:
    tmp_data = json.load(f)
    
caption_id = 1
for _ in tqdm(tmp_data):
    file_name = _['image_id']           # file name  ***.jpg
    image_id = file_name.split('.')[0]  # image id   ***
    captions = _['caption']             # captions x5
    images_info = {'file_name': file_name, "id": image_id}
    val_data['images'].append(images_info)
    for caption in captions:
        _caption = ' '.join(jieba.cut(caption, cut_all=False))
        annotations_info = {
            'caption': _caption,
            'id': caption_id,
            'image_id': image_id
        }
        # print(image_id, caption_id, _caption)
        # val_data['images'].append(images_info)
        val_data['annotations'].append(annotations_info)
        caption_id += 1
        
with open('./caption_validation_annotations.json', 'w') as f:
    json.dump(val_data, f)

100%|██████████| 30000/30000 [00:37<00:00, 793.19it/s]


1 获取词汇表，重新划分数据集，将官方的Train (210000) / Val (30000) --> Train(230000) / Val (5000) / Test (5000)

2 同时，生成训练所需要的文件，包括数据集各子集的id文件，词汇表文件，Val/Test的GroundTruth文件

In [1]:
# encoding=utf-8
import jieba
import json
import pickle
import os
import sys
from tqdm import tqdm

In [2]:
def tokenize(sent):
    _ = jieba.cut(sent, cut_all=False)
    return list(_)

In [3]:
# 输入文件
raw_train_annotation_file = './ai_challenger_caption_train_20170902/caption_train_annotations_20170902.json'
raw_val_annotation_file = './ai_challenger_caption_validation_20170910/caption_validation_annotations_20170910.json'

# 输出文件
misc_gts_file = './ICC_train_gts.pkl'
misc_cider_file = './ICC_train_cider.pkl'
misc_val5k_ann_file = './ICC_captions_val5k.json'  # 必须包含'type':'captions'字段
misc_test5k_ann_file = './ICC_captions_test5k.json'# 必须包含'type':'captions'字段
sent_input_file = './ICC_train_input.pkl'
sent_target_file = './ICC_train_target.pkl'
txt_train_id_file = './ICC_train_image_id.txt'
txt_val_id_file = './ICC_val_image_id.txt'
txt_test_id_file = './ICC_test_image_id.txt'
txt_vocabulary_file = './ICC_vocabulary.txt'

读取两个输入文件，进行中文分词，统计词汇表；同时，对数据集进行重新划分

In [4]:
# 读取源标注文件
raw_train_annotation = None
raw_val_annotation   = None
with open(raw_train_annotation_file, 'r') as f:
    raw_train_annotation = json.load(f)

with open(raw_val_annotation_file, 'r') as f:
    raw_val_annotation = json.load(f)

In [6]:
token_counter = {}

train_id = []
val_id = []
test_id = []

train_ann = []
val_ann = []
test_ann = []

train_cnt = 0
val_cnt = 0
test_cnt = 0
for i, _ in enumerate(tqdm(raw_train_annotation)):
    _id = _['image_id']
    _sents = _['caption']
    # 对5个参考captions进行分词，并重新组合
    __sents = []
    for _sent in _sents:
        _tokens = tokenize(_sent)  # 词汇表统计
        __sents.append(" ".join(_tokens))
        for _token in _tokens:
            token_counter[_token] = token_counter.get(_token, 0) + 1
    # 统计id信息，以及ann信息
    train_id.append(_id)
    train_ann.append({'image_id': _id, 'caption': __sents})
    train_cnt += 1
    
# 后10000个样本，划分5000作为val，50000作为test
for i, _ in enumerate(tqdm(raw_val_annotation)):
    _id = _['image_id']
    _sents = _['caption']
    # 对5个参考captions进行分词，并重新组合
    __sents = []
    for _sent in _sents:
        _tokens = tokenize(_sent)  # 词汇表统计
        __sents.append(" ".join(_tokens))
        for _token in _tokens:
            token_counter[_token] = token_counter.get(_token, 0) + 1
    # 统计id信息，以及ann信息
    if i < 20000:
        train_id.append(_id)
        train_ann.append({'image_id': _id, 'caption': __sents})
        train_cnt += 1
    elif i >=20000 and i <25000:
        val_id.append(_id)
        val_ann.append({'image_id': _id, 'caption': __sents})
        val_cnt += 1
    else:
        test_id.append(_id)
        test_ann.append({'image_id': _id, 'caption': __sents})
        test_cnt += 1

print('Train set:', train_cnt, len(train_id))
print('Val set:', val_cnt, len(val_id))
print('Test set:', test_cnt, len(test_id))
# 按照token出现次数排序
ct = sorted([(count,token) for token,count in token_counter.items()], reverse=True)

100%|██████████| 210000/210000 [04:38<00:00, 753.28it/s]
100%|██████████| 30000/30000 [00:39<00:00, 760.28it/s]


Train set: 230000 230000
Val set: 5000 5000
Test set: 5000 5000


In [16]:
# 保存train / val / test id以及词汇表到txt文件中
index = 1
token2index = {}  # 分词词汇 -> index序号 映射
final_token = []  # 最终收集到的分词词汇集合
bad_tokens = []
with open(txt_vocabulary_file, 'w') as f:
    for count, token in ct:
        if count >= 5:
            f.write(str(token) + '\n')
            token2index[str(token)] = index
            final_token.append((index, str(token)))
            index += 1
        else:
            bad_tokens.append(token)
    f.write(str('UNK') + '\n')
    token2index[str('UNK')] = index
    final_token.append((index, str('UNK')))
    
with open('./coco_bad_token.txt', 'w') as f:
    for token in bad_tokens:
        f.write(str(token) + '\n')
        
with open(txt_train_id_file, 'w') as f:
    for _id in train_id:
        f.write(str(_id) + '\n')

with open(txt_val_id_file, 'w') as f:
    for _id in val_id:
        f.write(str(_id) + '\n')

with open(txt_test_id_file, 'w') as f:
    for _id in test_id:
        f.write(str(_id) + '\n')

In [17]:
# 保存val / test annotation 到json文件中
def info_pre():
    data = {}
    data['type'] = "captions"
    data['info'] = {
        "contributor": "Yiyu Wang", 
        "description": "ImageChineseCaptioningEval",
        "url": "ttps://github.com/AIChallenger/AI_Challenger.git",
        "version": "1",
        "year": 2021
    }
    data['licenses'] = [{"url": "https://challenger.ai"}]
    data['images'] = []
    data['annotations'] = []
    return data

# 共用字段生成
val_annotation_data = info_pre()
test_annotation_data = info_pre()
caption_id = 1
for _data in tqdm(test_ann):
    _image_id = _data['image_id']   # ***.jpg
    _id = _image_id.split('.')[0]   # ***
    _captions = _data['caption']
    test_annotation_data['images'].append({'file_name': _image_id, 'id': _id})
    for _caption in _captions:
        test_annotation_data['annotations'].append(
            {
                'caption': _caption,
                'id': caption_id,
                'image_id': _id
            }
        )
        caption_id += 1
        
caption_id = 1
for _data in tqdm(val_ann):
    _image_id = _data['image_id']   # ***.jpg
    _id = _image_id.split('.')[0]   # ***
    _captions = _data['caption']
    val_annotation_data['images'].append({'file_name': _image_id, 'id': _id})
    for _caption in _captions:
        val_annotation_data['annotations'].append(
            {
                'caption': _caption,
                'id': caption_id,
                'image_id': _id
            }
        )
        caption_id += 1
        
with open(misc_test5k_ann_file, 'w') as f:
    json.dump(test_annotation_data, f)
    
with open(misc_val5k_ann_file, 'w') as f:
    json.dump(val_annotation_data, f)

100%|██████████| 5000/5000 [00:00<00:00, 67431.46it/s]
100%|██████████| 5000/5000 [00:00<00:00, 118294.69it/s]


根据生成的词汇表（包括UNK，从1计数，0用于<sos>和<eos>），以及train_ann生成训练所需的几个pkl文件

In [19]:
# 未处理信息: train_ann, token2index, final_token
print(len(train_ann))
print(len(token2index))
print(len(final_token))

230000
8257
8257


In [45]:
import numpy as np

# 需要将caption从str序列转为index序列
def str2index(sent, token2index):
    result = [token2index[str(token)] if str(token) in token2index else token2index[str('UNK')] for token in sent.split(' ')] # 逐token转换为index
    result.append(0)  # 在末尾添加 <eos>
    return result

def fill_list(_list, fill, length):
    _len = len(_list)
    if _len > length:
        return _list[:length]
    else:
        _fill = [fill for i in range(length-_len)]
        return _list + _fill

# ICC_train_gts.pkl
# ICC_train_input.pkl
# ICC_train_target.pkl
ICC_train_gts = []     # list
ICC_train_input = {}
ICC_train_target = {}

for _data in tqdm(train_ann):
    _image_id = _data['image_id']
    _captions = _data['caption']
    _index_list = []
    _input = []
    _target = []
    for i, _caption in enumerate(_captions):
        # 将_caption从str序列转换为对应的index序列
        # NOTE：里面存在几个图像只有4个caption，即存在caption为 ""（空），使用同图像另4个中的一个替代
        if len(_caption) == 0:
            if i != 0:
                _caption = _captions[0]
            else:
                _caption = _captions[i-1]
        _caption_index = str2index(_caption, token2index)
        _index_list.append(_caption_index[:20]) # 只保留前20个分词，多余的直接舍弃
        # 需要将其扩展到长度为20
        _input_index = [0] + _caption_index  # 首部添加 <bos>
        _target_index = _caption_index       
        _input.append(fill_list(_input_index, 0, 20))
        _target.append(fill_list(_target_index, -1, 20))
        
    ICC_train_gts.append(_index_list)
    _id = _image_id.split('.')[0]  # 去除'.jpg'，仅保留图像文件名
    ICC_train_input[_id] = np.array(_input)
    ICC_train_target[_id] = np.array(_target)
   
# 保存到文件中
with open(misc_gts_file, 'wb') as f:
    pickle.dump(ICC_train_gts, f)
    
with open(sent_input_file, 'wb') as f:
    pickle.dump(ICC_train_input, f)
    
with open(sent_target_file, 'wb') as f:
    pickle.dump(ICC_train_target, f)

100%|██████████| 230000/230000 [00:40<00:00, 5659.25it/s]


2 生成CIDEr score cache文件

In [52]:
# CIDEr score cache file，用于训练过程中快速计算CIDEr得分，避免训练速度瓶颈
# ICC_train_cider.pkl
import numpy as np
import pickle
from collections import defaultdict

def precook(words, n=4, out=False):
    """
    Takes a string as input and returns an object that can be given to
    either cook_refs or cook_test. This is optional: cook_refs and cook_test
    can take string arguments as well.
    :param s: string : sentence to be converted into ngrams
    :param n: int    : number of ngrams for which representation is calculated
    :return: term frequency vector for occuring ngrams
    """
    counts = defaultdict(int)
    for k in range(1,n+1):
        for i in range(len(words)-k+1):
            ngram = tuple(words[i:i+k])
            counts[ngram] += 1
    return counts

def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
    '''Takes a list of reference sentences for a single segment
    and returns an object that encapsulates everything that BLEU
    needs to know about them.
    :param refs: list of string : reference sentences for some image
    :param n: int : number of ngrams for which (ngram) representation is calculated
    :return: result (list of dict)
    '''
    return [precook(ref, n) for ref in refs]

def cook_test(test, n=4):
    '''Takes a test sentence and returns an object that
    encapsulates everything that BLEU needs to know about it.
    :param test: list of string : hypothesis sentence for some image
    :param n: int : number of ngrams for which (ngram) representation is calculated
    :return: result (dict)
    '''
    return precook(test, n, True)

# 读取 ICC_train_target.pkl，生成保存 ICC_train_cider.pkl
target_seqs = pickle.load(open(sent_target_file, 'rb'), encoding='bytes')
print(type(target_seqs))

# 读取 ICC_train_gts.pkl 文件
gts = pickle.load(open(misc_gts_file, 'rb'), encoding='bytes')
print(type(gts))

# 核心操作，统计词频（分词词汇的index，非实际str信息）
crefs = []
for gt in gts:
    crefs.append(cook_refs(gt))

document_frequency = defaultdict(float)
for refs in crefs:
    # refs, k ref captions of one image
    for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
        document_frequency[ngram] += 1
ref_len = np.log(float(len(crefs)))
pickle.dump(
    {'document_frequency': document_frequency, 'ref_len': ref_len }, 
    open(misc_cider_file, 'wb')
)



<class 'dict'>
<class 'list'>


使用SwinTransformer提取ICC数据集的特征

NOTE: 此部分代码为在SwinTransformer项目下运行源码，在notebook中运行需要进行修改

In [53]:
# 构造ICC Dataset类，用于实现图像的读取与预处理，方便SwinTransformer进行特征提取
# tools/ICC_dataset.py
import json
import cv2
import os
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data.transforms import _pil_interp

class ICCDataset(Dataset):
    def __init__(self, img_folder, json_file, img_size=384):
        self.img_folder = img_folder
        # 读取json文件，获取图像文件名
        with open(json_file, 'r') as f:
            json_data = json.load(f)
            # test 的 json 文件组织形式和 train / val 不一样，需要分开处理
            if 'test' in json_file:
                self.image_data = [_['file_name']+'.jpg' for _ in json_data['images']]
                self.image_data = list(set(self.image_data))  # 去重
            else:
                self.image_data = [_['image_id'] for _ in json_data]
    
        # 构建图像预处理单元
        self.transform = transforms.Compose([
            transforms.Resize((img_size, img_size), interpolation=_pil_interp('bicubic')),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)]
        )
    def __len__(self):
        return len(self.image_data)
    
    def __getitem__(self, index):
        img_path = os.path.join(self.img_folder, self.image_data[index])
        img = cv2.imread(img_path)
        img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        return self.transform(img), self.image_data[index]
    

In [54]:
# extract_feats_ICC.py
import torch
import os
import sys
from tqdm import tqdm
import numpy as np
import argparse

from torch.utils.data import Dataset, DataLoader

from models.swin_transformer import SwinTransformer
from tools.ICC_dataset import ICCDataset

JSON_PATHS = {
    'train': 'ai_challenger_caption_train_20170902/caption_train_annotations_20170902.json',
    'val': 'ai_challenger_caption_validation_20170910/caption_validation_annotations_20170910.json',
    'test_a': 'ai_challenger_caption_test_a_20180103/caption_test_a_annotations_20180103.json',
    'test_b': 'ai_challenger_caption_test_b_20180103/caption_test_b_annotations_20180103.json'
}

def parse_args():
    parser = argparse.ArgumentParser(description='extract feature from Swin')
    parser.add_argument('--model', help='model weights', type=str, default='./pre_trained/swin_large_patch4_window12_384_22kto1k.pth')
    parser.add_argument('--dataset_folder', help='ICC folder', type=str, default='/home/wangyiyu/wangyiyu_data_ssd6/ImageChineseCaptioning_dataset/AI_Challenger/')
    parser.add_argument('--dataset_split', nargs='+', help='train/val/test')
    parser.add_argument('--out_folder', help='the storage folder of feature', type=str, default='/home/wangyiyu/wangyiyu_data_ssd6/ImageChineseCaptioning_dataset/AI_Challenger/SwinL_features')
    parser.add_argument('--batch_size', help='batch size', type=int, default=1)
    parser.add_argument('--num_workers', help='number of workers', type=int, default=1)
    
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)
        
    args = parser.parse_args()
    return args

def forward_feats(model, x):
    # 提取图像特征，[B, L, D]
    x = model.patch_embed(x)
    x = model.pos_drop(x)

    for layer in model.layers:
        x = layer(x)

    x = model.norm(x)  # B L C
    return x

def extract_feats(model, data_loader, out_folder):
    # 输出路径如果不存在，则创建
    if not os.path.exists(out_folder):
        os.mkdir(out_folder)
    
    model.eval()
    import shutil
    for i, (batch_img, img_meta) in enumerate(tqdm(data_loader)):
        with torch.no_grad():
            feats = forward_feats(model, batch_img.cuda()) # 提取特征
          
        # 将提取到的特征，转换为numpy.ndarray存储为npz文件
        feats_folder = out_folder
        for j, img_path in enumerate(img_meta):
            img_id = str(img_path.split('.')[0])  # ICC数据集图像文件名即img_id
            feat = feats[j]
            # print(img_id, feat.cpu().size())
            np.savez_compressed(os.path.join(feats_folder, str(img_id)), feat=feat.cpu().numpy())
            

def main(args):
    
    batch_size = args.batch_size
    num_workers = args.num_workers
    # data_folder = '/home/wangyiyu/wangyiyu_data/coco2014/'
    data_folder = args.dataset_folder
    
    split = args.dataset_split
    
    # 构建模型
    model = SwinTransformer(
        img_size=384, 
        embed_dim=192, 
        depths=[2, 2, 18, 2],
        num_heads=[6, 12, 24, 48],
        window_size=12,
        num_classes=1000
    ).cuda()
    
    # 导入模型参数
    # checkpoint_path = './pre_trained/swin_large_patch4_window12_384_22k.pth'
    checkpoint_path = args.model
    checkpoint = torch.load(checkpoint_path, map_location='cuda')
    model.load_state_dict(checkpoint['model'], strict=False)
    print(model)
    
    for _ in split:
        if _ not in JSON_PATHS:
            print('split error, should be train / val / test!')
            sys.exit(1)
        else:
            json_path = JSON_PATHS[_]
            # 构造数据集
            dataset = ICCDataset(
                os.path.join(data_folder, json_path.split('.')[0].replace('annotations', 'images')),
                os.path.join(data_folder, json_path)
            )
            print(_, len(dataset))
            dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
            
            # 提取特征
            print('extract feature for %s' % str(_)+'2014')
            print('data length:', len(dataset))
            extract_feats(model, dataloader, args.out_folder)    
    
if __name__ == '__main__':
    args = parse_args()
    print(args)
    main(args)

210000


In [80]:
# NOTE: 特征提取完成之后发现，train set中存在与test a和test b同名数据（图片分辨率不一致，但属于相同图片），但是标注不完全一致
# 与test a同名文件：
#   ['4f00cd3c12151a4835394351771c3cead4163504', '7792e3f10ea8cb47d72f00055e6d39626355f1ab', 'ffefb84febc02d7378444b4109067c04e1975121', 'ee03bd82f33765ff8bd3c038df87728deaf80e12', '740ba78c6e7d3e87e4c7db16ec435bfa996af45e', '63bb7d86af30f0bd9c17c1db073f84c870a015c6', 'c35ea7caa03f34d23aeb90c6b38da949d2d65632', 'f8554ea6a881eb48203db4c04daa9bb99f5317e9', 'c780ac6640c9cb66fe5d95d3e06ba202932d6931', '57efb4b95ef0a6098c59b6a4861e676c798d476c', '9bce1bd771b395e6081445197052a9d2487e116c', 'a3b529fbc407beba26e2069fb55757179639a413', 'ee6cbf98ee35d4eb4283dc0b284194b1c27cd20f', 'bd31b5ec946bb9b941008e045f76393a1e715fbb', '60326b2ae46317fd56cb5c3a43a08dcc4cb1b798', '8be62893034f4038ded00343069cf6af1e18622f', '03f9c6bf168661563a1fe63cab001b31c32828d8', '5e3d6d490384ba9279538fb0ce91f595c030be2b', 'd758b1e63d884afbd458046715c414788f3e2c1b', '9df4b1957cd74348e46395974ae32c65459990f9', 'f0c3f9211bc623dfdc7f51e47abb139f828fa2c3', 'f50ab63f0dfdc0c84628b36b1f2b59ca846002f3', 'f108dcc06d6bc03392d4194b55ba9337babf7bc1', '7e0b97ee7cb090325fe2fca2a7b09a8071861a8e', '1aa34363b77795b35e951e5d88ac41c97648459c', 'd3af8a8b1a4402c0d89c577c2d35b334244087e8', '6da45283e31e3c40eb8e29f379a05268eef67ebc', '187cf95a4c3c09ab6aacd7cc50cc9e8cb2488c93', 'd33d8728737850238b4ad2e98917dd2634f789c3', '5fde1b7a53175481f1df1b047657c5191ba93121']
# 与test b同名文件：
#   ['290ed792c8315b1478603a34ae8a4adc7aa039b2', '8715aadf88f76080c6453ffb57ee38efbf021d08', 'd1ff7d6c3886da19b4d2ece233fe28b0545b7c95', '54612401fac9b89d6a7278a9f3f934159c180451', 'c6aaa1f23dd70f7af8fd391fcd25aaa325a11054', '4d3e523ff9009069f6d12928e7193c251518d9e3', '8caad04e7571672315b37c4a81522f3778330416', 'b89ee384b0b566e8f6ba9e00d48e75ec61c0296c', 'cf5e4d0178ee5445b3b40a0d6ce00ec4d76fab73', '213d1ec8e1439270a862fceeef63c7246a3fc491', 'c899232b0dba94804d9798b15e42f3967c6483ce', 'd4224522164681666cbc67f4afbe372d84914ffb', '2a98c11f9556cffc5c79d301227649d3b8b2da53', '6ae34d3f49544c1a08d30cc0cc659016793be611', '2ffc2009d22bcfddd52191b73d89af7f9c7ea570', 'bf147aa685469bb95ed34a8812d7ba02ffb99b8f', '1012757b58640d862d1ded08bde7dfd0f8f2d5d1', '996df064e684998f3736958c61053144f65284a0', '4ddcd54822a46bfee743454aac33e04448590b17', '702c4ff9c70d3c373ae303062ad9042bb95dceeb', 'f525426d9551a976de68fbc6e0b8f8fdaffb2a62', '1680797d244c279291a08e10fdc570a06e214261', '24b308b395444241d6ab745b23c9cab847874a83', '7e373dad5cf4d019b46895a840c810410dc0eaa5', 'af0746dfef181b4156ab8112f45bbacc4faec2c4', 'b58a3cfbb489667c1a279ac5919d0ea6dabe8537', '149f7db108275c32efa9eb2179515e886f58811d', '93ce9f9d884b413b37ba21f0440ce7dd3ffff092', '7b33f51f11c6f44636562b8a524b23b121e94f34', '605a7185d234aa999f706ac39819d70eb0b0aa02', 'ba2afd60a9f2d02fe9017b19e03f34828eadb9a7']

In [83]:
import shutil

def get_ids(json_file):
    json_data = json.load(open(json_file, 'r'))
    if 'test' in json_file:
        _ids = [_['file_name'] for _ in json_data['images']]
        _ids = list(set(_ids))
    else:
        _ids = [_['image_id'].split('.')[0] for _ in json_data]
            
    return _ids

In [84]:
test_a_json_file = './ai_challenger_caption_test_a_20180103/caption_test_a_annotations_20180103.json'
test_b_json_file = './ai_challenger_caption_test_b_20180103/caption_test_b_annotations_20180103.json'
test_a_ids = get_ids(test_a_json_file)
test_b_ids = get_ids(test_b_json_file)
test_ids = list(set(test_a_ids + test_b_ids))
print(len(test_ids))


60000


In [86]:
# 将test a和test b的特征文件（60000条数据）存放到单独的路径下 （test_a_b）
input_folder = './SwinL_features/'
out_folder = './SwinL_features/test_a_b'
if not os.path.exists(out_folder):
    os.mkdir(out_folder)
    
for _id in tqdm(test_ids):
    src_npz = os.path.join(input_folder, _id+'.npz')
    tar_npz = os.path.join(out_folder, _id+'.npz')
    shutil.move(src_npz, tar_npz)
    
# 将剩余的npz文件（239939条数据）存放到train_val路径下，
# 缺少的61张图像的特征数据，重新使用SwinTransformer提取
list_dir = os.listdir(input_folder)
train_val_npz = []
for _ in list_dir:
    if '.npz' in _:
        train_val_npz.append(_)
print(len(train_val_npz))

for _npz in train_val_npz:
    src_npz = os.path.join(input_folder, _npz)
    tar_npz = os.path.join('./SwinL_features/train_val', _npz)
    shutil.move(src_npz, tar_npz)



  0%|          | 0/60000 [00:00<?, ?it/s][A[A

  2%|▏         | 1363/60000 [00:00<00:04, 13628.65it/s][A[A

  6%|▌         | 3676/60000 [00:00<00:03, 15543.62it/s][A[A

 13%|█▎        | 7702/60000 [00:00<00:02, 19052.16it/s][A[A

 19%|█▊        | 11177/60000 [00:00<00:02, 22038.31it/s][A[A

 25%|██▍       | 14703/60000 [00:00<00:01, 24831.47it/s][A[A

 30%|██▉       | 17774/60000 [00:00<00:01, 26342.46it/s][A[A

 35%|███▌      | 21072/60000 [00:00<00:01, 28033.92it/s][A[A

 41%|████      | 24499/60000 [00:00<00:01, 29649.83it/s][A[A

 47%|████▋     | 28043/60000 [00:00<00:01, 31177.79it/s][A[A

 53%|█████▎    | 32072/60000 [00:01<00:00, 33445.63it/s][A[A

 59%|█████▉    | 35546/60000 [00:01<00:00, 33097.06it/s][A[A

 65%|██████▍   | 38947/60000 [00:01<00:00, 33061.78it/s][A[A

 71%|███████   | 42378/60000 [00:01<00:00, 33424.52it/s][A[A

 76%|███████▋  | 45766/60000 [00:01<00:00, 33417.16it/s][A[A

 82%|████████▏ | 49447/60000 [00:01<00:00, 34365.51it/s]

保存测试集test a和test b的id到txt文件中，读取其json文件获取

In [1]:
import json

In [4]:
test_a_data = json.load(open('./ai_challenger_caption_test_a_20180103/caption_test_a_annotations_20180103.json', 'r'))
test_a_ids = [_['file_name']+'.jpg' for _ in test_a_data['images']]
test_a_ids = list(set(test_a_ids))

test_b_data = json.load(open('./ai_challenger_caption_test_b_20180103/caption_test_b_annotations_20180103.json', 'r'))
test_b_ids = [_['file_name']+'.jpg' for _ in test_b_data['images']]
test_b_ids = list(set(test_b_ids))

In [5]:
with open('./ICC_test3w_a_image_id.txt', 'w') as f:
    for _id in test_a_ids:
        f.write(str(_id) + '\n')
        
with open('./ICC_test3w_b_image_id.txt', 'w') as f:
    for _id in test_b_ids:
        f.write(str(_id) + '\n')

处理测试集test a和test b的ground truth文件，将其中的image_id替换为图像的文件名

（原本的image_id意义不明）

In [8]:
def process_ann_json(in_file, out_file):
    test_data = json.load(open(in_file, 'r'))
    id2filename = {}  # 原本的image_id到文件名的映射
    filenames = []    # 图像文件名统计
    for _ in test_data['images']:
        filenames.append(_['file_name']+'.jpg')
        if _['id'] not in id2filename:
            id2filename[_['id']] = _['file_name']
        else:
            assert _['file_name'] == id2filename[_['id']], 'Error'

    filenames = list(set(filenames))
    
    # 修改替换'images'字段
    test_data['images'] = [{'file_name': filename, 'id': filename.split('.')[0]} for filename in filenames]
    # 修改替换'annotations'字段
    for _ in test_data['annotations']:
        _['image_id'] = id2filename[_['image_id']]

    with open(out_file, 'w') as f:
        json.dump(test_data, f)

process_ann_json(
    './ai_challenger_caption_test_a_20180103/caption_test_a_annotations_20180103.json',
    './ICC_caption_test_a_annotations_20180103.json'
)
process_ann_json(
    './ai_challenger_caption_test_b_20180103/caption_test_b_annotations_20180103.json',
    './ICC_caption_test_b_annotations_20180103.json'
)