This is a jupyter notebook file for the implementation of paper **"Multi-Camera Person Re-Identification using Spatiotemporal Context Modeling"**<br><br>
Author: Fatima Zulfiqar, Usama Ijaz Bajwa, Rana Hammad Raza

# Base

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import os.path as osp
import imdb


class BaseImgDataset(object):
    def __init__(self):
        self.train_lmdb_path = None
        self.query_lmdb_path = None
        self.gallery_lmdb_path = None

    def generate_lmdb(self):
        assert isinstance(self.train, list)
        assert isinstance(self.query, list)
        assert isinstance(self.gallery, list)
        
        print("Reminder: this function is under development, some datasets might not be applicable yet")

        self.train_lmdb_path = osp.join(self.dataset_dir, 'train_lmdb')
        self.query_lmdb_path = osp.join(self.dataset_dir, 'query_lmdb')
        self.gallery_lmdb_path = osp.join(self.dataset_dir, 'gallery_lmdb')

        def _write_lmdb(write_path, data_list):
            if osp.exists(write_path):
                return
            
            print("Generating lmdb files to '{}'".format(write_path))
            
            num_data = len(data_list)
            max_map_size = int(num_data * 500**2 * 3) # be careful with this
            env = lmdb.open(write_path, map_size=max_map_size)
            
            for img_path, pid, camid in data_list:
                with env.begin(write=True) as txn:
                    with open(img_path, 'rb') as imgf:
                        imgb = imgf.read()
                    txn.put(img_path, imgb)

        _write_lmdb(self.train_lmdb_path, self.train)
        _write_lmdb(self.query_lmdb_path, self.query)
        _write_lmdb(self.gallery_lmdb_path, self.gallery)

# Load Market1501 Dataset

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import glob
import re
import sys
import urllib
import tarfile
import zipfile
import os.path as osp
from scipy.io import loadmat
import numpy as np
import h5py
#from scipy.misc import imsave

#from base import BaseImgDataset


