In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.autograd import Variable
import cv2
import os
import numpy as np
import torchvision.models as models
from matplotlib import pyplot as plt
import torchvision.transforms as transforms
from tqdm import tqdm


In [2]:
device = 'cuda' if torch.cuda.is_available()==True else 'cpu'

resnet50 = models.resnet50(pretrained=True).to(device)
modules=list(resnet50.children())[:-1]
resnet50=nn.Sequential(*modules)
for p in resnet50.parameters():
    p.requires_grad = False

In [4]:
#crop the instance region. For the images containing two instances, you need to crop both of them.
def query_crop(query_path, txt_path, save_path):
    query_img = cv2.imread(query_path)
    query_img = query_img[:,:,::-1] #bgr2rgb
    txt = np.loadtxt(txt_path)     #load the coordinates of the bounding box
    crop = query_img[int(txt[1]):int(txt[1] + txt[3]), int(txt[0]):int(txt[0] + txt[2]), :] #crop the instance region
    cv2.imwrite(save_path, crop[:,:,::-1])  #save the cropped region
    return crop

def resnet_extraction(img, featsave_path):
    resnet_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])])
    img_transform = resnet_transform(img).to(device) #normalize the input image and transform it to tensor.
    
    img_transform = torch.unsqueeze(img_transform, 0) #Set batchsize as 1. You can enlarge the batchsize to accelerate.
    
    feats = resnet50(img_transform) # extract feature
    feats_np = feats.cpu().detach().numpy() # convert tensor to numpy
    np.save(featsave_path, feats_np) # save the feature

# Note that I feed the whole image into the pretrained vgg11 model to extract the feature, which will lead to a poor retrieval performance.
# To extract more fine-grained features, you could preprocess the gallery images by cropping them using windows with different sizes and shapes.
# Hint: opencv provides some off-the-shelf tools for image segmentation.

def feat_extractor_gallery(gallery_dir, feat_savedir):
    for img_file in tqdm(os.listdir(gallery_dir)):
        img = cv2.imread(os.path.join(gallery_dir, img_file))
        img = img[:,:,::-1] #bgr2rgb
        img_resize = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC) # resize the image
        featsave_path = os.path.join(feat_savedir, img_file.split('.')[0]+'.npy')
        # return img_resize, featsave_path
        resnet_extraction(img_resize, featsave_path)

# Extract the query feature
def feat_extractor_query():
    query_dir = './datasets_4186/query_4186/'
    txt_dir = './datasets_4186/query_txt_4186/'
    save_dir =  './datasets_4186/query_cropped/'
    featsave_dir = './datasets_4186/query_feature/'
    for query_file in tqdm(os.listdir(query_dir)):
        if query_file.endswith(".DS_Store"):
            continue
        print(query_file)
        img_name = query_file[0:query_file.find('.')]
        txt_file = img_name+'.txt'
        featsave_file = img_name+'_feats.npy'
        query_path = os.path.join(query_dir, query_file)
        txt_path = os.path.join(txt_dir, txt_file)
        save_path = os.path.join(save_dir, query_file)
        featsave_path =os.path.join(featsave_dir, featsave_file) 
        crop = query_crop(query_path, txt_path, save_path)
        crop_resize = cv2.resize(crop, (224, 224), interpolation=cv2.INTER_CUBIC)
        # return crop_resize,featsave_path
        resnet_extraction(crop_resize, featsave_path)

def main():
    feat_extractor_query()
    gallery_dir = './datasets_4186/gallery_4186/'
    feat_savedir = './datasets_4186/gallery_feature/'
    feat_extractor_gallery(gallery_dir, feat_savedir)

if __name__=='__main__':
    main()

  0%|          | 0/21 [00:00<?, ?it/s]

1258.jpg


 19%|█▉        | 4/21 [00:07<00:26,  1.54s/it]

1656.jpg
1709.jpg
2032.jpg
2040.jpg
2176.jpg


 48%|████▊     | 10/21 [00:07<00:04,  2.45it/s]

2461.jpg
27.jpg
2714.jpg
316.jpg
35.jpg
3502.jpg


 76%|███████▌  | 16/21 [00:07<00:00,  5.44it/s]

3557.jpg
3833.jpg
3906.jpg
4354.jpg
4445.jpg
4716.jpg


100%|██████████| 21/21 [00:08<00:00,  2.60it/s]


4929.jpg
776.jpg


  1%|          | 60/5000 [00:02<02:48, 29.40it/s]


KeyboardInterrupt: 