# 特征提取

In [1]:
from __future__ import print_function, division

import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torchvision import datasets, transforms
import torch.backends.cudnn as cudnn
import matplotlib
#matplotlib.use('agg')
import matplotlib.pyplot as plt
#from PIL import Image
import time
import os
import yaml
import math
from shutil import copyfile
import random
import numpy as np
from torchvision import models
from torch.nn import init
from torch.autograd import Variable
import pretrainedmodels
import tensorwatch as tw
import scipy.io

## 数据加载

In [2]:
data_dir = "/home/yxy/ReID/data/DukeMTMC-reID/pytorch/"

galleryset = datasets.ImageFolder(os.path.join(data_dir, 'gallery'),
                                               transforms.Compose([
                                                   transforms.Resize((256,128), interpolation=3),
                                                   transforms.ToTensor(),
                                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                               ]))
galleryloaders = torch.utils.data.DataLoader(galleryset, batch_size=32,
                                             shuffle=False, num_workers=16) 


queryset = datasets.ImageFolder(os.path.join(data_dir, 'query'),
                                               transforms.Compose([
                                                   transforms.Resize((256,128), interpolation=3),
                                                   transforms.ToTensor(),
                                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                               ]))
queryloaders = torch.utils.data.DataLoader(queryset, batch_size=32,
                                             shuffle=False, num_workers=16) 

#multiqueryset = datasets.ImageFolder(os.path.join(data_dir, 'multi-query'),
#                                               transforms.Compose([
#                                                   transforms.Resize((256,128), interpolation=3),
#                                                   transforms.ToTensor(),
#                                                   transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
#                                               ]))
#multiqueryloaders = torch.utils.data.DataLoader(multiqueryset, batch_size=32,
#                                             shuffle=False, num_workers=16) 


In [3]:
print(len(galleryset))
print(len(queryset))
#print(len(multiqueryset))
print(galleryloaders)
print(queryloaders)
#print(galleryset.imgs)

17661
2228
<torch.utils.data.dataloader.DataLoader object at 0x7f6d06329e80>
<torch.utils.data.dataloader.DataLoader object at 0x7f6d062d76a0>


## 搭建模型

In [4]:
def weights_init_kaiming(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') # For old pytorch, you may use kaiming_normal.
    elif classname.find('Linear') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')
        init.constant_(m.bias.data, 0.0)
    elif classname.find('BatchNorm1d') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)