class Market1501_seg(BaseImgDataset):
    """
    Market1501

    Reference:
    Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.

    URL: http://www.liangzheng.org/Project/project_reid.html
    
    Dataset statistics:
    # identities: 1501 (+1 for background)
    # images: 12936 (train) + 3368 (query) + 15913 (gallery)
    """
    dataset_dir = '../Market-1501-v15.09.15' #path to dataset

    def __init__(self, verbose=True, use_lmdb=False, **kwargs):
        super(Market1501_seg, self).__init__()
        self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
        self.query_dir = osp.join(self.dataset_dir, 'query')
        self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')

        self.train_seg_dir = osp.join(self.dataset_dir, 'Market1501_train_seg_part4')

        self._check_before_run()

        train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, self.train_seg_dir, relabel=True)
        query, num_query_pids, num_query_imgs = self._process_dir(self.query_dir, relabel=False)
        gallery, num_gallery_pids, num_gallery_imgs = self._process_dir(self.gallery_dir, relabel=False)
        num_total_pids = num_train_pids + num_query_pids
        num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs

        if verbose:
            print("=> Market1501 loaded")
            print("Dataset statistics:")
            print("  ------------------------------")
            print("  subset   | # ids | # images")
            print("  ------------------------------")
            print("  train    | {:5d} | {:8d}".format(num_train_pids, num_train_imgs))
            print("  query    | {:5d} | {:8d}".format(num_query_pids, num_query_imgs))
            print("  gallery  | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs))
            print("  ------------------------------")
            print("  total    | {:5d} | {:8d}".format(num_total_pids, num_total_imgs))
            print("  ------------------------------")

        self.train = train
        self.query = query
        self.gallery = gallery

        self.num_train_pids = num_train_pids
        self.num_query_pids = num_query_pids
        self.num_gallery_pids = num_gallery_pids

        if use_lmdb:
            self.generate_lmdb()

    def _check_before_run(self):
        """Check if all files are available before going deeper"""
        if not osp.exists(self.dataset_dir):
            raise RuntimeError("'{}' is not available".format(self.dataset_dir))
        if not osp.exists(self.train_dir):
            raise RuntimeError("'{}' is not available".format(self.train_dir))
        if not osp.exists(self.query_dir):
            raise RuntimeError("'{}' is not available".format(self.query_dir))
        if not osp.exists(self.gallery_dir):
            raise RuntimeError("'{}' is not available".format(self.gallery_dir))
        if not osp.exists(self.train_seg_dir):
            raise RuntimeError("'{}' is not available".format(self.train_dir))

    def _process_dir(self, dir_path, seg_dir_path=None, relabel=False):
        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
        pattern = re.compile(r'([-\d]+)_c(\d)')

        pid_container = set()
        for img_path in img_paths:
            pid, _ = map(int, pattern.search(img_path).groups())
            if pid == -1: continue  # junk images are just ignored
            pid_container.add(pid)
        pid2label = {pid:label for label, pid in enumerate(pid_container)}

        dataset = []
        for img_path in img_paths:
            pid, camid = map(int, pattern.search(img_path).groups())
            if pid == -1: continue  # junk images are just ignored
            assert 0 <= pid <= 1501  # pid == 0 means background
            assert 1 <= camid <= 6
            camid -= 1 # index starts from 0
            if relabel: pid = pid2label[pid]

            if seg_dir_path is not None:
                img_name = img_path.split('/')[-1].split('.')[0]
                seg_dir = osp.join(seg_dir_path, img_name)
                head_path = osp.join(seg_dir, 'head.png')
                upper_body_path = osp.join(seg_dir, 'upper_clothes.png')
                lower_body_path = osp.join(seg_dir, 'lower_clothes.png')
                shoes_path = osp.join(seg_dir, 'shoes.png')
                foreground_path = osp.join(seg_dir, 'foreground.png')

                assert os.path.exists(head_path)
                assert os.path.exists(upper_body_path)
                assert os.path.exists(lower_body_path)
                assert os.path.exists(shoes_path)
                assert os.path.exists(foreground_path)
                img_path = [img_path, head_path, upper_body_path, lower_body_path, shoes_path, foreground_path]

            dataset.append((img_path, pid, camid))

        num_pids = len(pid_container)
        num_imgs = len(dataset)
        return dataset, num_pids, num_imgs
    
if __name__ == '__main__':
    Market1501_seg()

# Load DukeMTMC Dataset

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import glob
import re
import sys
import urllib
import tarfile
import zipfile
import os.path as osp
from scipy.io import loadmat
import numpy as np
import h5py
#from scipy.misc import imsave

#from base import BaseImgDataset


class DukeMTMC_seg(BaseImgDataset):
    """
    DukeMTMC

    Reference:
    Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.

    URL: http://www.liangzheng.org/Project/project_reid.html
    
    Dataset statistics:
    # identities: 1501 (+1 for background)
    # images: 12936 (train) + 3368 (query) + 15913 (gallery)
    """
    dataset_dir = '.../DukeMTMC-reID' # path to dataset

    def __init__(self, verbose=True, use_lmdb=False, **kwargs):
        super(DukeMTMC_seg, self).__init__()
        self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
        self.query_dir = osp.join(self.dataset_dir, 'query')
        self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')

        self.train_seg_dir = osp.join(self.dataset_dir, 'DukeMTMC_train_seg_part4')

        self._check_before_run()

        train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, self.train_seg_dir, relabel=True)
        query, num_query_pids, num_query_imgs = self._process_dir(self.query_dir, relabel=False)
        gallery, num_gallery_pids, num_gallery_imgs = self._process_dir(self.gallery_dir, relabel=False)
        num_total_pids = num_train_pids + num_query_pids
        num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs

        if verbose:
            print("=> DukeMTMC-ReID loaded")
            print("Dataset statistics:")
            print("  ------------------------------")
            print("  subset   | # ids | # images")
            print("  ------------------------------")
            print("  train    | {:5d} | {:8d}".format(num_train_pids, num_train_imgs))
            print("  query    | {:5d} | {:8d}".format(num_query_pids, num_query_imgs))
            print("  gallery  | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs))
            print("  ------------------------------")
            print("  total    | {:5d} | {:8d}".format(num_total_pids, num_total_imgs))
            print("  ------------------------------")

        self.train = train
        self.query = query
        self.gallery = gallery

        self.num_train_pids = num_train_pids
        self.num_query_pids = num_query_pids
        self.num_gallery_pids = num_gallery_pids

        if use_lmdb:
            self.generate_lmdb()

    def _check_before_run(self):
        """Check if all files are available before going deeper"""
        if not osp.exists(self.dataset_dir):
            raise RuntimeError("'{}' is not available".format(self.dataset_dir))
        if not osp.exists(self.train_dir):
            raise RuntimeError("'{}' is not available".format(self.train_dir))
        if not osp.exists(self.query_dir):
            raise RuntimeError("'{}' is not available".format(self.query_dir))
        if not osp.exists(self.gallery_dir):
            raise RuntimeError("'{}' is not available".format(self.gallery_dir))
        if not osp.exists(self.train_seg_dir):
            raise RuntimeError("'{}' is not available".format(self.train_dir))

    def _process_dir(self, dir_path, seg_dir_path=None, relabel=False):
        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
        pattern = re.compile(r'([-\d]+)_c(\d)')

        pid_container = set()
        for img_path in img_paths:
            pid, _ = map(int, pattern.search(img_path).groups())
            #rint(pid)
            if pid == -1: continue  # junk images are just ignored
            pid_container.add(pid)
        pid2label = {pid:label for label, pid in enumerate(pid_container)}

        dataset = []
        for img_path in img_paths:
            pid, camid = map(int, pattern.search(img_path).groups())
            #print(pid)
            if pid == -1: continue  # junk images are just ignored
            #rint(pid)
            assert 0 <= pid <= 7140                           
            assert 1 <= camid <= 8
            camid -= 1 # index starts from 0
            if relabel: pid = pid2label[pid]

            if seg_dir_path is not None:
                img_name = img_path.split('/')[-1].split('.')[0]
                seg_dir = osp.join(seg_dir_path, img_name)
                head_path = osp.join(seg_dir, 'head.png')
                upper_body_path = osp.join(seg_dir, 'upper_clothes.png')
                lower_body_path = osp.join(seg_dir, 'lower_clothes.png')
                shoes_path = osp.join(seg_dir, 'shoes.png')
                foreground_path = osp.join(seg_dir, 'foreground.png')

                assert os.path.exists(head_path)
                assert os.path.exists(upper_body_path)
                assert os.path.exists(lower_body_path)
                assert os.path.exists(shoes_path)
                assert os.path.exists(foreground_path)
                img_path = [img_path, head_path, upper_body_path, lower_body_path, shoes_path, foreground_path]

            dataset.append((img_path, pid, camid))

        num_pids = len(pid_container)
        #rint("No of ID's: ",num_pids)
        num_imgs = len(dataset)
        #rint("No of Images: ",num_imgs)
        return dataset, num_pids, num_imgs
    
if __name__ == '__main__':
    DukeMTMC_seg()

# Load MSMT17 Dataset

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import glob
import re
import sys
import urllib
import tarfile
import zipfile
import os.path as osp
from scipy.io import loadmat
import numpy as np
import h5py

#from ..dataset import ImageDataset

# Log
# 22.01.2019
# - add v2
# - v1 and v2 differ in dir names
# - note that faces in v2 are blurred
TRAIN_DIR_KEY = 'train_dir'
TEST_DIR_KEY = 'test_dir'
VERSION_DICT = {
    'MSMT17_V1': {
        TRAIN_DIR_KEY: 'train',
        TEST_DIR_KEY: 'test',
    },
    'MSMT17_V2': {
        TRAIN_DIR_KEY: 'mask_train_v2',
        TEST_DIR_KEY: 'mask_test_v2',
    }
}


class MSMT17_seg(BaseImgDataset):
    """MSMT17.
    Reference:
        Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018.
    URL: `<http://www.pkuvmc.com/publications/msmt17.html>`_
    
    Dataset statistics:
        - identities: 4101.
        - images: 32621 (train) + 11659 (query) + 82161 (gallery).
        - cameras: 15.
    """
    dataset_dir = '/content/' # set path to dataset
    
    def __init__(self, verbose=True, use_lmdb=False, **kwargs):
        super(MSMT17_seg, self).__init__()

        has_main_dir = False
        
        for main_dir in VERSION_DICT:
            if osp.exists(osp.join(self.dataset_dir, main_dir)):
                train_dir = VERSION_DICT[main_dir][TRAIN_DIR_KEY]
                test_dir = VERSION_DICT[main_dir][TEST_DIR_KEY]
                has_main_dir = True        
                break
        
        assert has_main_dir, 'Dataset folder not found'

        self.train_dir = osp.join(self.dataset_dir, main_dir, train_dir)
        self.test_dir = osp.join(self.dataset_dir, main_dir, test_dir)
        self.train_seg_dir = osp.join(self.dataset_dir, main_dir, train_dir, 'MSMT17_train_seg_part4')
        
        self.list_train_path = osp.join(self.dataset_dir, main_dir, 'list_train.txt')
        self.list_train_seg_path = osp.join(self.dataset_dir, main_dir, 'list_train.txt')
        self.list_val_path = osp.join(self.dataset_dir, main_dir, 'list_val.txt')
        self.list_query_path = osp.join(self.dataset_dir, main_dir, 'list_query.txt')
        self.list_gallery_path = osp.join(self.dataset_dir, main_dir, 'list_gallery.txt')

        #required_files = [self.dataset_dir, self.train_dir, self.test_dir, self.train_seg_dir]
        self._check_before_run() # or pass required_files variable that is declared above
        
        train, num_train_pids, num_train_imgs = self.process_dir(self.train_dir, self.list_train_path, self.train_seg_dir, relabel = True )
        val, num_val_pids, num_val_imgs = self.process_dir(self.train_dir, self.list_val_path, relabel = False)
        query, num_query_pids, num_query_imgs = self.process_dir(self.test_dir, self.list_query_path, relabel = False)
        gallery, num_gallery_pids, num_gallery_imgs = self.process_dir(self.test_dir, self.list_gallery_path, relabel = False)
        
        num_total_pids = num_train_pids + num_query_pids
        num_total_imgs = num_train_imgs + num_val_imgs + num_query_imgs + num_gallery_imgs

        # Note: to fairly compare with published methods on the conventional ReID setting,
        #       do not add val images to the training set.
        if 'combineall' in kwargs and kwargs['combineall']:
            train += val

        if verbose:
            print("=> MSMT17 loaded")
            print("Dataset statistics:")
            print("  ------------------------------")
            print("  subset   | # ids | # images")
            print("  ------------------------------")
            print("  train    | {:5d} | {:8d}".format(num_train_pids, num_train_imgs))
            print("  val      | {:5d} | {:8d}".format(num_val_pids, num_val_imgs))
            print("  query    | {:5d} | {:8d}".format(num_query_pids, num_query_imgs))
            print("  gallery  | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs))
            print("  ------------------------------")
            print("  total    | {:5d} | {:8d}".format(num_total_pids, num_total_imgs))
            print("  ------------------------------")

        self.train = train
        self.val = val
        self.query = query
        self.gallery = gallery

        self.num_train_pids = num_train_pids
        self.num_val_pids = num_val_pids
        self.num_query_pids = num_query_pids
        self.num_gallery_pids = num_gallery_pids

        if use_lmdb:
            self.generate_lmdb()

    def _check_before_run(self):
        """Check if all files are available before going deeper"""
        if not osp.exists(self.dataset_dir):
            raise RuntimeError("'{}' is not available".format(self.dataset_dir))
        if not osp.exists(self.train_dir):
            raise RuntimeError("'{}' is not available".format(self.train_dir))
        if not osp.exists(self.test_dir):
            raise RuntimeError("'{}' is not available".format(self.test_dir))
        if not osp.exists(self.list_val_path):
            raise RuntimeError("'{}' is not available".format(self.list_val_path))
        if not osp.exists(self.list_query_path):
            raise RuntimeError("'{}' is not available".format(self.list_query_path))
        if not osp.exists(self.list_gallery_path):
            raise RuntimeError("'{}' is not available".format(self.list_gallery_path))
        if not osp.exists(self.train_seg_dir):
            raise RuntimeError("'{}' is not available".format(self.train_seg_dir)) # train_dir

    def process_dir(self, dir_path, list_path, seg_dir_path=None, relabel=False ): 
        with open(list_path, 'r') as txt:
            lines = txt.readlines()

        dataset = []
        pid_container = set()
        for img_idx, img_info in enumerate(lines):
            img_path, pid = img_info.split(' ')
            pid = int(pid) # no need to relabel
            pid_container.add(pid)
            camid = int(img_path.split('_')[2]) - 1 # index starts from 0
            img_path = osp.join(dir_path, img_path)
            
            if seg_dir_path is not None:
                img_path = img_path.replace('\\','/')
                img_name = img_path.split('/')[-1].split('.')[0]
                seg_dir = osp.join(seg_dir_path, img_name)
                head_path = osp.join(seg_dir, 'head.png')
                upper_body_path = osp.join(seg_dir, 'upper_clothes.png')
                lower_body_path = osp.join(seg_dir, 'lower_clothes.png')
                shoes_path = osp.join(seg_dir, 'shoes.png')
                foreground_path = osp.join(seg_dir, 'foreground.png')

                assert os.path.exists(head_path)
                assert os.path.exists(upper_body_path)
                assert os.path.exists(lower_body_path)
                assert os.path.exists(shoes_path)
                assert os.path.exists(foreground_path)
                img_path = [img_path, head_path, upper_body_path, lower_body_path, shoes_path, foreground_path]

            dataset.append((img_path, pid, camid))


        num_pids = len(pid_container)
        #rint("No of ID's: ",num_pids)
        num_imgs = len(dataset)
        #rint("No of Images: ",num_imgs)
        
        return dataset, num_pids, num_imgs
     
if __name__ == '__main__':
    MSMT17_seg()

# Initialize Dataset

run market_seg.py and uncomment market_seg to train for market1501

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


__imgreid_factory = {
    'market1501_seg': Market1501_seg#,
   #'DukeMTMC_seg': DukeMTMC_seg#,
   #'MSMT17_seg': MSMT17_seg
}



def get_names():
    return list(__imgreid_factory.keys())


def init_imgreid_dataset(name, **kwargs):
    if name not in list(__imgreid_factory.keys()):
        raise KeyError("Invalid dataset, got '{}', but expected to be one of {}".format(name, list(__imgreid_factory.keys())))
    return __imgreid_factory[name](**kwargs)

# models Folder

# Spatial Attention Module (SAM)

In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import math


class Gconv(nn.Module):
    def __init__(self, in_channels):
        super(Gconv, self).__init__()
        fsm_blocks = []
        fsm_blocks.append(nn.Conv2d(in_channels * 2, in_channels, 1))
        fsm_blocks.append(nn.BatchNorm2d(in_channels))
        fsm_blocks.append(nn.ReLU(inplace=True))
        self.fsm = nn.Sequential(*fsm_blocks)
        # init
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, W, x):
        bs, n, c = x.size()

        x_neighbor = torch.bmm(W, x)  
        x = torch.cat([x, x_neighbor], 2) 
        x = x.view(-1, x.size(2), 1, 1) 
        x = self.fsm(x) 
        x = x.view(bs, n, c)
        return x 


class Wcompute(nn.Module):
    def __init__(self, in_channels):
        super(Wcompute, self).__init__()
        self.in_channels = in_channels

        edge_block = []
        edge_block.append(nn.Conv2d(in_channels * 2, 1, 1))
        edge_block.append(nn.BatchNorm2d(1))
        self.relation = nn.Sequential(*edge_block)

        #init
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
        

    def forward(self, x, W_id, y):
        bs, N, C = x.size()

        W1 = x.unsqueeze(2) 
        W2 = torch.transpose(W1, 1, 2) 
        W_new = torch.abs(W1 - W2)
        W_new = torch.transpose(W_new, 1, 3) 
        y = y.view(bs, C, 1, 1).expand_as(W_new)
        W_new = torch.cat((W_new, y), 1) 

        W_new = self.relation(W_new) 
        W_new = torch.transpose(W_new, 1, 3) 
        W_new = W_new.squeeze(3) 

        W_new = W_new - W_id.expand_as(W_new) * 1e8
        W_new = F.softmax(W_new, dim=2)
        return W_new


class SAM(nn.Module):
    def __init__(self, in_channels):
        super(SAM, self).__init__()
        self.in_channels = in_channels
        self.module_w = Wcompute(in_channels)
        self.module_l = Gconv(in_channels)

    def forward(self, x, y):
        bs, N, C = x.size()

        W_init = torch.eye(N).unsqueeze(0) 
        W_init = W_init.repeat(bs, 1, 1).cuda() 
        W = self.module_w(x, W_init, y) 
        s = self.module_l(W, x) 
        return s

# Interconnection, Accumulation Operation

In [None]:
from __future__ import absolute_import

import torch
import math
from torch import nn
from torch.nn import functional as F
import numpy as np


def generate_grid(h, w):
    x = np.linspace(0, w-1, w)
    y = np.linspace(0, h-1, h)
    xv, yv = np.meshgrid(x, y)
    xv = xv.flatten()
    yv = yv.flatten()
    return xv, yv

def generate_gaussian(height, width, alpha_x, alpha_y):
    Dis = np.zeros((height*width, height*width))
    xv, yv = generate_grid(height, width)
    for i in range(0, width):
        for j in range(0, height):
            d = (np.square(xv - i))/ (2 * alpha_x**2)  + (np.square(yv - j)) / (2 * alpha_y**2)
            Dis[i+j*width] = -1 *  d 
    Dis = torch.from_numpy(Dis).float()
    Dis = F.softmax(Dis, dim=-1)
    return Dis

# Interonnection and accumulation operation
class _IABlockND(nn.Module):
    def __init__(self, in_channels, height, width,
            alpha_x, alpha_y):
        super(_IABlockND, self).__init__()

        self.in_channels = in_channels
        conv_nd = nn.Conv2d
        max_pool = nn.MaxPool2d
        bn = nn.BatchNorm2d

        self.Dis = generate_gaussian(height=height, width=width, alpha_x=alpha_x, alpha_y=alpha_y)
        self.W1 = bn(self.in_channels)
        self.W2 = bn(self.in_channels)
        
        # init
        for m in self.modules():
            if isinstance(m, conv_nd):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, bn):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        nn.init.constant_(self.W1.weight.data, 0.0)
        nn.init.constant_(self.W1.bias.data, 0.0)
        nn.init.constant_(self.W2.weight.data, 0.0)
        nn.init.constant_(self.W2.bias.data, 0.0)


    def forward(self, x):
        '''
        :param x: (b, c, h, w)
        :return:
        '''

        batch_size = x.size(0)

        g_x = x.view(batch_size, self.in_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        f_cluster = []
        f_loc = torch.unsqueeze(self.Dis.cuda(), 0)
        f_loc = f_loc.expand(batch_size, -1, -1)
        f_cluster.append(torch.unsqueeze(f_loc, 1))

        theta_x = x.view(batch_size, self.in_channels, -1)
        theta_x = theta_x.permute(0, 2, 1) #[B, H*W, C]
        phi_x = x.view(batch_size, self.in_channels, -1)
        f = torch.matmul(theta_x, phi_x) #[B, H*W, H*W]
        f = f / np.sqrt(self.in_channels)
        f = F.softmax(f, dim=-1)
        f_cluster.append(torch.unsqueeze(f, 1))
        
        f_cluster = torch.cat(f_cluster, 1)
        f = torch.prod(f_cluster, dim=1)
        f = F.softmax(f, dim=-1)

        y = torch.matmul(f, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.in_channels, *x.size()[2:])
        y = self.W1(y)
        z = y + x

        x = z
        g_x = x.view(batch_size, self.in_channels, -1) #[B, c, h*w]
        theta_x = g_x #[B, c, h*w]
        phi_x = g_x.permute(0, 2, 1) #[B, h*w, c]
        f = torch.matmul(theta_x, phi_x) #[B, c, c]
        f = F.softmax(f, dim=-1)
        y = torch.matmul(f, g_x)
        y = y.view(batch_size, self.in_channels, *x.size()[2:])
        y = self.W2(y)
        z = y + x

        return z



class IABlock2D(_IABlockND):
    def __init__(self, in_channels, height, width, alpha_x, alpha_y, **kwargs):
        super(IABlock2D, self).__init__(in_channels, height=height, width=width,
                                alpha_x=alpha_x, alpha_y=alpha_y)

# Interconnetion, Accumulation, and Transformation Operations

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import math

class ConvBlock(nn.Module):
    """Basic convolutional block"""
    def __init__(self, in_c, out_c, k, s=1, p=0):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
        self.bn = nn.BatchNorm2d(out_c)

    def forward(self, x):
        return self.bn(self.conv(x))


class SpatialAttn(nn.Module):
    """Spatial Attention """
    def __init__(self, in_channels, number):
        super(SpatialAttn, self).__init__()
        self.conv = ConvBlock(in_channels, number, 1)

    def forward(self, x):
        x = self.conv(x) 
        a = torch.sigmoid(x)
        return a


class IAT(nn.Module):
    def __init__(self, in_channels):
        super(IAT, self).__init__()

        inter_stride = 2
        self.in_channels = in_channels
        conv_nd = nn.Conv2d
        bn = nn.BatchNorm2d
        self.inter_channels = in_channels // inter_stride

        self.sa = SpatialAttn(in_channels, number=4)

        self.g = conv_nd(self.in_channels, self.inter_channels, kernel_size=1, stride=1, padding=0, bias=True)
        self.SAM = SAM(self.inter_channels)

        self.W1 = nn.Sequential(
                conv_nd(self.in_channels, self.in_channels,
                    kernel_size=1, stride=1, padding=0, bias=True),
                bn(self.in_channels)
            )
        
        self.W2 = nn.Sequential(
                conv_nd(self.in_channels, self.in_channels,
                    kernel_size=1, stride=1, padding=0, bias=True),
                bn(self.in_channels)
            )
        
        # init
        for m in self.modules():
            if isinstance(m, conv_nd):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, bn):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        nn.init.constant_(self.W1[1].weight.data, 0.0)
        nn.init.constant_(self.W1[1].bias.data, 0.0)
        nn.init.constant_(self.W2[1].weight.data, 0.0)
        nn.init.constant_(self.W2[1].bias.data, 0.0)


    def reduce_dimension(self, x, global_node):
        bs, c = global_node.size()

        x = x.transpose(1, 2).unsqueeze(3) 
        x = torch.cat((x, global_node.view(bs, c, 1, 1)), 2) 
        x = self.g(x).squeeze(3) 

        global_node = x[:,:,-1] 
        x = x[:,:,:-1].transpose(1, 2) 
        return x, global_node


    def forward(self, x):
        # CAM
        batch_size = x.size(0)

        g_x = x.view(batch_size, self.in_channels, -1)
        theta_x = g_x 
        phi_x = g_x.permute(0, 2, 1) 
        f = torch.matmul(theta_x, phi_x) 
        f = F.softmax(f, dim=-1)
        y = torch.matmul(f, g_x)
        y = y.view(batch_size, self.in_channels, *x.size()[2:])
        y = self.W1(y)
        z = y + x

        # SAM
        x = z
        inputs = x
        b, c, h, w = x.size()
        u = x.view(b, c, -1).mean(2)

        a = self.sa(x) 
        x = torch.bmm(a.view(b, -1, h * w), x.view(b, c, -1).transpose(1, 2)) 
        x, u = self.reduce_dimension(x, u)
        y = self.SAM(x, u) 

        y = torch.mean(y, 1) #[b, c//2]
        u = torch.cat((y, u), 1) 

        y = self.W2(u.view(u.size(0), u.size(1), 1, 1))
        z = y + inputs
        return z, a

