## GaitSet Implements

In [None]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tordata
from torch.autograd import Variable
import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np

from network.network import SetNet
from network.network_layer import *
from network.triplet_loss import *
from utils.triplet_sampler import *
from utils.data_load import *

from datetime import datetime

from sklearn.manifold import TSNE

import multiprocessing as mp

### Settings

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### data

In [None]:
train_dataset, test_dataset = load_OU_ISIR('./data/OU_ISIR/npy/', False)

In [None]:
triplet_sampler = TripletSampler(train_dataset, [16, 2])

train_loader = tordata.DataLoader(dataset=train_dataset, batch_sampler=triplet_sampler, collate_fn=collate_fnn)
test_loader = tordata.DataLoader(dataset=test_dataset, batch_size=1, sampler=tordata.sampler.SequentialSampler(test_dataset), collate_fn=collate_fnn)

### model

In [None]:
encoder = SetNet(128).to(device)

### loss

In [None]:
criterion = TripletLoss(8, 0.2).to(device)
optimizer = optim.Adam([{'params':encoder.parameters()},], lr=1e-5)

### train

In [None]:
total_iter = 20000
all_losses = []

pool = mp.Pool(processes=6)

s_time = datetime.now()
for i, (seqs, view, label) in enumerate(train_loader):
    
    feature = encoder(seqs)
    
    tmp_label_set = list(train_dataset.set_label)
    
    target_label = [tmp_label_set.index(l) for l in label]
    target_label = Variable(torch.IntTensor(target_label)).to(device)
    
    triplet_feature = feature.permute(1, 0, 2).contiguous()
    triplet_label = target_label.unsqueeze(0).repeat(triplet_feature.size(0), 1)
    
    hard_loss, dist_mean = criterion(triplet_feature, triplet_label)
    
    loss = hard_loss.mean()
    
    if loss > 1e-9:
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()        

    if (i+1) % 100 == 0:
        print('Step [{}], Elapsed Time [{}], Loss [{}]'.format(i+1, datetime.now() - s_time, loss))
        all_losses.append(loss)
        s_time = datetime.now()
        
    #if (i+1) % 500 == 0:
    #    pca = TSNE(2)
    #    pca_feature = pca.fit_transform( feature.contiguous().view(feature.size(0), -1).data.cpu().numpy() )
    #    for i in range(16):
    #        plt.scatter(pca_feature[i:i+1, 0], pca_feature[i:i+1, 1], label=label[i])
    #
    #    plt.show()
        
    if (i+1) == total_iter:
        torch.save(encoder.state_dict(), os.path.join('./checkpoint/','OU_ISIR_Encoder_1.ptm'))
        torch.save(optimizer.state_dict(), os.path.join('./checkpoint/','OU_ISIR_Optimizer_1.ptm'))
        break
        
    if (i+1) == (total_iter // 2):
        for g in optimizer.param_groups:
            g['lr'] = 1e-6
        
pool.close()
pool.join()

In [None]:
def drawLoss(loss_dict):
    plt.style.use(['ggplot'])
    
    for key, value in loss_dict.items():
        x = np.arange(len(loss_dict[key]))
        plt.plot(x, loss_dict[key], label=key)
    
    plt.xlabel("train step")
    plt.ylabel("loss")
    
    plt.show()

In [None]:
drawLoss({'Loss':all_losses})

### test

In [None]:
encoder.load_state_dict(torch.load('./checkpoint/OU_ISIR_Encoder.ptm'))
optimizer.load_state_dict(torch.load('./checkpoint/OU_ISIR_Optimizer.ptm'))

In [None]:
pool = mp.Pool(processes=6)

feature_list = []
view_list = []
label_list = []

s_time = datetime.now()
for i, (seqs, view, label) in enumerate(test_loader):

    feature = encoder(seqs)
    
    n, num_bins, _ = feature.size()
    feature_list.append(feature.view(n, -1).data.cpu().numpy())
    view_list += view
    label_list += label
    
test = np.concatenate(feature_list, 0)
        
pool.close()
pool.join()

In [None]:
pca = TSNE(2, perplexity=40, learning_rate=100, verbose=True, random_state=0)

aa = torch.tensor(feature_list)
pca_feature = pca.fit_transform( aa.contiguous().view(aa.size(0), -1).data.cpu().numpy() )

xs = pca_feature[:,0]
ys = pca_feature[:,1]

plt.scatter(xs, ys, c = label_list)

plt.show()

In [None]:
def cuda_dist(x, y):
    x = torch.from_numpy(x).cuda()
    y = torch.from_numpy(y).cuda()
    dist = torch.sum(x ** 2, 1).unsqueeze(1) + torch.sum(y ** 2, 1).unsqueeze(
        1).transpose(0, 1) - 2 * torch.matmul(x, y.transpose(0, 1))
    dist = torch.sqrt(F.relu(dist))
    return dist

In [None]:
def evaluation(feature, view, label):
    
    label = np.array(label)
    view_list = list(set(view))
    view_list.sort()
    view_num = len(view_list)
    sample_num = len(feature)

    probe_seq_list = ['00']
    gallery_seq_list = ['01']

    num_rank = 5
    acc = np.zeros([len(probe_seq_list), view_num, view_num, num_rank])
    for (p, probe_seq) in enumerate(probe_seq_list):
        for gallery_seq in gallery_seq_list:
            for (v1, probe_view) in enumerate(view_list):
                for (v2, gallery_view) in enumerate(view_list):
                    gseq_mask = np.isin(view, [gallery_view])
                    gallery_x = feature[gseq_mask, :]
                    gallery_y = label[gseq_mask]

                    pseq_mask = np.isin(view, [probe_view])
                    probe_x = feature[pseq_mask, :]
                    probe_y = label[pseq_mask]

                    dist = cuda_dist(probe_x, gallery_x)
                    idx = dist.sort(1)[1].cpu().numpy()
                    acc[p, v1, v2, :] = np.round(
                        np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx[:, 0:num_rank]], 1) > 0,
                               0) * 100 / dist.shape[0], 2)

    return acc

In [None]:
def de_diag(acc, each_angle=False):
    result = np.sum(acc - np.diag(np.diag(acc)), 1) / 10.0
    if not each_angle:
        result = np.mean(result)
    return result

In [None]:
acc = evaluation(test, view_list, label_list)

In [None]:
test_acc = acc[0,:,:,0]

res_del_same_view = np.sum(test_acc - np.diag(np.diag(test_acc)), 1) / 13.0
res_with_same_view = np.mean(test_acc, 1)

print('Without Same View Result')
print(res_del_same_view)
print(res_del_same_view.mean())
print('With Same View Result')
print(res_with_same_view)
print(res_with_same_view.mean())