def weights_init_classifier(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        init.normal_(m.weight.data, std=0.001)
        init.constant_(m.bias.data, 0.0)

In [5]:
# Defines the new fc layer and classification layer
# |--Linear--|--bn--|--relu--|--Linear--|
class ClassBlock(nn.Module):
    def __init__(self, input_dim, class_num, droprate, relu=False, bnorm=True, num_bottleneck=512, linear=True, return_f = False):
        super(ClassBlock, self).__init__()
        self.return_f = return_f
        add_block = []
        if linear:
            add_block += [nn.Linear(input_dim, num_bottleneck)]
        else:
            num_bottleneck = input_dim
        if bnorm:
            add_block += [nn.BatchNorm1d(num_bottleneck)]
        if relu:
            add_block += [nn.LeakyReLU(0.1)]
        if droprate>0:
            add_block += [nn.Dropout(p=droprate)]
        add_block = nn.Sequential(*add_block)
        add_block.apply(weights_init_kaiming)

        classifier = []
        classifier += [nn.Linear(num_bottleneck, class_num)]
        classifier = nn.Sequential(*classifier)
        classifier.apply(weights_init_classifier)

        self.add_block = add_block
        self.classifier = classifier
    def forward(self, x):
        x = self.add_block(x)
        if self.return_f:
            f = x
            x = self.classifier(x)
            return x,f
        else:
            x = self.classifier(x)
            return x

In [6]:
class ft_net(nn.Module):

    def __init__(self, class_num, droprate=0.5, stride=2):
        super(ft_net, self).__init__()
        model_ft = models.resnet50(pretrained=True)
        if stride == 1:
            model_ft.layer4[0].downsample[0].stride = (1,1)
            model_ft.layer4[0].conv2.stride = (1,1)
        model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.model = model_ft
        self.classifier = ClassBlock(2048, class_num, droprate)

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = self.model.avgpool(x)
        x = x.view(x.size(0), x.size(1))
        x = self.classifier(x)
        return x

## 加载预训练模型

In [7]:
device = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")
print(device)

cuda:6


In [8]:
def load_network(network):
    save_path = os.path.join('./duke/model/resnet-duke.pth')
    network.load_state_dict(torch.load(save_path))
    return network

In [9]:
def get_id(img_path):
    camera_id = []
    labels = []
    for path, v in img_path:
        #filename = path.split('/')[-1]
        filename = os.path.basename(path)
        label = filename[0:4]
        camera = filename.split('c')[1]
        if label[0:2]=='-1':
            labels.append(-1)
        else:
            labels.append(int(label))
        camera_id.append(int(camera[0]))
    return camera_id, labels

gallery_path = galleryset.imgs
query_path = queryset.imgs
#mquery_path = multiqueryset.imgs

gallery_cam,gallery_label = get_id(gallery_path)
query_cam,query_label = get_id(query_path)
#mquery_cam,mquery_label = get_id(mquery_path)
#print(gallery_label)

In [10]:
model_structure = ft_net(702, stride = 2)
model = load_network(model_structure)
model.classifier.classifier = nn.Sequential()
model = model.eval()
model = model.to(device)
model
#tw.draw_model(model,[32, 3, 256, 128])

ft_net(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (downsample): Sequential(
          (0): Conv2d(6

In [11]:
def fliplr(img):
    """水平翻转"""
    inv_idx = torch.arange(img.size(3)-1,-1,-1).long()
    img_filp = img.index_select(3,inv_idx)    
    return img_filp

count = 0
features = torch.FloatTensor()
it = iter(galleryloaders)
batch = next(it)
img,labels = batch
print(labels)
print(img.size())
n,c,h,w = img.size()
count += n
print(count)
ff = torch.FloatTensor(n,512).zero_().to(device)
#print(ff)
input_img = img.to(device)
#print(input_img)
output = model(input_img)
ff += output
print(ff)

fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
print(fnorm)
ff = ff.div(fnorm.expand_as(ff))
print(ff)
print(ff.size())
features = torch.cat((features,ff.data.cpu()), 0)
print(features)

In [12]:
def extract_feature(model, dataloaders):
    count = 0
    features = torch.FloatTensor()
    for i,data in enumerate(dataloaders):
        img,labels = data
        #print(img.size())
        n,c,h,w = img.size()   
        count += n
        print(count)
        ff = torch.FloatTensor(n,512).zero_().to(device)
    
        for i in range(2):
            #if (i==1):
                #img = fliplr(img)
            input_img = img.to(device)
            outputs = model(input_img)
            
            ff += outputs
        
        fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
        ff = ff.div(fnorm.expand_as(ff))
        
        features = torch.cat((features, ff.data.cpu()),0)
    return features

In [13]:
with torch.no_grad():
    gallery_feature = extract_feature(model,galleryloaders)
    query_feature = extract_feature(model,queryloaders)
#mquery_feature = extract_feature(model,multiqueryloaders)

32
64
96
128
160
192
224
256
288
320
352
384
416
448
480
512
544
576
608
640
672
704
736
768
800
832
864
896
928
960
992
1024
1056
1088
1120
1152
1184
1216
1248
1280
1312
1344
1376
1408
1440
1472
1504
1536
1568
1600
1632
1664
1696
1728
1760
1792
1824
1856
1888
1920
1952
1984
2016
2048
2080
2112
2144
2176
2208
2240
2272
2304
2336
2368
2400
2432
2464
2496
2528
2560
2592
2624
2656
2688
2720
2752
2784
2816
2848
2880
2912
2944
2976
3008
3040
3072
3104
3136
3168
3200
3232
3264
3296
3328
3360
3392
3424
3456
3488
3520
3552
3584
3616
3648
3680
3712
3744
3776
3808
3840
3872
3904
3936
3968
4000
4032
4064
4096
4128
4160
4192
4224
4256
4288
4320
4352
4384
4416
4448
4480
4512
4544
4576
4608
4640
4672
4704
4736
4768
4800
4832
4864
4896
4928
4960
4992
5024
5056
5088
5120
5152
5184
5216
5248
5280
5312
5344
5376
5408
5440
5472
5504
5536
5568
5600
5632
5664
5696
5728
5760
5792
5824
5856
5888
5920
5952
5984
6016
6048
6080
6112
6144
6176
6208
6240
6272
6304
6336
6368
6400
6432
6464
6496
6528
6560
6592
6624

In [14]:
result ={ 'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam,
         'query_f':query_feature.numpy(),'query_label':query_label,'query_cam':query_cam}
#result = {'mquery_f':mquery_feature.numpy(),'mquery_label':mquery_label,'mquery_cam':mquery_cam}
scipy.io.savemat('./duke/pytorch_result.mat',result)
#scipy.io.savemat('multi_query.mat',result)