# ResNets

In [None]:
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo


__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50_s1', 'resnet101',
           'resnet152']


model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3])
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def resnet18_s1(pretrained=True, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model


def resnet34_s1(pretrained=True, **kwargs):
    """Constructs a ResNet-34 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
    return model


def resnet50_s1(pretrained=True, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
    return model


def resnet101_s1(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
    return model


def resnet152(pretrained=False, **kwargs):
    """Constructs a ResNet-152 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
    return model

# STCANet2D

In [None]:
from __future__ import absolute_import
from __future__ import division

import torch
from torch import nn
from torch.nn import functional as F
import torchvision

class STCANet(nn.Module):
    def __init__(self, num_classes):
        super(STCANet, self).__init__()
        resnet50 = resnet50_s1(pretrained=True)

        self.conv1 = resnet50.conv1
        self.bn1 = resnet50.bn1
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = resnet50.maxpool

        self.layer1 = resnet50.layer1
        self.layer2 = resnet50.layer2
        self.layer3 = resnet50.layer3 
        self.layer4 = resnet50.layer4 

        self.IAT2 = IAT(512) 
        self.IAT3 = IAT(1024)

        self.feat_dim = 2048
        self.bn = nn.BatchNorm1d(self.feat_dim)
        self.classifier = nn.Linear(self.feat_dim, num_classes)

    def forward(self, x):
        x = self.bn1(self.conv1(x))
        x = self.maxpool(self.relu(x))

        x1 = self.layer1(x) 

        x2 = self.layer2(x1) 
        x2, a2 = self.IAT2(x2) 

        x3 = self.layer3(x2) 
        x3, a3 = self.IAT3(x3) 

        x4 = self.layer4(x3) 

        f = F.avg_pool2d(x4, x4.size()[2:])
        f = f.view(f.size(0), -1)
        f = self.bn(f)
        y = self.classifier(f)

        a_head = [a2[:,0:1], a3[:,0:1]]
        a_upper = [a2[:,1:2], a3[:,1:2]]
        a_lower = [a2[:,2:3], a3[:,2:3]]
        a_shoes = [a2[:,3:4], a3[:,3:4]]

        return y, f, a_head, a_upper, a_lower, a_shoes

# IAResnet

In [None]:
from __future__ import absolute_import
from __future__ import division

import torch
from torch import nn
from torch.nn import functional as F
import torchvision

class IAResNet50_location(nn.Module):

    def __init__(self, last_s1, num_classes,  **kwargs):

        super(IAResNet50_location, self).__init__()

        if not last_s1:
            resnet50 = torchvision.models.resnet50(pretrained=True)
        else:
            resnet50 = resnet50_s1(pretrained=True)
        
        self.conv1 = resnet50.conv1
        self.bn1 = resnet50.bn1
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = resnet50.maxpool

        self.layer1 = self._inflate_reslayer(resnet50.layer1)
        self.layer2 = self._inflate_reslayer(resnet50.layer2, IA_idx=[3], height=32,
                                    width=16, alpha_x=10, alpha_y=20, IA_channels=512)
        self.layer3 = self._inflate_reslayer(resnet50.layer3, IA_idx=[5], height=16,
                                    width=8, alpha_x=5, alpha_y=10, IA_channels=1024)
        self.layer4 = self._inflate_reslayer(resnet50.layer4)

        
        self.bn = nn.BatchNorm1d(2048)
        self.classifier = nn.Linear(2048, num_classes)


    def _inflate_reslayer(self, reslayer, height=0, width=0,
                    alpha_x=0, alpha_y=0, IA_idx=[], IA_channels=0):
        reslayers = []
        for i, layer2d in enumerate(reslayer):
            reslayers.append(layer2d)

            if i in IA_idx:
                IA_block = IABlock2D(in_channels=IA_channels, height=height,
                        width=width, alpha_x=alpha_x, alpha_y=alpha_y)
                reslayers.append(IA_block)

        return nn.Sequential(*reslayers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        f = F.avg_pool2d(x, x.size()[2:])
        f = f.view(f.size(0), -1)
        f = self.bn(f)
        if not self.training:
            return f
        y = self.classifier(f)

        return y, f

# Model Initialization

In [None]:
from __future__ import absolute_import

__model_factory = {
        'STCANet': STCANet,
}


def get_names():
    return list(__model_factory.keys())


def init_model(name, *args, **kwargs):
    if name not in list(__model_factory.keys()):
        raise KeyError("Unknown model: {}".format(name))
    return __model_factory[name](*args, **kwargs)

# Metrics Folder

# distance

In [None]:
from __future__ import division, print_function, absolute_import
import torch
from torch.nn import functional as F


def compute_distance_matrix(input1, input2, metric): # default metric = 'euclidean'
    
    """A wrapper function for computing distance matrix.
    Args:
        input1 (torch.Tensor): 2-D feature matrix.
        input2 (torch.Tensor): 2-D feature matrix.
        metric (str, optional): "euclidean" or "cosine".
            Default is "euclidean".
    Returns:
        torch.Tensor: distance matrix.
    Examples::
       >>> from torchreid import metrics
       >>> input1 = torch.rand(10, 2048)
       >>> input2 = torch.rand(100, 2048)
       >>> distmat = metrics.compute_distance_matrix(input1, input2)
       >>> distmat.size() # (10, 100)
    """
    # check input
    assert isinstance(input1, torch.Tensor)
    assert isinstance(input2, torch.Tensor)
    assert input1.dim() == 2, 'Expected 2-D tensor, but got {}-D'.format(
        input1.dim()
    )
    assert input2.dim() == 2, 'Expected 2-D tensor, but got {}-D'.format(
        input2.dim()
    )
    assert input1.size(1) == input2.size(1)

    if metric == 'euclidean':
        distmat = euclidean_squared_distance(input1, input2)
    elif metric == 'cosine':
        distmat = cosine_distance(input1, input2)
    else:
        raise ValueError(
            'Unknown distance metric: {}. '
            'Please choose either "euclidean" or "cosine"'.format(metric)
        )

    return distmat


def euclidean_squared_distance(input1, input2):
    """Computes euclidean squared distance.
    Args:
        input1 (torch.Tensor): 2-D feature matrix.
        input2 (torch.Tensor): 2-D feature matrix.
    Returns:
        torch.Tensor: distance matrix.
    """
    m, n = input1.size(0), input2.size(0)
    mat1 = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n)
    mat2 = torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
    distmat = mat1 + mat2
    distmat.addmm_(input1, input2.t(), beta=1, alpha=-2)
    return distmat


def cosine_distance(input1, input2):
    """Computes cosine distance.
    Args:
        input1 (torch.Tensor): 2-D feature matrix.
        input2 (torch.Tensor): 2-D feature matrix.
    Returns:
        torch.Tensor: distance matrix.
    """
    input1_normed = F.normalize(input1, p=2, dim=1)
    input2_normed = F.normalize(input2, p=2, dim=1)
    distmat = 1 - torch.mm(input1_normed, input2_normed.t())
    return distmat

# Utils Folder

# averagemeter

In [None]:
from __future__ import absolute_import
from __future__ import division


class AverageMeter(object):
    """Computes and stores the average and current value.
       
       Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

# iotools

In [None]:
from __future__ import absolute_import

import os
import os.path as osp
import errno
import json
import shutil

import torch


def mkdir_if_missing(directory):
    if not osp.exists(directory):
        try:
            os.makedirs(directory)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise


def check_isfile(path):
    isfile = osp.isfile(path)
    if not isfile:
        print("=> Warning: no file found at '{}' (ignored)".format(path))
    return isfile


def read_json(fpath):
    with open(fpath, 'r') as f:
        obj = json.load(f)
    return obj


def write_json(obj, fpath):
    mkdir_if_missing(osp.dirname(fpath))
    with open(fpath, 'w') as f:
        json.dump(obj, f, indent=4, separators=(',', ': '))


def save_checkpoint(state, is_best=False, fpath='checkpoint.pth.tar'):
    if len(osp.dirname(fpath)) != 0:
        mkdir_if_missing(osp.dirname(fpath))
    torch.save(state, fpath)
    if is_best:
        shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar'))

# Logger

In [None]:
from __future__ import absolute_import

import sys
import os
import os.path as osp

#from .iotools import mkdir_if_missing


class Logger(object):
    """
    Write console output to external text file.
    Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py.
    """
    def __init__(self, fpath=None, mode='w'):
        self.console = sys.stdout
        self.file = None
        if fpath is not None:
            mkdir_if_missing(osp.dirname(fpath))
            self.file = open(fpath, mode)

    def __del__(self):
        self.close()

    def __enter__(self):
        pass

    def __exit__(self, *args):
        self.close()

    def write(self, msg):
        self.console.write(msg)
        if self.file is not None:
            self.file.write(msg)

    def flush(self):
        self.console.flush()
        if self.file is not None:
            self.file.flush()
            os.fsync(self.file.fileno())

    def close(self):
        self.console.close()
        if self.file is not None:
            self.file.close()

# Reidtools (Visulaization of ranked results)

## New

In [None]:
from __future__ import print_function, absolute_import
import numpy as np
import shutil
import os.path as osp
import cv2

#from .tools import mkdir_if_missing

__all__ = ['visualize_ranked_results']

GRID_SPACING = 10
QUERY_EXTRA_SPACING = 90
BW = 5 # border width
GREEN = (0, 255, 0)
RED = (0, 0, 255)


def visualize_ranked_results(
    distmat, dataset, data_type, width=128, height=256, save_dir='', topk=10
):
    """Visualizes ranked results.
    Supports both image-reid and video-reid.
    For image-reid, ranks will be plotted in a single figure. For video-reid, ranks will be
    saved in folders each containing a tracklet.
    Args:
        distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery).
        dataset (tuple): a 2-tuple containing (query, gallery), each of which contains
            tuples of (img_path(s), pid, camid, dsetid).
        data_type (str): "image" or "video".
        width (int, optional): resized image width. Default is 128.
        height (int, optional): resized image height. Default is 256.
        save_dir (str): directory to save output images.
        topk (int, optional): denoting top-k images in the rank list to be visualized.
            Default is 10.
    """
    num_q, num_g = distmat.shape
    mkdir_if_missing(save_dir)

    print('# query: {}\n# gallery {}'.format(num_q, num_g))
    print('Visualizing top-{} ranks ...'.format(topk))

    #query, gallery = dataset
    #assert num_q == len(query)
    #assert num_g == len(gallery)
    assert num_q == len(dataset.query)
    assert num_g == len(dataset.gallery)
    indices = np.argsort(distmat, axis=1)

    def _cp_img_to(src, dst, rank, prefix, matched=False):
        """
        Args:
            src: image path or tuple (for vidreid)
            dst: target directory
            rank: int, denoting ranked position, starting from 1
            prefix: string
            matched: bool
        """
        if isinstance(src, (tuple, list)):
            if prefix == 'gallery':
                suffix = 'TRUE' if matched else 'FALSE'
                dst = osp.join(
                    dst, prefix + '_top' + str(rank).zfill(3)
                ) + '_' + suffix
            else:
                dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3))
            mkdir_if_missing(dst)
            for img_path in src:
                shutil.copy(img_path, dst)
        else:
            dst = osp.join(
                dst, prefix + '_top' + str(rank).zfill(3) + '_name_' +
                osp.basename(src)
            )
            shutil.copy(src, dst)

    for q_idx in range(num_q):
        qimg_path, qpid, qcamid = dataset.query[q_idx][:3]
        qimg_path_name = qimg_path[0] if isinstance(
            qimg_path, (tuple, list)
        ) else qimg_path

        if data_type == 'image':
            qimg = cv2.imread(qimg_path)
            qimg = cv2.resize(qimg, (width, height))
            qimg = cv2.copyMakeBorder(
                qimg, BW, BW, BW, BW, cv2.BORDER_CONSTANT, value=(0, 0, 0)
            )
            # resize twice to ensure that the border width is consistent across images
            qimg = cv2.resize(qimg, (width, height))
            num_cols = topk + 1
            grid_img = 255 * np.ones(
                (
                    height,
                    num_cols*width + topk*GRID_SPACING + QUERY_EXTRA_SPACING, 3
                ),
                dtype=np.uint8
            )
            grid_img[:, :width, :] = qimg
        else:
            qdir = osp.join(
                save_dir, osp.basename(osp.splitext(qimg_path_name)[0])
            )
            mkdir_if_missing(qdir)
            _cp_img_to(qimg_path, qdir, rank=0, prefix='query')

        rank_idx = 1
        for g_idx in indices[q_idx, :]:
            gimg_path, gpid, gcamid = dataset.gallery[g_idx][:3]
            invalid = (qpid == gpid) & (qcamid == gcamid)

            if not invalid:
                matched = gpid == qpid
                if data_type == 'image':
                    border_color = GREEN if matched else RED
                    gimg = cv2.imread(gimg_path)
                    gimg = cv2.resize(gimg, (width, height))
                    gimg = cv2.copyMakeBorder(
                        gimg,
                        BW,
                        BW,
                        BW,
                        BW,
                        cv2.BORDER_CONSTANT,
                        value=border_color
                    )
                    gimg = cv2.resize(gimg, (width, height))
                    start = rank_idx*width + rank_idx*GRID_SPACING + QUERY_EXTRA_SPACING
                    end = (
                        rank_idx+1
                    ) * width + rank_idx*GRID_SPACING + QUERY_EXTRA_SPACING
                    grid_img[:, start:end, :] = gimg
                else:
                    _cp_img_to(
                        gimg_path,
                        qdir,
                        rank=rank_idx,
                        prefix='gallery',
                        matched=matched
                    )

                rank_idx += 1
                if rank_idx > topk:
                    break

        if data_type == 'image':
            imname = osp.basename(osp.splitext(qimg_path_name)[0])
            cv2.imwrite(osp.join(save_dir, imname + '.jpg'), grid_img)

        if (q_idx+1) % 100 == 0:
            print('- done {}/{}'.format(q_idx + 1, num_q))

    print('Done. Images have been saved to "{}" ...'.format(save_dir))

## Old

In [None]:
from __future__ import absolute_import
from __future__ import print_function

import numpy as np
import os
import os.path as osp
import shutil

#from .iotools import mkdir_if_missing


def visualize_ranked_results(distmat, dataset, save_dir='log/ranked_results', topk=20):
    """
    Visualize ranked results

    Support both imgreid and vidreid

    Args:
    - distmat: distance matrix of shape (num_query, num_gallery).
    - dataset: has dataset.query and dataset.gallery, both are lists of (img_path, pid, camid);
               for imgreid, img_path is a string, while for vidreid, img_path is a tuple containing
               a sequence of strings.
    - save_dir: directory to save output images.
    - topk: int, denoting top-k images in the rank list to be visualized.
    """
    num_q, num_g = distmat.shape

    print("Visualizing top-{} ranks in '{}' ...".format(topk, save_dir))
    print("# query: {}. # gallery {}".format(num_q, num_g))
    
    assert num_q == len(dataset.query)
    assert num_g == len(dataset.gallery)
    
    indices = np.argsort(distmat, axis=1)
    mkdir_if_missing(save_dir)

    def _cp_img_to(src, dst, rank, prefix):
        """
        - src: image path or tuple (for vidreid)
        - dst: target directory
        - rank: int, denoting ranked position, starting from 1
        - prefix: string
        """
        if isinstance(src, tuple):
            dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3))
            mkdir_if_missing(dst)
            for img_path in src:
                shutil.copy(img_path, dst)
        else:
            dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3) + '_name_' + osp.basename(src))
            shutil.copy(src, dst)

    for q_idx in range(num_q):
        qimg_path, qpid, qcamid = dataset.query[q_idx]
        qdir = osp.join(save_dir, 'query' + str(q_idx + 1).zfill(5))
        mkdir_if_missing(qdir)
        _cp_img_to(qimg_path, qdir, rank=0, prefix='query')

        rank_idx = 1
        for g_idx in indices[q_idx,:]:
            gimg_path, gpid, gcamid = dataset.gallery[g_idx]
            invalid = (qpid == gpid) & (qcamid == gcamid)
            if not invalid:
                _cp_img_to(gimg_path, qdir, rank=rank_idx, prefix='gallery')
                rank_idx += 1
                if rank_idx > topk:
                    break

# torchtools

In [None]:
from __future__ import absolute_import
from __future__ import division

import torch
import torch.nn as nn


def adjust_learning_rate(optimizer, base_lr, epoch, stepsize, gamma=0.1):
    # decay learning rate by 'gamma' for every 'stepsize'
    lr = base_lr * (gamma ** (epoch // stepsize))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def set_bn_to_eval(m):
    # 1. no update for running mean and var
    # 2. scale and shift parameters are still trainable
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval()


def count_num_param(model):
    num_param = sum(p.numel() for p in model.parameters()) / 1e+06
    if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module):
        # we ignore the classifier because it is unused at test time
        num_param -= sum(p.numel() for p in model.classifier.parameters()) / 1e+06
    return num_param

# Reranking

In [None]:
"""
q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery]
q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query]
g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery]
k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3)
Returns:
  final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery]
