# 特征提取

In [1]:
import os
import torch
import numpy as np
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision import models
from torch.nn import init
import tensorwatch as tw
import scipy.io

## 数据加载

In [2]:
data_dir = "/home/yxy/ReID/data/DukeMTMC-reID/pytorch/"

galleryset = datasets.ImageFolder(os.path.join(data_dir, 'gallery'),
                                               transforms.Compose([
                                                   transforms.Resize((256,128), interpolation=3),
                                                   transforms.ToTensor(),
                                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                               ]))
galleryloaders = torch.utils.data.DataLoader(galleryset, batch_size=32,
                                             shuffle=False, num_workers=16) 


queryset = datasets.ImageFolder(os.path.join(data_dir, 'query'),
                                               transforms.Compose([
                                                   transforms.Resize((256,128), interpolation=3),
                                                   transforms.ToTensor(),
                                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                               ]))
queryloaders = torch.utils.data.DataLoader(queryset, batch_size=32,
                                             shuffle=False, num_workers=16) 

#multiqueryset = datasets.ImageFolder(os.path.join(data_dir, 'multi-query'),
#                                               transforms.Compose([
#                                                   transforms.Resize((256,128), interpolation=3),
#                                                   transforms.ToTensor(),
#                                                   transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
#                                               ]))
#multiqueryloaders = torch.utils.data.DataLoader(multiqueryset, batch_size=32,
#                                             shuffle=False, num_workers=16) 


## 搭建模型

In [3]:
def weights_init_kaiming(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') # For old pytorch, you may use kaiming_normal.
    elif classname.find('Linear') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')
        init.constant_(m.bias.data, 0.0)
    elif classname.find('BatchNorm1d') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)

def weights_init_classifier(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        init.normal_(m.weight.data, std=0.001)
        init.constant_(m.bias.data, 0.0)

In [4]:
class ClassBlock(nn.Module):
    def __init__(self, input_dim, class_num, droprate, relu=False, bnorm=True, num_bottleneck=512, linear=True, return_f = False):
        super(ClassBlock, self).__init__()
        self.return_f = return_f
        add_block = []
        if linear:
            add_block += [nn.Linear(input_dim, num_bottleneck)]
        else:
            num_bottleneck = input_dim
        if bnorm:
            add_block += [nn.BatchNorm1d(num_bottleneck)]
        if relu:
            add_block += [nn.LeakyReLU(0.1)]
        if droprate>0:
            add_block += [nn.Dropout(p=droprate)]
        add_block = nn.Sequential(*add_block)
        add_block.apply(weights_init_kaiming)

        classifier = []
        classifier += [nn.Linear(num_bottleneck, class_num)]
        classifier = nn.Sequential(*classifier)
        classifier.apply(weights_init_classifier)

        self.add_block = add_block
        self.classifier = classifier
    def forward(self, x):
        x = self.add_block(x)
        if self.return_f:
            f = x
            x = self.classifier(x)
            return x,f
        else:
            x = self.classifier(x)
            return x

In [5]:
class ft_net(nn.Module):

    def __init__(self, class_num, droprate=0.5, stride=2):
        super(ft_net, self).__init__()
        model_ft = models.resnet50(pretrained=True)
        if stride == 1:
            model_ft.layer4[0].downsample[0].stride = (1,1)
            model_ft.layer4[0].conv2.stride = (1,1)
        model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.model = model_ft
        self.classifier = ClassBlock(2048, class_num, droprate)

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = self.model.avgpool(x)
        x = x.view(x.size(0), x.size(1))
        x = self.classifier(x)
        return x

## 加载预训练模型

In [6]:
device = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")
print(device)

cuda:6


In [7]:
def load_network(network):
    save_path = os.path.join('./model/resnet_croess.pth')
    network.load_state_dict(torch.load(save_path))
    return network

In [8]:
def get_id(img_path):
    camera_id = []
    labels = []
    for path, v in img_path:
        #filename = path.split('/')[-1]
        filename = os.path.basename(path)
        label = filename[0:4]
        camera = filename.split('c')[1]
        if label[0:2]=='-1':
            labels.append(-1)
        else:
            labels.append(int(label))
        camera_id.append(int(camera[0]))
    return camera_id, labels

gallery_path = galleryset.imgs
query_path = queryset.imgs
#mquery_path = multiqueryset.imgs

gallery_cam,gallery_label = get_id(gallery_path)
query_cam,query_label = get_id(query_path)
#mquery_cam,mquery_label = get_id(mquery_path)

In [10]:
model_structure = ft_net(702, stride = 2)
model = load_network(model_structure)
model.classifier.classifier = nn.Sequential()
model = model.eval()
model = model.to(device)

In [11]:
def fliplr(img):
    """水平翻转"""
    inv_idx = torch.arange(img.size(3)-1,-1,-1).long()
    img_filp = img.index_select(3,inv_idx)    
    return img_filp

In [12]:
def extract_feature(model, dataloaders):
    count = 0
    features = torch.FloatTensor()
    print("start extract feature")
    for i,data in enumerate(dataloaders):
        img,labels = data
        n,c,h,w = img.size()   
        count += n
        ff = torch.FloatTensor(n,512).zero_().to(device)
    
        for i in range(2):
            input_img = img.to(device)
            outputs = model(input_img)
            
            ff += outputs
        
        fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
        ff = ff.div(fnorm.expand_as(ff))
        
        features = torch.cat((features, ff.data.cpu()),0)
    print("extract feature finish")
    
    return features

In [13]:
with torch.no_grad():
    gallery_feature = extract_feature(model,galleryloaders)
    query_feature = extract_feature(model,queryloaders)
    #mquery_feature = extract_feature(model,multiqueryloaders)

start extract feature
extract feature finish
start extract feature
extract feature finish


In [14]:
result ={ 'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam,
         'query_f':query_feature.numpy(),'query_label':query_label,'query_cam':query_cam}
#result = {'mquery_f':mquery_feature.numpy(),'mquery_label':mquery_label,'mquery_cam':mquery_cam}

scipy.io.savemat('./feature/pytorch_result.mat',result)
#scipy.io.savemat('multi_query.mat',result)