In [1]:
import os
import torch
import numpy as np
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision import models
from torch.nn import init
import tensorwatch as tw
import scipy.io

# 数据加载

In [2]:
data_dir = "/home/yxy/ReID/data/Market-1501/pytorch/"

galleryset = datasets.ImageFolder(os.path.join(data_dir, 'gallery'),
                                               transforms.Compose([
                                                   transforms.Resize((256,128), interpolation=3),
                                                   transforms.ToTensor(),
                                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                               ]))
galleryloaders = torch.utils.data.DataLoader(galleryset, batch_size=64,
                                             shuffle=False, num_workers=16) 


queryset = datasets.ImageFolder(os.path.join(data_dir, 'query'),
                                               transforms.Compose([
                                                   transforms.Resize((256,128), interpolation=3),
                                                   transforms.ToTensor(),
                                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                               ]))
queryloaders = torch.utils.data.DataLoader(queryset, batch_size=64,
                                             shuffle=False, num_workers=16) 

# 搭建模型

In [3]:
def weights_init_kaiming(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv') != -1:
        init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
    elif classname.find('Linear') != -1:
        init.kaiming_normal(m.weight.data, a=0, mode='fan_out')
        init.constant(m.bias.data, 0.0)
    elif classname.find('BatchNorm1d') != -1:
        init.normal(m.weight.data, 1.0, 0.02)
        init.constant(m.bias.data, 0.0)

def weights_init_classifier(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        init.normal(m.weight.data, std=0.001)
        init.constant(m.bias.data, 0.0)

In [4]:
# Defines the new fc layer and classification layer
# |--Linear--|--bn--|--relu--|--Linear--|
class ClassBlock(nn.Module):
    def __init__(self, input_dim, class_num, dropout=False, relu=False, num_bottleneck=512):
        super(ClassBlock, self).__init__()
        add_block = []
        #add_block += [nn.Linear(input_dim, num_bottleneck)] 
        num_bottleneck=input_dim
        add_block += [nn.BatchNorm1d(num_bottleneck)]
        if relu:
            add_block += [nn.LeakyReLU(0.1)]
        if dropout:
            add_block += [nn.Dropout(p=0.5)]
        add_block = nn.Sequential(*add_block)
        add_block.apply(weights_init_kaiming)

        classifier = []
        classifier += [nn.Linear(num_bottleneck, class_num)]
        classifier = nn.Sequential(*classifier)
        classifier.apply(weights_init_classifier)

        self.add_block = add_block
        self.classifier = classifier
    def forward(self, x):
        f = self.add_block(x)
        f_norm = f.norm(p=2, dim=1, keepdim=True) + 1e-8
        f = f.div(f_norm)
        x = self.classifier(f)
        return x,f

In [5]:
class ft_net(nn.Module):

    def __init__(self, class_num ):
        super(ft_net, self).__init__()
        model_ft = models.resnet50(pretrained=True)
        # avg pooling to global pooling
        model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.model = model_ft
        self.classifier = ClassBlock(2048, class_num, dropout=False, relu=False)
        # remove the final downsample
        # self.model.layer4[0].downsample[0].stride = (1,1)
        # self.model.layer4[0].conv2.stride = (1,1)
    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = self.model.avgpool(x)
        x = torch.squeeze(x)
        x,f = self.classifier(x)
        return x,f

# 加载预训练模型

In [6]:
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")

In [7]:
def load_network(network):
    save_path = os.path.join('./model/net_59.pth')
    network.load_state_dict(torch.load(save_path))
    return network

In [8]:
def get_id(img_path):
    camera_id = []
    labels = []
    for path, v in img_path:
        #filename = path.split('/')[-1]
        filename = os.path.basename(path)
        label = filename[0:4]
        camera = filename.split('c')[1]
        if label[0:2]=='-1':
            labels.append(-1)
        else:
            labels.append(int(label))
        camera_id.append(int(camera[0]))
    return camera_id, labels

gallery_path = galleryset.imgs
query_path = queryset.imgs
#mquery_path = multiqueryset.imgs

gallery_cam,gallery_label = get_id(gallery_path)
query_cam,query_label = get_id(query_path)
#mquery_cam,mquery_label = get_id(mquery_path)

In [9]:
model_structure = ft_net(751)
model = load_network(model_structure)
model = model.eval()
model = model.to(device)
model

  # Remove the CWD from sys.path while we load stuff.
  # This is added back by InteractiveShellApp.init_path()
  app.launch_new_instance()


ft_net(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (downsample): Sequential(
          (0): Conv2d(6

In [10]:
def fliplr(img):
    """水平翻转"""
    inv_idx = torch.arange(img.size(3)-1,-1,-1).long()
    img_filp = img.index_select(3,inv_idx)    
    return img_filp

In [11]:
def extract_feature(model, dataloaders):
    count = 0
    features = torch.FloatTensor()
    print("start extract feature")
    for i,data in enumerate(dataloaders):
        img,labels = data
        #print(img.size())
        n,c,h,w = img.size()   
        count += n
       
        ff = torch.FloatTensor(n,2048).zero_()
    
        for i in range(2):
            #if (i==1):
                #img = fliplr(img)
            input_img = Variable(img.to(device))
            outputs, f = model(input_img)
            f = f.data.cpu()
            ff = ff + f
        
        fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
        ff = ff.div(fnorm.expand_as(ff))
        features = torch.cat((features, ff),0)
    print("extract feature finish")
    return features

In [12]:
gallery_feature = extract_feature(model,galleryloaders)
query_feature = extract_feature(model,queryloaders)

start extract feature
extract feature finish
start extract feature
extract feature finish


In [13]:
result ={ 'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam,
         'query_f':query_feature.numpy(),'query_label':query_label,'query_cam':query_cam}
#result = {'mquery_f':mquery_feature.numpy(),'mquery_label':mquery_label,'mquery_cam':mquery_cam}
scipy.io.savemat('./feature/pytorch_result.mat',result)
#scipy.io.savemat('multi_query.mat',result)