"""

from __future__ import division, print_function, absolute_import
import numpy as np

__all__ = ['re_ranking']


def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3):

    # The following naming, e.g. gallery_num, is different from outer scope.
    # Don't care about it.
    original_dist = np.concatenate(
        [
            np.concatenate([q_q_dist, q_g_dist], axis=1),
            np.concatenate([q_g_dist.T, g_g_dist], axis=1)
        ],
        axis=0
    )
    original_dist = np.power(original_dist, 2).astype(np.float32)
    original_dist = np.transpose(
        1. * original_dist / np.max(original_dist, axis=0)
    )
    V = np.zeros_like(original_dist).astype(np.float32)
    initial_rank = np.argsort(original_dist).astype(np.int32)

    query_num = q_g_dist.shape[0]
    gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1]
    all_num = gallery_num

    for i in range(all_num):
        # k-reciprocal neighbors
        forward_k_neigh_index = initial_rank[i, :k1 + 1]
        backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]
        fi = np.where(backward_k_neigh_index == i)[0]
        k_reciprocal_index = forward_k_neigh_index[fi]
        k_reciprocal_expansion_index = k_reciprocal_index
        for j in range(len(k_reciprocal_index)):
            candidate = k_reciprocal_index[j]
            candidate_forward_k_neigh_index = initial_rank[
                candidate, :int(np.around(k1 / 2.)) + 1]
            candidate_backward_k_neigh_index = initial_rank[
                candidate_forward_k_neigh_index, :int(np.around(k1 / 2.)) + 1]
            fi_candidate = np.where(
                candidate_backward_k_neigh_index == candidate
            )[0]
            candidate_k_reciprocal_index = candidate_forward_k_neigh_index[
                fi_candidate]
            if len(
                np.
                intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)
            ) > 2. / 3 * len(candidate_k_reciprocal_index):
                k_reciprocal_expansion_index = np.append(
                    k_reciprocal_expansion_index, candidate_k_reciprocal_index
                )

        k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
        weight = np.exp(-original_dist[i, k_reciprocal_expansion_index])
        V[i, k_reciprocal_expansion_index] = 1. * weight / np.sum(weight)
    original_dist = original_dist[:query_num, ]
    if k2 != 1:
        V_qe = np.zeros_like(V, dtype=np.float32)
        for i in range(all_num):
            V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0)
        V = V_qe
        del V_qe
    del initial_rank
    invIndex = []
    for i in range(gallery_num):
        invIndex.append(np.where(V[:, i] != 0)[0])

    jaccard_dist = np.zeros_like(original_dist, dtype=np.float32)

    for i in range(query_num):
        temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float32)
        indNonZero = np.where(V[i, :] != 0)[0]
        indImages = []
        indImages = [invIndex[ind] for ind in indNonZero]
        for j in range(len(indNonZero)):
            temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(
                V[i, indNonZero[j]], V[indImages[j], indNonZero[j]]
            )
        jaccard_dist[i] = 1 - temp_min / (2.-temp_min)

    final_dist = jaccard_dist * (1-lambda_value) + original_dist*lambda_value
    del original_dist
    del V
    del jaccard_dist
    final_dist = final_dist[:query_num, query_num:]
    
    return final_dist

# Dataset loader

In [None]:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import os
from PIL import Image
import numpy as np
import os.path as osp
import imdb
import io

import torch
from torch.utils.data import Dataset


def read_image(img_path, mode='RGB'):
    """Keep reading image until succeed.
    This can avoid IOError incurred by heavy IO process."""
    got_img = False
    if not osp.exists(img_path):
        raise IOError("{} does not exist".format(img_path))
    while not got_img:
        try:
            img = Image.open(img_path).convert(mode)
            got_img = True
        except IOError:
            print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
            pass
    return img


class ImageDataset(Dataset):
    """Image Person ReID Dataset"""
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img_path, pid, camid = self.dataset[index]
        img = read_image(img_path)
        if self.transform is not None:
            img = self.transform(img)
        return img, pid, camid


class ImageDataset_seg(Dataset):
    """Image Person ReID Dataset"""
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        path, pid, camid = self.dataset[index]
        img_path, head_path, upper_body_path, lower_body_path, shoes_path, foreground_path = path
        
        img = read_image(img_path, mode='RGB')
        head = read_image(head_path, mode='L')
        upper_body = read_image(upper_body_path, mode='L')
        lower_body = read_image(lower_body_path, mode='L')
        shoes = read_image(shoes_path, mode='L')
        foreground = read_image(foreground_path, mode='L')

        sequence = [img, head, upper_body, lower_body, shoes, foreground]
        
        if self.transform is not None:
            self.transform.randomize_parameters()
            sequence = [self.transform(img) for img in sequence]

        sequence = torch.cat(sequence, 0) #[3+1+1+1, h, w]
        
        return sequence, pid, camid

# evaluation metric

In [None]:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import numpy as np
import copy
from collections import defaultdict
import sys

try:
    from eval_lib.cython_eval import eval_market1501_wrap
    CYTHON_EVAL_AVAI = True
    print("Cython evaluation is AVAILABLE")
except ImportError:
    CYTHON_EVAL_AVAI = False
    print("Warning: Cython evaluation is UNAVAILABLE")


def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, N=100):
    """Evaluation with cuhk03 metric
    Key: one image for each gallery identity is randomly sampled for each query identity.
    Random sampling is performed N times (default: N=100).
    """
    num_q, num_g = distmat.shape
    if num_g < max_rank:
        max_rank = num_g
        print("Note: number of gallery samples is quite small, got {}".format(num_g))
    indices = np.argsort(distmat, axis=1)
    matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)

    # compute cmc curve for each query
    all_cmc = []
    all_AP = []
    num_valid_q = 0. # number of valid query
    for q_idx in range(num_q):
        # get query pid and camid
        q_pid = q_pids[q_idx]
        q_camid = q_camids[q_idx]

        # remove gallery samples that have the same pid and camid with query
        order = indices[q_idx]
        remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
        keep = np.invert(remove)

        # compute cmc curve
        orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
        if not np.any(orig_cmc):
            # this condition is true when query identity does not appear in gallery
            continue

        kept_g_pids = g_pids[order][keep]
        g_pids_dict = defaultdict(list)
        for idx, pid in enumerate(kept_g_pids):
            g_pids_dict[pid].append(idx)

        cmc, AP = 0., 0.
        for repeat_idx in range(N):
            mask = np.zeros(len(orig_cmc), dtype=np.bool)
            for _, idxs in g_pids_dict.items():
                # randomly sample one image for each gallery person
                rnd_idx = np.random.choice(idxs)
                mask[rnd_idx] = True
            masked_orig_cmc = orig_cmc[mask]
            _cmc = masked_orig_cmc.cumsum()
            _cmc[_cmc > 1] = 1
            cmc += _cmc[:max_rank].astype(np.float32)
            # compute AP
            num_rel = masked_orig_cmc.sum()
            tmp_cmc = masked_orig_cmc.cumsum()
            tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
            tmp_cmc = np.asarray(tmp_cmc) * masked_orig_cmc
            AP += tmp_cmc.sum() / num_rel
        cmc /= N
        AP /= N
        all_cmc.append(cmc)
        all_AP.append(AP)
        num_valid_q += 1.

    assert num_valid_q > 0, "Error: all query identities do not appear in gallery"

    all_cmc = np.asarray(all_cmc).astype(np.float32)
    all_cmc = all_cmc.sum(0) / num_valid_q
    mAP = np.mean(all_AP)

    return all_cmc, mAP


def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
    """Evaluation with market1501 metric
    Key: for each query identity, its gallery images from the same camera view are discarded.
    """
    num_q, num_g = distmat.shape
    if num_g < max_rank:
        max_rank = num_g
        print("Note: number of gallery samples is quite small, got {}".format(num_g))
    indices = np.argsort(distmat, axis=1)
    matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)

    # compute cmc curve for each query
    all_cmc = []
    all_AP = []
    num_valid_q = 0. # number of valid query
    for q_idx in range(num_q):
        # get query pid and camid
        q_pid = q_pids[q_idx]
        q_camid = q_camids[q_idx]

        # remove gallery samples that have the same pid and camid with query
        order = indices[q_idx]
        remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
        keep = np.invert(remove)

        # compute cmc curve
        orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
        if not np.any(orig_cmc):
            # this condition is true when query identity does not appear in gallery
            continue

        cmc = orig_cmc.cumsum()
        cmc[cmc > 1] = 1

        all_cmc.append(cmc[:max_rank])
        num_valid_q += 1.

        # compute average precision
        # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
        num_rel = orig_cmc.sum()
        tmp_cmc = orig_cmc.cumsum()
        tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
        tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
        AP = tmp_cmc.sum() / num_rel
        all_AP.append(AP)

    assert num_valid_q > 0, "Error: all query identities do not appear in gallery"

    all_cmc = np.asarray(all_cmc).astype(np.float32)
    all_cmc = all_cmc.sum(0) / num_valid_q
    mAP = np.mean(all_AP)

    return all_cmc, mAP


def evaluate_model(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, use_metrics_cuhk03=False, use_cython=True):
    if use_metrics_cuhk03:
        return eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
    else:
        if use_cython and CYTHON_EVAL_AVAI:
            return eval_market1501_wrap(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
        else:
            return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)

# Loss

In [None]:
from __future__ import absolute_import
from __future__ import division

import sys

import torch
from torch import nn
from torch.nn import functional as F


def DeepSupervision(criterion, xs, y):
    """
    Args:
    - criterion: loss function
    - xs: tuple of inputs
    - y: ground truth
    """
    loss = 0.
    for x in xs:
        loss += criterion(x, y)
    loss /= len(xs)
    return loss


class MaskLoss(nn.Module):
    """L2 or L1 loss or cross entropy loss with average with all elements.
    """
    def __init__(self, mode='l2'):
        super(MaskLoss, self).__init__()
        if mode == 'l2':
            self.loss = nn.MSELoss()
        elif mode == 'l1':
            self.loss = nn.L1Loss()
        elif mode == 'ce':
            self.loss = nn.BCELoss()

    def forward(self, inputs, targets):
        """
        Args:
        - inputs: prediction spatial map with shape (batch_size, 1, h, w)
        - targets: ground truth labels with shape (batch_size, 1, h1, w1)
        """
        b, c, h, w = inputs.size()
        targets = F.interpolate(targets, (h, w), mode='bilinear', align_corners=True)
        inputs = inputs.view(b, -1)
        targets = targets.view(b, -1)
        return self.loss(inputs, targets)


class CrossEntropyLabelSmooth(nn.Module):
    """Cross entropy loss with label smoothing regularizer.

    Reference:
    Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
    Equation: y = (1 - epsilon) * y + epsilon / K.

    Args:
    - num_classes (int): number of classes.
    - epsilon (float): weight.
    """
    def __init__(self, num_classes, epsilon=0.1, use_gpu=True):
        super(CrossEntropyLabelSmooth, self).__init__()
        self.num_classes = num_classes
        self.epsilon = epsilon
        self.use_gpu = use_gpu
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, inputs, targets):
        """
        Args:
        - inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
        - targets: ground truth labels with shape (num_classes)
        """
        log_probs = self.logsoftmax(inputs)
        targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
        if self.use_gpu: targets = targets.cuda()
        targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
        loss = (- targets * log_probs).mean(0).sum()
        return loss


class TripletLoss(nn.Module):
    """Triplet loss with hard positive/negative mining.

    Reference:
    Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.

    Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.

    Args:
    - margin (float): margin for triplet.
    """
    def __init__(self, margin=0.3, distance='euclidean', use_gpu=True):
        super(TripletLoss, self).__init__()
        if distance not in ['euclidean', 'cosine']:
            raise KeyError('Unsupported distance: {}'.format(distance))
        self.distance = distance
        self.margin = margin
        self.use_gpu = use_gpu
        self.ranking_loss = nn.MarginRankingLoss(margin=margin)

    def forward(self, inputs, targets):
        """
        Args:
        - inputs: feature matrix with shape (batch_size, feat_dim)
        - targets: ground truth labels with shape (num_classes)
        """
        n = inputs.size(0)
        
        # Compute pairwise distance, replace by the official when merged
        if self.distance == 'euclidean':
            dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
            dist = dist + dist.t()
            dist.addmm_(1, -2, inputs, inputs.t())
            dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
        elif self.distance == 'cosine':
            fnorm = torch.norm(inputs, p=2, dim=1, keepdim=True)
            l2norm = inputs.div(fnorm.expand_as(inputs))
            dist = - torch.mm(l2norm, l2norm.t())

        if self.use_gpu: targets = targets.cuda()
        
        # For each anchor, find the hardest positive and negative
        mask = targets.expand(n, n).eq(targets.expand(n, n).t())
        dist_ap, dist_an = [], []
        for i in range(n):
            dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
            dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
        dist_ap = torch.cat(dist_ap)
        dist_an = torch.cat(dist_an)
        
        # Compute ranking hinge loss
        y = torch.ones_like(dist_an)
        loss = self.ranking_loss(dist_an, dist_ap, y)
        return loss

# Optimizer

In [None]:
from __future__ import absolute_import

import torch


def init_optim(optim, params, lr, weight_decay):
    if optim == 'adam':
        return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)
    elif optim == 'amsgrad':
        return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay, amsgrad=True)
    elif optim == 'sgd':
        return torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=weight_decay)
    elif optim == 'rmsprop':
        return torch.optim.RMSprop(params, lr=lr, momentum=0.9, weight_decay=weight_decay)
    else:
        raise KeyError("Unsupported optimizer: {}".format(optim))

# samplers

In [None]:
from __future__ import absolute_import
from __future__ import division

from collections import defaultdict
import numpy as np
import copy
import random

import torch
from torch.utils.data.sampler import Sampler


class RandomIdentitySampler(Sampler):
    """
    Randomly sample N identities, then for each identity,
    randomly sample K instances, therefore batch size is N*K.

    Args:
    - data_source (Dataset): dataset to sample from.
    - num_instances (int): number of instances per identity.
    """
    def __init__(self, data_source, num_instances=4):
        self.data_source = data_source
        self.num_instances = num_instances
        self.index_dic = defaultdict(list)
        for index, (_, pid, _) in enumerate(data_source):
            self.index_dic[pid].append(index)
        self.pids = list(self.index_dic.keys())
        self.num_identities = len(self.pids)

        # compute number of examples in an epoch
        self.length = 0
        for pid in self.pids:
            idxs = self.index_dic[pid]
            num = len(idxs)
            if num < self.num_instances:
                num = self.num_instances
            self.length += num - num % self.num_instances

    def __iter__(self):
        list_container = []

        for pid in self.pids:
            idxs = copy.deepcopy(self.index_dic[pid])
            if len(idxs) < self.num_instances:
                idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
            random.shuffle(idxs)
            batch_idxs = []
            for idx in idxs:
                batch_idxs.append(idx)
                if len(batch_idxs) == self.num_instances:
                    list_container.append(batch_idxs)
                    batch_idxs = []

        random.shuffle(list_container)

        ret = []
        for batch_idxs in list_container:
            ret.extend(batch_idxs)

        return iter(ret)

    def __len__(self):
        return self.length

# Spatial transforms

In [None]:
from __future__ import absolute_import

import random
import math
import numbers
import collections
import numpy as np
import torch
from PIL import Image, ImageOps
try:
    import accimage
except ImportError:
    accimage = None


class Compose(object):
    """Composes several transforms together.
    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.
    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

    def randomize_parameters(self):
        for t in self.transforms:
            t.randomize_parameters()


class ToTensor(object):
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.
    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """

    def __init__(self, norm_value=255):
        self.norm_value = norm_value

    def __call__(self, pic):
        """
        Args:
            pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.
        Returns:
            Tensor: Converted image.
        """
        if isinstance(pic, np.ndarray):
            # handle numpy array
            img = torch.from_numpy(pic.transpose((2, 0, 1)))
            # backward compatibility
            return img.float().div(self.norm_value)

        if accimage is not None and isinstance(pic, accimage.Image):
            nppic = np.zeros(
                [pic.channels, pic.height, pic.width], dtype=np.float32)
            pic.copyto(nppic)
            return torch.from_numpy(nppic)

        # handle PIL Image
        if pic.mode == 'I':
            img = torch.from_numpy(np.array(pic, np.int32, copy=False))
        elif pic.mode == 'I;16':
            img = torch.from_numpy(np.array(pic, np.int16, copy=False))
        else:
            img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
        # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
        if pic.mode == 'YCbCr':
            nchannel = 3
        elif pic.mode == 'I;16':
            nchannel = 1
        else:
            nchannel = len(pic.mode)
        img = img.view(pic.size[1], pic.size[0], nchannel)
        # put it from HWC to CHW format
        # yikes, this transpose takes 80% of the loading time/CPU
        img = img.transpose(0, 1).transpose(0, 2).contiguous()
        if isinstance(img, torch.ByteTensor):
            return img.float().div(self.norm_value)
        else:
            return img

    def randomize_parameters(self):
        pass


class Normalize(object):
    """Normalize an tensor image with mean and standard deviation.
    Given mean: (R, G, B) and std: (R, G, B),
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
    """

    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        # TODO: make efficient
        if tensor.size(0) == 1:
            return tensor

        for t, m, s in zip(tensor, self.mean, self.std):
            t.sub_(m).div_(s)
        return tensor

    def randomize_parameters(self):
        pass


class NormalizeSub(object):
    """Normalize an tensor image with mean and standard deviation.
    Given mean: (R, G, B) and std: (R, G, B),
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
    """

    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        # TODO: make efficient
        if tensor.size(0) == 1:
            mean = [0.5]
            std = [0.25]
            for t, m, s in zip(tensor, mean, std):
                t.sub_(m).div_(s)
            return tensor

        for t, m, s in zip(tensor, self.mean, self.std):
            t.sub_(m).div_(s)
        return tensor

    def randomize_parameters(self):
        pass


class Scale(object):
    """Rescale the input PIL.Image to the given size.
    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (w, h), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
    """

    def __init__(self, size, interpolation=Image.BILINEAR):
        assert isinstance(size,
                          int) or (isinstance(size, collections.Iterable) and
                                   len(size) == 2)
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Image to be scaled.
        Returns:
            PIL.Image: Rescaled image.
        """
        if isinstance(self.size, int):
            w, h = img.size
            if (w <= h and w == self.size) or (h <= w and h == self.size):
                return img
            if w < h:
                ow = self.size
                oh = int(self.size * h / w)
                return img.resize((ow, oh), self.interpolation)
            else:
                oh = self.size
                ow = int(self.size * w / h)
                return img.resize((ow, oh), self.interpolation)
        else:
            return img.resize(self.size[::-1], self.interpolation)

    def randomize_parameters(self):
        pass


class RandomCrop(object):
    """Crops the given PIL.Image at a random location.
    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
    """

    def __init__(self, size):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Image to be cropped.
        Returns:
            PIL.Image: Cropped image.
        """
        w, h = img.size
        th, tw = self.size

        x1 = int(round(self.tl_x * (w - tw)))
        y1 = int(round(self.tl_y * (h - th)))
        return img.crop((x1, y1, x1 + tw, y1 + th))

    def randomize_parameters(self):
        self.tl_x = random.random()
        self.tl_y = random.random()

        
class CenterCrop(object):
    """Crops the given PIL.Image at the center.
    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
    """

    def __init__(self, size):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Image to be cropped.
        Returns:
            PIL.Image: Cropped image.
        """
        w, h = img.size
        th, tw = self.size
        x1 = int(round((w - tw) / 2.))
        y1 = int(round((h - th) / 2.))
        return img.crop((x1, y1, x1 + tw, y1 + th))

    def randomize_parameters(self):
        pass


class CornerCrop(object):

    def __init__(self, size, crop_position=None):
        self.size = size
        if crop_position is None:
            self.randomize = True
        else:
            self.randomize = False
        self.crop_position = crop_position
        self.crop_positions = ['c', 'tl', 'tr', 'bl', 'br']

    def __call__(self, img):
        image_width = img.size[0]
        image_height = img.size[1]

        if self.crop_position == 'c':
            th, tw = (self.size, self.size)
            x1 = int(round((image_width - tw) / 2.))
            y1 = int(round((image_height - th) / 2.))
            x2 = x1 + tw
            y2 = y1 + th
        elif self.crop_position == 'tl':
            x1 = 0
            y1 = 0
            x2 = self.size
            y2 = self.size
        elif self.crop_position == 'tr':
            x1 = image_width - self.size
            y1 = 0
            x2 = image_width
            y2 = self.size
        elif self.crop_position == 'bl':
            x1 = 0
            y1 = image_height - self.size
            x2 = self.size
            y2 = image_height
        elif self.crop_position == 'br':
            x1 = image_width - self.size
            y1 = image_height - self.size
            x2 = image_width
            y2 = image_height

        img = img.crop((x1, y1, x2, y2))

        return img

    def randomize_parameters(self):
        if self.randomize:
            self.crop_position = self.crop_positions[random.randint(
                0,
                len(self.crop_positions) - 1)]


class RandomHorizontalFlip(object):
    """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Image to be flipped.
        Returns:
            PIL.Image: Randomly flipped image.
        """
        if self.p < 0.5:
            return img.transpose(Image.FLIP_LEFT_RIGHT)
        return img

    def randomize_parameters(self):
        self.p = random.random()


class MultiScaleCornerCrop(object):
    """Crop the given PIL.Image to randomly selected size.
    A crop of size is selected from scales of the original size.
    A position of cropping is randomly selected from 4 corners and 1 center.
    This crop is finally resized to given size.
    Args:
        scales: cropping scales of the original size
        size: size of the smaller edge
        interpolation: Default: PIL.Image.BILINEAR
    """

    def __init__(self,
                 scales,
                 size,
                 interpolation=Image.BILINEAR,
                 crop_positions=['c', 'tl', 'tr', 'bl', 'br']):
        self.scales = scales
        self.size = size
        self.interpolation = interpolation

        self.crop_positions = crop_positions

    def __call__(self, img):
        min_length = min(img.size[0], img.size[1])
        crop_size = int(min_length * self.scale)

        image_width = img.size[0]
        image_height = img.size[1]

        if self.crop_position == 'c':
            center_x = image_width // 2
            center_y = image_height // 2
            box_half = crop_size // 2
            x1 = center_x - box_half
            y1 = center_y - box_half
            x2 = center_x + box_half
            y2 = center_y + box_half
        elif self.crop_position == 'tl':
            x1 = 0
            y1 = 0
            x2 = crop_size
            y2 = crop_size
        elif self.crop_position == 'tr':
            x1 = image_width - crop_size
            y1 = 0
            x2 = image_width
            y2 = crop_size
        elif self.crop_position == 'bl':
            x1 = 0
            y1 = image_height - crop_size
            x2 = crop_size
            y2 = image_height
        elif self.crop_position == 'br':
            x1 = image_width - crop_size
            y1 = image_height - crop_size
            x2 = image_width
            y2 = image_height

        img = img.crop((x1, y1, x2, y2))

        return img.resize((self.size, self.size), self.interpolation)

    def randomize_parameters(self):
        self.scale = self.scales[random.randint(0, len(self.scales) - 1)]
        self.crop_position = self.crop_positions[random.randint(
            0,
            len(self.scales) - 1)]


class MultiScaleRandomCrop(object):

    def __init__(self, scales, size, interpolation=Image.BILINEAR):
        self.scales = scales
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        min_length = min(img.size[0], img.size[1])
        crop_size = int(min_length * self.scale)

        image_width = img.size[0]
        image_height = img.size[1]

        x1 = self.tl_x * (image_width - crop_size)
        y1 = self.tl_y * (image_height - crop_size)
        x2 = x1 + crop_size
        y2 = y1 + crop_size

        img = img.crop((x1, y1, x2, y2))

        return img.resize((self.size, self.size), self.interpolation)

    def randomize_parameters(self):
        self.scale = self.scales[random.randint(0, len(self.scales) - 1)]
        self.tl_x = random.random()
        self.tl_y = random.random()


class Random2DTranslation(object):
    """
    With a probability, first increase image size to (1 + 1/8), and then perform random crop.

    Args:
        height (int): target height.
        width (int): target width.
        p (float): probability of performing this transformation. Default: 0.5.
    """
    def __init__(self, size, p=0.5, interpolation=Image.BILINEAR):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size

        self.height, self.width = self.size
        self.p = p
        self.interpolation = interpolation

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be cropped.

        Returns:
            PIL Image: Cropped image.
        """
        if not self.cropping:
            return img.resize((self.width, self.height), self.interpolation)
        
        new_width, new_height = int(round(self.width * 1.125)), int(round(self.height * 1.125))
        resized_img = img.resize((new_width, new_height), self.interpolation)
        x_maxrange = new_width - self.width
        y_maxrange = new_height - self.height
        x1 = int(round(self.tl_x * x_maxrange))
        y1 = int(round(self.tl_y * y_maxrange))
        return resized_img.crop((x1, y1, x1 + self.width, y1 + self.height))

    def randomize_parameters(self):
        self.cropping = random.uniform(0, 1) < self.p
        self.tl_x = random.random()
        self.tl_y = random.random()


class RandomErasing(object):
    """ Randomly selects a rectangle region in an image and erases its pixels.
        'Random Erasing Data Augmentation' by Zhong et al.
        See https://arxiv.org/pdf/1708.04896.pdf
    Args:
         probability: The probability that the Random Erasing operation will be performed.
         sl: Minimum proportion of erased area against input image.
         sh: Maximum proportion of erased area against input image.
         r1: Minimum aspect ratio of erased area.
         mean: Erasing value. 
    """
    
    def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]):
        self.probability = probability
        self.mean = mean
        self.sl = sl
        self.sh = sh
        self.r1 = r1
       
    def __call__(self, img):
        """
        img: [C, H, W]
        """

        if img.size(0) == 1:
            return img

        if random.uniform(0, 1) > self.probability:
            return img

        for attempt in range(100):
            area = img.size()[1] * img.size()[2]
       
            target_area = random.uniform(self.sl, self.sh) * area
            aspect_ratio = random.uniform(self.r1, 1/self.r1)

            h = int(round(math.sqrt(target_area * aspect_ratio)))
            w = int(round(math.sqrt(target_area / aspect_ratio)))

            if w < img.size()[2] and h < img.size()[1]:
                x1 = random.randint(0, img.size()[1] - h)
                y1 = random.randint(0, img.size()[2] - w)
                if img.size()[0] == 3:
                    img[0, x1:x1+h, y1:y1+w] = self.mean[0]
                    img[1, x1:x1+h, y1:y1+w] = self.mean[1]
                    img[2, x1:x1+h, y1:y1+w] = self.mean[2]
                else:
                    img[0, x1:x1+h, y1:y1+w] = self.mean[0]
                return img

        return img

    def randomize_parameters(self):
        pass

# Train for MSMT17

In [None]:
from __future__ import print_function
from __future__ import division

import os
import sys
import time
import datetime
import argparse
import os.path as osp
import numpy as np

import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
import torchvision.transforms as T

parser = "Train image model with cross entropy loss and hard triplet loss"
# Datasets
dataset = 'MSMT17_seg' # default = 'market1501_seg' 'MSMT17_seg', 'DukeMTMC_seg'
workers = 4 # default = 4 type = int help = number of data loading workers (default: 4)"
height = 256 # default = 256, 
width = 128 # default = 128
split_id = 0 # default = 0, help = split index 

    
# CUHK03-specific setting
cuhk03_labeled = False 
cuhk03_classic_split = False
use_metric_cuhk03 = False

# Optimization options
optim = 'adam' # default = 'adam'
max_epoch = 60 # default = 60    
start_epoch = 0 # default = 0 , help = manual epoch number (usefult for restart)   
train_batch = 64 # default = 64
test_batch = 32 # default = 32    
lr = 0.00035 # learning rate, default = 0.00035, type = float
stepsize = [20, 40] # stepsize to decay rate    
gamma = 0.1 # learning rate decay, default = 0.1    
weight_decay = 5e-04  #default = 5e-04
margin = 0.3 # margin for triplet loss    
num_instances = 4

# Architecture
arch = 'STCANet'
save_dir = './result/market/STCANet' # path to save model checkpoints

# spatial attention loss
mode = 'ce' #  help='mask loss mode, in {l1, l2, ce}
alpha = 0.5 # the weight for the spatial attention loss

# data process
flip_cnt = 1

# Miscs
distance = 'cosine'
seed = 1
resume = ''
evaluate = True
rerank = False #default = False
eval_step = 10    
start_eval = 0
gpu_devices = '2' # GPU device default = 2 or 3


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    torch.manual_seed(seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu_devices
    use_gpu = torch.cuda.is_available()

    if not evaluate:
        sys.stdout = Logger(osp.join(save_dir, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(save_dir, 'log_test.txt'), mode='a')
    #print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    #print("Initializing dataset {}".format(dataset))
    datasets = init_imgreid_dataset(
        name= dataset, split_id=split_id,
        cuhk03_labeled=cuhk03_labeled, cuhk03_classic_split=cuhk03_classic_split,
    )

    transform_train = Compose([
        Scale((height, width), interpolation=3),
        RandomHorizontalFlip(),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        RandomErasing(0.5),
    ])

    transform_test = T.Compose([
        T.Resize((height, width), interpolation=3),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = True if use_gpu else False
    #ImageDataset_segs = ImageDataset_seg()
    trainloader = DataLoader(
        ImageDataset_seg(datasets.train, transform=transform_train),
        sampler=RandomIdentitySampler(datasets.train, num_instances=num_instances),
        batch_size=train_batch, num_workers=workers,
        pin_memory=pin_memory, drop_last=True,
    )
    #ImageDatasets = ImageDataset()
    queryloader = DataLoader(
        ImageDataset(datasets.query, transform=transform_test),
        batch_size=test_batch, shuffle=False, num_workers=workers,
        pin_memory=pin_memory, drop_last=False,
    )

    galleryloader = DataLoader(
        ImageDataset(datasets.gallery, transform=transform_test),
        batch_size=test_batch, shuffle=False, num_workers=workers,
        pin_memory=pin_memory, drop_last=False,
    )

    print("Initializing model: {}".format(arch))
    model = init_model(name=arch,  num_classes=datasets.num_train_pids) 
    
    print(model)
    print("Model size: {:.3f} M".format(count_num_param(model)))
    
    criterion_xent = CrossEntropyLabelSmooth(num_classes=datasets.num_train_pids, use_gpu=use_gpu)
    criterion_htri = TripletLoss(margin=margin, distance=distance)
    criterion_mask = MaskLoss(mode=mode) 
    
    optimizer = init_optim(optim, model.parameters(), lr, weight_decay)
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=stepsize, gamma=gamma)

    if resume:
        if check_isfile(resume):
            checkpoint = torch.load(resume)
            model.load_state_dict(checkpoint['state_dict'])
            start_epochs = checkpoint['epoch']
            rank1 = checkpoint['rank1']
            print("Loaded checkpoint from '{}'".format(resume))
            print("- start_epoch: {}\n- rank1: {}".format(start_epochs, rank1))

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    if evaluate:
        print("Evaluate only")
        test(model, queryloader, galleryloader, use_gpu = use_gpu)
        return
    #if cross_validate:
     #   print("Cross Dataset Validation")
      #  test(m, queryloader, galleryloader, use_gpu = use_gpu)
      #  return

    start_time = time.time()
    train_time = 0
    best_rank1 = -np.inf
    best_epoch = 0
    print("==> Start training")

    for epoch in range(start_epoch, max_epoch):
        scheduler.step()

        start_train_time = time.time()
        train(epoch, model, criterion_xent, criterion_htri, criterion_mask, optimizer, trainloader)
        train_time += round(time.time() - start_train_time)
        
        if (epoch + 1) % eval_step == 0 or epoch == 0:
            print("==> Test")
            use_gpu = True
            rank1 = test(model, queryloader, galleryloader, use_gpu = use_gpu)
            is_best = rank1 > best_rank1
            
            if is_best:
                best_rank1 = rank1
                best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()
            
            save_checkpoint({
                'state_dict': state_dict,
                'rank1': rank1,
                'epoch': epoch,
            }, is_best, osp.join(save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

    print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
    #print("==========\nArgs:{}\n==========".format(args))


def train(epoch, model, criterion_xent, criterion_htri, criterion_mask, optimizer, trainloader, use_gpu=True):
    xent_losses = AverageMeter()
    htri_losses = AverageMeter()
    mask_losses = AverageMeter()
    accs = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    model.train()
    end = time.time()
    for batch_idx, (sequences, pids, _) in enumerate(trainloader):
        if use_gpu:
            sequences, pids = sequences.cuda(), pids.cuda()

        # measure data loading time
        data_time.update(time.time() - end)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward
        imgs = sequences[:, :3] 
        head_masks = sequences[:, 3:4] 
        upper_masks = sequences[:, 4:5] 
        lower_masks = sequences[:, 5:6] 
        shoes_masks = sequences[:, 6:7] 
        
        outputs, features, a_head, a_upper, a_lower, a_shoes = model(imgs)
        _, preds = torch.max(outputs.data, 1)
        xent_loss = criterion_xent(outputs, pids)
        htri_loss = criterion_htri(features, pids)
        loss = xent_loss + htri_loss

        head_loss = DeepSupervision(criterion_mask, a_head, head_masks)
        upper_loss = DeepSupervision(criterion_mask, a_upper, upper_masks)
        lower_loss = DeepSupervision(criterion_mask, a_lower, lower_masks)
        shoes_loss = DeepSupervision(criterion_mask, a_shoes, shoes_masks)
        mask_loss = (head_loss + upper_loss + lower_loss + shoes_loss) / 4.0

        total_loss = loss + alpha * mask_loss

        # backward + optimize
        total_loss.backward()
        optimizer.step()

        # statistics
        accs.update(torch.sum(preds == pids.data).float()/pids.size(0), pids.size(0))
        xent_losses.update(xent_loss.item(), pids.size(0))
        htri_losses.update(htri_loss.item(), pids.size(0))
        mask_losses.update(mask_loss.item(), pids.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

    print('Epoch{0} '
          'Time:{batch_time.sum:.1f}s '
          'Data:{data_time.sum:.1f}s '
          'xentLoss:{xent_loss.avg:.4f} '
          'triLoss:{tri_loss.avg:.4f} '
          'MaskLoss:{mask_loss.avg:.4f} '
          'Acc:{acc.avg:.2%} '.format(
           epoch+1, batch_time=batch_time,
           data_time=data_time, xent_loss=xent_losses,
           tri_loss=htri_losses, mask_loss=mask_losses, acc=accs))


def fliplr(img, use_gpu):
    '''flip horizontal'''
    inv_idx = torch.arange(img.size(3)-1, -1, -1).long()
    if use_gpu: inv_idx = inv_idx.cuda()
    img_flip = img.index_select(3, inv_idx)
    return img_flip


def test(model, queryloader, galleryloader, use_gpu = True, ranks=[1, 5, 10, 20], rerank = rerank):
    batch_time = AverageMeter()

    model.eval()

    with torch.no_grad():
        qf, q_pids, q_camids = [], [], []
        for batch_idx, (imgs, pids, camids) in enumerate(queryloader):
            end = time.time()

            n, c, h, w = imgs.size()
            features = torch.FloatTensor(n, model.module.feat_dim).zero_()
            for i in range(flip_cnt):
                if (i==1):
                    imgs = fliplr(imgs, use_gpu = use_gpu)
                f = model(imgs)[1]
                f = f.data.cpu()
                features = features + f

            batch_time.update(time.time() - end)

            qf.append(features)
            q_pids.extend(pids)
            q_camids.extend(camids)
        qf = torch.cat(qf, 0)
        q_pids = np.asarray(q_pids)
        q_camids = np.asarray(q_camids)

        print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1)))

        gf, g_pids, g_camids = [], [], []
        for batch_idx, (imgs, pids, camids) in enumerate(galleryloader):

            end = time.time()

            n, c, h, w = imgs.size()
            features = torch.FloatTensor(n, model.module.feat_dim).zero_()
            for i in range(flip_cnt):
                if (i==1):
                    imgs = fliplr(imgs, use_gpu = True)
                if use_gpu: imgs = imgs.cuda()
                f = model(imgs)[1]
                f = f.data.cpu()
                features = features + f

            batch_time.update(time.time() - end)

            gf.append(features)
            g_pids.extend(pids)
            g_camids.extend(camids)
        gf = torch.cat(gf, 0)
        g_pids = np.asarray(g_pids)
        g_camids = np.asarray(g_camids)

        print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1)))
    
    print("==> BatchTime(s)/BatchSize(img): {:.3f}/{}".format(batch_time.avg, test_batch))
    
    #uncomment this and comment below dismat code before rerank
    distmat = compute_distance_matrix(qf, gf, distance)
    
    #m, n = qf.size(0), gf.size(0)
    #distmat = torch.zeros((m, n))
    #if distance == 'euclidean':
     #   distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
      #            torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
       # distmat.addmm_(1, -2, qf, gf.t())
    
    #else:
     #   q_norm = torch.norm(qf, p=2, dim=1, keepdim=True)
      #  g_norm = torch.norm(gf, p=2, dim=1, keepdim=True)
       # qf = qf.div(q_norm.expand_as(qf))
       # gf = gf.div(g_norm.expand_as(gf))
       # distmat = - torch.mm(qf, gf.t())
    
    distmat = distmat.numpy()
    
    if rerank:
        print('Applying person re-ranking ...')
        distmat_qq = compute_distance_matrix(qf, qf, distance)
        distmat_gg = compute_distance_matrix(gf, gf, distance)
        distmat = re_ranking(distmat, distmat_qq, distmat_gg)
    
    print("Computing CMC and mAP")
    use_metric_cuhk03 = False
    cmc, mAP = evaluate_model(distmat, q_pids, g_pids, q_camids, g_camids, use_metrics_cuhk03=use_metric_cuhk03)

    print("Results ----------")
    print("mAP: {:.1%}".format(mAP))
    print("CMC curve")
    for r in ranks:
        print("Rank-{:<3}: {:.1%}".format(r, cmc[r-1]))
    print("------------------")

    return cmc[0]


if __name__ == '__main__':
    main()

# Visualize Ranked Results

In [None]:
from __future__ import print_function
from __future__ import division

import os
import sys
import time
import datetime
import argparse
import os.path as osp
import numpy as np

import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
import torchvision.transforms as T

parser = "Train image model with cross entropy loss and hard triplet loss"
# Datasets
dataset = 'market1501_seg' # default = 'market1501_seg' 'MSMT17_seg', 'DukeMTMC_seg'
workers = 4 # default = 4 type = int help = number of data loading workers (default: 4)"
height = 256 # default = 256, 
width = 128 # default = 128
split_id = 0 # default = 0, help = split index 

    
# CUHK03-specific setting
cuhk03_labeled = False 
cuhk03_classic_split = False
use_metric_cuhk03 = False

# Optimization options
optim = 'adam' # default = 'adam'
max_epoch = 60 # default = 60    
start_epoch = 0 # default = 0 , help = manual epoch number (usefult for restart)   
train_batch = 64 # default = 64
test_batch = 32 # default = 32    
lr = 0.00035 # learning rate, default = 0.00035, type = float
stepsize = [20, 40] # stepsize to decay rate    
gamma = 0.1 # learning rate decay, default = 0.1    
weight_decay = 5e-04  #default = 5e-04
margin = 0.3 # margin for triplet loss    
num_instances = 4

# Architecture
arch = 'STCANet'
save_dir = './result/market/STCANet' # path to save model checkpoints

# spatial attention loss
mode = 'ce' #  help='mask loss mode, in {l1, l2, ce}
alpha = 0.5 # the weight for the spatial attention loss

# data process
flip_cnt = 1

# Miscs
distance = 'cosine'
seed = 1
resume = '' # set path of saved model weights
evaluate = True # evaluate only
rerank = False #default = False
visrank = True
eval_step = 10    
start_eval = 0
gpu_devices = '2' # GPU device default = 2 or 3


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    torch.manual_seed(seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu_devices
    use_gpu = torch.cuda.is_available()

    if not evaluate:
        sys.stdout = Logger(osp.join(save_dir, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(save_dir, 'log_test.txt'), mode='a')
    #print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    #print("Initializing dataset {}".format(dataset))
    datasets = init_imgreid_dataset(
        name= dataset, split_id=split_id,
        cuhk03_labeled=cuhk03_labeled, cuhk03_classic_split=cuhk03_classic_split,
    )

    transform_train = Compose([
        Scale((height, width), interpolation=3),
        RandomHorizontalFlip(),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        RandomErasing(0.5),
    ])

    transform_test = T.Compose([
        T.Resize((height, width), interpolation=3),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = True if use_gpu else False
    #ImageDataset_segs = ImageDataset_seg()
    trainloader = DataLoader(
        ImageDataset_seg(datasets.train, transform=transform_train),
        sampler=RandomIdentitySampler(datasets.train, num_instances=num_instances),
        batch_size=train_batch, num_workers=workers,
        pin_memory=pin_memory, drop_last=True,
    )
    #ImageDatasets = ImageDataset()
    queryloader = DataLoader(
        ImageDataset(datasets.query, transform=transform_test),
        batch_size=test_batch, shuffle=False, num_workers=workers,
        pin_memory=pin_memory, drop_last=False,
    )

    galleryloader = DataLoader(
        ImageDataset(datasets.gallery, transform=transform_test),
        batch_size=test_batch, shuffle=False, num_workers=workers,
        pin_memory=pin_memory, drop_last=False,
    )

    print("Initializing model: {}".format(arch))
    model = init_model(name=arch,  num_classes=datasets.num_train_pids) 
    
    print(model)
    print("Model size: {:.3f} M".format(count_num_param(model)))
    
    criterion_xent = CrossEntropyLabelSmooth(num_classes=datasets.num_train_pids, use_gpu=use_gpu)
    criterion_htri = TripletLoss(margin=margin, distance=distance)
    criterion_mask = MaskLoss(mode=mode) 
    
    optimizer = init_optim(optim, model.parameters(), lr, weight_decay)
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=stepsize, gamma=gamma)

    if resume:
        if check_isfile(resume):
            checkpoint = torch.load(resume)
            model.load_state_dict(checkpoint['state_dict'])
            start_epochs = checkpoint['epoch']
            rank1 = checkpoint['rank1']
            print("Loaded checkpoint from '{}'".format(resume))
            print("- start_epoch: {}\n- rank1: {}".format(start_epochs, rank1))

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    if evaluate:
        print("Evaluate only")
        #test(model, queryloader, galleryloader, use_gpu = use_gpu)
        visualization(model, datasets, queryloader, galleryloader, use_gpu)
        return

    start_time = time.time()
    train_time = 0
    best_rank1 = -np.inf
    best_epoch = 0
    print("==> Start training")

    for epoch in range(start_epoch, max_epoch):
        scheduler.step()

        start_train_time = time.time()
        train(epoch, model, criterion_xent, criterion_htri, criterion_mask, optimizer, trainloader)
        train_time += round(time.time() - start_train_time)
        
        if (epoch + 1) % eval_step == 0 or epoch == 0:
            print("==> Test")
            use_gpu = True
            rank1 = test(model, queryloader, galleryloader, use_gpu = use_gpu)
            is_best = rank1 > best_rank1
            
            if is_best:
                best_rank1 = rank1
                best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()
            
            save_checkpoint({
                'state_dict': state_dict,
                'rank1': rank1,
                'epoch': epoch,
            }, is_best, osp.join(save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

    print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
    #print("==========\nArgs:{}\n==========".format(args))


def train(epoch, model, criterion_xent, criterion_htri, criterion_mask, optimizer, trainloader, use_gpu=True):
    xent_losses = AverageMeter()
    htri_losses = AverageMeter()
    mask_losses = AverageMeter()
    accs = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    model.train()
    end = time.time()
    for batch_idx, (sequences, pids, _) in enumerate(trainloader):
        if use_gpu:
            sequences, pids = sequences.cuda(), pids.cuda()

        # measure data loading time
        data_time.update(time.time() - end)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward
        imgs = sequences[:, :3] 
        head_masks = sequences[:, 3:4] 
        upper_masks = sequences[:, 4:5] 
        lower_masks = sequences[:, 5:6] 
        shoes_masks = sequences[:, 6:7] 
        
        outputs, features, a_head, a_upper, a_lower, a_shoes = model(imgs)
        _, preds = torch.max(outputs.data, 1)
        xent_loss = criterion_xent(outputs, pids)
        htri_loss = criterion_htri(features, pids)
        loss = xent_loss + htri_loss

        head_loss = DeepSupervision(criterion_mask, a_head, head_masks)
        upper_loss = DeepSupervision(criterion_mask, a_upper, upper_masks)
        lower_loss = DeepSupervision(criterion_mask, a_lower, lower_masks)
        shoes_loss = DeepSupervision(criterion_mask, a_shoes, shoes_masks)
        mask_loss = (head_loss + upper_loss + lower_loss + shoes_loss) / 4.0

        total_loss = loss + alpha * mask_loss

        # backward + optimize
        total_loss.backward()
        optimizer.step()

        # statistics
        accs.update(torch.sum(preds == pids.data).float()/pids.size(0), pids.size(0))
        xent_losses.update(xent_loss.item(), pids.size(0))
        htri_losses.update(htri_loss.item(), pids.size(0))
        mask_losses.update(mask_loss.item(), pids.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

    print('Epoch{0} '
          'Time:{batch_time.sum:.1f}s '
          'Data:{data_time.sum:.1f}s '
          'xentLoss:{xent_loss.avg:.4f} '
          'triLoss:{tri_loss.avg:.4f} '
          'MaskLoss:{mask_loss.avg:.4f} '
          'Acc:{acc.avg:.2%} '.format(
           epoch+1, batch_time=batch_time,
           data_time=data_time, xent_loss=xent_losses,
           tri_loss=htri_losses, mask_loss=mask_losses, acc=accs))


def fliplr(img, use_gpu):
    '''flip horizontal'''
    inv_idx = torch.arange(img.size(3)-1, -1, -1).long()
    if use_gpu: inv_idx = inv_idx.cuda()
    img_flip = img.index_select(3, inv_idx)
    return img_flip


def test(model, queryloader, galleryloader, use_gpu = True, ranks=[1, 5, 10, 20], rerank = rerank):
    batch_time = AverageMeter()

    model.eval()

    with torch.no_grad():
        qf, q_pids, q_camids = [], [], []
        for batch_idx, (imgs, pids, camids) in enumerate(queryloader):
            end = time.time()

            n, c, h, w = imgs.size()
            features = torch.FloatTensor(n, model.module.feat_dim).zero_()
            for i in range(flip_cnt):
                if (i==1):
                    imgs = fliplr(imgs, use_gpu = use_gpu)
                f = model(imgs)[1]
                f = f.data.cpu()
                features = features + f

            batch_time.update(time.time() - end)

            qf.append(features)
            q_pids.extend(pids)
            q_camids.extend(camids)
        qf = torch.cat(qf, 0)
        q_pids = np.asarray(q_pids)
        q_camids = np.asarray(q_camids)

        print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1)))

        gf, g_pids, g_camids = [], [], []
        for batch_idx, (imgs, pids, camids) in enumerate(galleryloader):

            end = time.time()

            n, c, h, w = imgs.size()
            features = torch.FloatTensor(n, model.module.feat_dim).zero_()
            for i in range(flip_cnt):
                if (i==1):
                    imgs = fliplr(imgs, use_gpu = True)
                if use_gpu: imgs = imgs.cuda()
                f = model(imgs)[1]
                f = f.data.cpu()
                features = features + f

            batch_time.update(time.time() - end)

            gf.append(features)
            g_pids.extend(pids)
            g_camids.extend(camids)
        gf = torch.cat(gf, 0)
        g_pids = np.asarray(g_pids)
        g_camids = np.asarray(g_camids)

        print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1)))
    
    print("==> BatchTime(s)/BatchSize(img): {:.3f}/{}".format(batch_time.avg, test_batch))
    
    #uncomment this and comment below dismat code before rerank
    distmat = compute_distance_matrix(qf, gf, distance)
    
    #m, n = qf.size(0), gf.size(0)
    #distmat = torch.zeros((m, n))
    #if distance == 'euclidean':
     #   distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
      #            torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
       # distmat.addmm_(1, -2, qf, gf.t())
    
    #else:
     #   q_norm = torch.norm(qf, p=2, dim=1, keepdim=True)
      #  g_norm = torch.norm(gf, p=2, dim=1, keepdim=True)
       # qf = qf.div(q_norm.expand_as(qf))
       # gf = gf.div(g_norm.expand_as(gf))
       # distmat = - torch.mm(qf, gf.t())
    
    distmat = distmat.numpy()
    
    if rerank:
        print('Applying person re-ranking ...')
        distmat_qq = compute_distance_matrix(qf, qf, distance)
        distmat_gg = compute_distance_matrix(gf, gf, distance)
        distmat = re_ranking(distmat, distmat_qq, distmat_gg)
    
  
    print("Computing CMC and mAP")
    use_metric_cuhk03 = False
    cmc, mAP = evaluate_model(distmat, q_pids, g_pids, q_camids, g_camids, use_metrics_cuhk03=use_metric_cuhk03)

    print("Results ----------")
    print("mAP: {:.1%}".format(mAP))
    print("CMC curve")
    for r in ranks:
        print("Rank-{:<3}: {:.1%}".format(r, cmc[r-1]))
    print("------------------")

  
    return cmc[0]

def visualization(model, datasets, queryloader, galleryloader, use_gpu = True, ranks=[1, 5, 10, 20], rerank = rerank):
    batch_time = AverageMeter()

    model.eval()

    with torch.no_grad():
        qf, q_pids, q_camids = [], [], []
        for batch_idx, (imgs, pids, camids) in enumerate(queryloader):
            end = time.time()

            n, c, h, w = imgs.size()
            features = torch.FloatTensor(n, model.module.feat_dim).zero_()
            for i in range(flip_cnt):
                if (i==1):
                    imgs = fliplr(imgs, use_gpu = use_gpu)
                f = model(imgs)[1]
                f = f.data.cpu()
                features = features + f

            batch_time.update(time.time() - end)

            qf.append(features)
            q_pids.extend(pids)
            q_camids.extend(camids)
        qf = torch.cat(qf, 0)
        q_pids = np.asarray(q_pids)
        q_camids = np.asarray(q_camids)

        print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1)))

        gf, g_pids, g_camids = [], [], []
        for batch_idx, (imgs, pids, camids) in enumerate(galleryloader):

            end = time.time()

            n, c, h, w = imgs.size()
            features = torch.FloatTensor(n, model.module.feat_dim).zero_()
            for i in range(flip_cnt):
                if (i==1):
                    imgs = fliplr(imgs, use_gpu = True)
                if use_gpu: imgs = imgs.cuda()
                f = model(imgs)[1]
                f = f.data.cpu()
                features = features + f

            batch_time.update(time.time() - end)

            gf.append(features)
            g_pids.extend(pids)
            g_camids.extend(camids)
        gf = torch.cat(gf, 0)
        g_pids = np.asarray(g_pids)
        g_camids = np.asarray(g_camids)

        print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1)))
    
    print("==> BatchTime(s)/BatchSize(img): {:.3f}/{}".format(batch_time.avg, test_batch))
    
    #uncomment this and comment below dismat code before rerank
    distmat = compute_distance_matrix(qf, gf, distance)
    
    #m, n = qf.size(0), gf.size(0)
    #distmat = torch.zeros((m, n))
    #if distance == 'euclidean':
     #   distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
      #            torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
       # distmat.addmm_(1, -2, qf, gf.t())
    
    #else:
     #   q_norm = torch.norm(qf, p=2, dim=1, keepdim=True)
      #  g_norm = torch.norm(gf, p=2, dim=1, keepdim=True)
       # qf = qf.div(q_norm.expand_as(qf))
       # gf = gf.div(g_norm.expand_as(gf))
       # distmat = - torch.mm(qf, gf.t())
    
    distmat = distmat.numpy()
    
    if rerank:
        print('Applying person re-ranking ...')
        distmat_qq = compute_distance_matrix(qf, qf, distance)
        distmat_gg = compute_distance_matrix(gf, gf, distance)
        distmat = re_ranking(distmat, distmat_qq, distmat_gg)
    
    if visrank:
      print("Visualize ranked results")
      #ImageDataset(datasets.query, transform=transform_test)
      visualize_ranked_results(distmat, datasets)
      print("Done")
      
    print("Computing CMC and mAP")
    use_metric_cuhk03 = False
    cmc, mAP = evaluate_model(distmat, q_pids, g_pids, q_camids, g_camids, use_metrics_cuhk03=use_metric_cuhk03)

    print("Results ----------")
    print("mAP: {:.1%}".format(mAP))
    print("CMC curve")
    for r in ranks:
        print("Rank-{:<3}: {:.1%}".format(r, cmc[r-1]))
    print("------------------")

            
    return cmc[0]

if __name__ == '__main__':
    main()