In [1]:
from pathlib import Path
from os.path import expanduser
from os import path
import numpy as np
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
from torchsummary import summary
from torchvision import transforms, datasets
from torch.optim import lr_scheduler
import os



In [2]:
root_folder = os.path.dirname(os.getcwd())
path_to_train_data = path.join(root_folder,'data/train_data_re_id.npy');
path_to_labels = path.join(root_folder,'data/train_labels_re_id.npy');
#path_to_mean_std = path.join(root_folder,'data/train_data/synthesized/mean_std.npy');

sys.path.append(root_folder)
from custom_resnet import CustomResnet as cnn

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 15000;
nb_of_epochs = 75;
logging_interval = 30
dataset_divider = 0.2

In [4]:
data = np.load(path_to_train_data);
labels = np.load(path_to_labels);
torch_data = torch.tensor(data).float()
torch_labels = torch.tensor(labels).int()
nb_of_classes = torch_labels.unique().size()[0]

In [5]:
print(nb_of_classes)

499


In [6]:
net = cnn.ft_net( nb_of_classes, 0);
net.to(device)

ft_net(
  (model): ResNet(
    (conv1): Conv1d(1, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
    (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
    (maxpool): MaxPool1d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
        (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
        (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
        (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU

In [7]:
if (len(torch_labels.size()) == 2):
    torch_labels = torch.squeeze(torch_labels)

In [8]:
# vertical flip
train_data_original, test_data_original = cnn.GenerateTrainAndTestDataset(torch_data, torch_labels, 0.2)
train_data_flipped, test_data_flippedl = cnn.GenerateTrainAndTestDataset(torch_data * -1, torch_labels, 0.2)



In [9]:
train_data = torch.utils.data.ConcatDataset((train_data_original, train_data_flipped));
test_data = torch.utils.data.ConcatDataset((test_data_original, test_data_flippedl));


In [10]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=8)





In [11]:
print(torch_data.size())
print(torch_labels.size())

torch.Size([2500489, 1, 72])
torch.Size([2500489])


In [12]:
print(torch_data[0,:])

tensor([[-5.2435e-01, -3.7025e-01, -1.3434e-01, -2.4474e-01, -3.5353e-01,
         -3.3476e-01, -2.0875e-01, -4.7216e-01, -7.2495e-01, -1.2463e+00,
         -1.5984e+00, -1.4817e+00, -1.7369e+00, -1.8890e+00, -2.1955e+00,
         -2.2463e+00, -2.0012e+00, -1.5047e+00, -1.1344e+00, -1.1903e+00,
         -1.4721e+00, -1.3439e+00, -9.3072e-01,  6.1741e-02,  7.0020e-01,
          1.2683e+00,  1.9363e+00,  2.1746e+00,  1.9996e+00,  1.6970e+00,
          1.5125e+00,  7.9792e-01, -6.8663e-02, -2.5527e+00, -7.1471e+00,
         -1.3885e+01, -1.8149e+01, -1.5931e+01, -9.2331e+00, -2.6818e+00,
          2.2164e+00,  5.5776e+00,  7.8141e+00,  9.1098e+00,  9.7197e+00,
          9.6600e+00,  8.8542e+00,  7.7477e+00,  6.6816e+00,  5.5360e+00,
          4.2213e+00,  3.1576e+00,  2.3921e+00,  1.9880e+00,  1.7718e+00,
          1.7307e+00,  1.8030e+00,  1.3844e+00,  5.8021e-01, -1.8272e-01,
         -9.4894e-01, -1.1664e+00, -1.1201e+00, -8.6691e-01, -3.7466e-01,
         -1.2474e-02,  3.5129e-01,  3.

In [13]:

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[5, 25, 45, 65], gamma=0.1)
for i in range(nb_of_epochs):  # loop over the dataset multiple times
  loss = cnn.Train(net, device, train_loader, optimizer, criterion, exp_lr_scheduler, i, logging_interval)
  model_name = 'models/re_id/resnet18_num_classes_' + str(nb_of_classes) + "_epoch_" + str(i) + '.pt';
  path_to_model = path.join(root_folder, model_name)
  torch.save({'epoch': i,'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss,}, path_to_model)
  cnn.Test(net, device, criterion, train_loader)
  cnn.Test(net, device, criterion, test_loader)


Test set: Average loss: 0.00006167180981719866, Accuracy: 2644361/4000982 (66.09279921779203448295%)


Test set: Average loss: 0.00006202644726727158, Accuracy: 659874/999996 (65.98766395065580070423%)


Test set: Average loss: 0.00005132034857524559, Accuracy: 2795382/4000982 (69.86739755390051698214%)


Test set: Average loss: 0.00005184676047065295, Accuracy: 696439/999996 (69.64417857671431022482%)


Test set: Average loss: 0.00004613907367456704, Accuracy: 2900811/4000982 (72.50247564222983953641%)


Test set: Average loss: 0.00004688628177973442, Accuracy: 721857/999996 (72.18598874395497944079%)


Test set: Average loss: 0.00004164083657087758, Accuracy: 2989100/4000982 (74.70915890148968685480%)


Test set: Average loss: 0.00004249623088981025, Accuracy: 742789/999996 (74.27919711678846681480%)


Test set: Average loss: 0.00004163079574937001, Accuracy: 3003827/4000982 (75.07724353671173389557%)


Test set: Average loss: 0.00004271812213119119, Accuracy: 745410/999996 (74.5412


Test set: Average loss: 0.00002939320984296501, Accuracy: 3321634/4000982 (83.02046847498938575427%)


Test set: Average loss: 0.00003094838757533580, Accuracy: 820484/999996 (82.04872819491278335136%)


Test set: Average loss: 0.00002900134677474853, Accuracy: 3330036/4000982 (83.23046692037104321571%)


Test set: Average loss: 0.00003059000300709158, Accuracy: 822380/999996 (82.23832895331581482878%)


Test set: Average loss: 0.00002852685247489717, Accuracy: 3342814/4000982 (83.54983851464464805758%)


Test set: Average loss: 0.00003016699884028640, Accuracy: 824858/999996 (82.48612994451977442623%)


Test set: Average loss: 0.00002796039007080253, Accuracy: 3358830/4000982 (83.95014024057094559339%)


Test set: Average loss: 0.00002966163810924627, Accuracy: 828917/999996 (82.89203156812627071304%)


Test set: Average loss: 0.00002752677028183825, Accuracy: 3369517/4000982 (84.21724966520719135588%)


Test set: Average loss: 0.00002926905835920479, Accuracy: 831167/999996 (83.1170


Test set: Average loss: 0.00002485974619048648, Accuracy: 3439746/4000982 (85.97254374051171055271%)


Test set: Average loss: 0.00002690447036002297, Accuracy: 846549/999996 (84.65523862095447782394%)


Test set: Average loss: 0.00002473297354299575, Accuracy: 3444804/4000982 (86.09896270465600309763%)


Test set: Average loss: 0.00002681612932065036, Accuracy: 847800/999996 (84.78033912135649075026%)


Test set: Average loss: 0.00002416823554085568, Accuracy: 3459772/4000982 (86.47307086110359364284%)


Test set: Average loss: 0.00002629862865433097, Accuracy: 851088/999996 (85.10914043656174499120%)


Test set: Average loss: 0.00002379374382144306, Accuracy: 3468914/4000982 (86.70156476584998017643%)


Test set: Average loss: 0.00002596595550130587, Accuracy: 853099/999996 (85.31024124096497018854%)


Test set: Average loss: 0.00002344025233469438, Accuracy: 3477677/4000982 (86.92058599613794456218%)


Test set: Average loss: 0.00002566008151916321, Accuracy: 854605/999996 (85.4608


Test set: Average loss: 0.00002239151581306942, Accuracy: 3509285/4000982 (87.71059204965180811087%)


Test set: Average loss: 0.00002468195634719450, Accuracy: 862207/999996 (86.22104488417953405133%)


Test set: Average loss: 0.00002234425483038649, Accuracy: 3509884/4000982 (87.72556337419163696723%)


Test set: Average loss: 0.00002464246244926471, Accuracy: 862311/999996 (86.23144492577969799640%)


Test set: Average loss: 0.00002229920755780768, Accuracy: 3511203/4000982 (87.75853028081606055366%)


Test set: Average loss: 0.00002459961615386419, Accuracy: 862624/999996 (86.26274505098020028981%)


Test set: Average loss: 0.00002226284050266258, Accuracy: 3512330/4000982 (87.78669836555125982613%)


Test set: Average loss: 0.00002457181471982040, Accuracy: 862931/999996 (86.29344517378069667757%)


Test set: Average loss: 0.00002221945578639861, Accuracy: 3513727/4000982 (87.82161479356817324060%)


Test set: Average loss: 0.00002453435809002258, Accuracy: 863039/999996 (86.3042


Test set: Average loss: 0.00002192195279349107, Accuracy: 3521395/4000982 (88.01326774276914477468%)


Test set: Average loss: 0.00002427564686513506, Accuracy: 864753/999996 (86.47564590258360794905%)


Test set: Average loss: 0.00002188621692766901, Accuracy: 3521845/4000982 (88.02451498157202536277%)


Test set: Average loss: 0.00002424490594421513, Accuracy: 864921/999996 (86.49244596978387278341%)


Test set: Average loss: 0.00002184337972721551, Accuracy: 3523207/4000982 (88.05855662434872499489%)


Test set: Average loss: 0.00002421208409941755, Accuracy: 865279/999996 (86.52824611298444779095%)


Test set: Average loss: 0.00002180505543947220, Accuracy: 3524002/4000982 (88.07842674623380219145%)


Test set: Average loss: 0.00002417384530417621, Accuracy: 865271/999996 (86.52744610978443517979%)


Test set: Average loss: 0.00002176129964936990, Accuracy: 3524988/4000982 (88.10307069614410124814%)


Test set: Average loss: 0.00002413347647234332, Accuracy: 865644/999996 (86.5647


Test set: Average loss: 0.00002168060746043921, Accuracy: 3527985/4000982 (88.17797730657123622677%)


Test set: Average loss: 0.00002405924897175282, Accuracy: 866413/999996 (86.64164656658626029184%)


Test set: Average loss: 0.00002167294405808207, Accuracy: 3528079/4000982 (88.18032672978783637063%)


Test set: Average loss: 0.00002405382656434085, Accuracy: 866397/999996 (86.64004656018623506952%)


Test set: Average loss: 0.00002166335070796777, Accuracy: 3528047/4000982 (88.17952692613962994983%)


Test set: Average loss: 0.00002404636143182870, Accuracy: 866375/999996 (86.63784655138620394155%)


Test set: Average loss: 0.00002166143713111524, Accuracy: 3528364/4000982 (88.18744998102965837461%)


Test set: Average loss: 0.00002404440965619870, Accuracy: 866460/999996 (86.64634658538633971148%)


Test set: Average loss: 0.00002166056765418034, Accuracy: 3528362/4000982 (88.18739999330163925606%)


Test set: Average loss: 0.00002404300721536856, Accuracy: 866501/999996 (86.6504


Test set: Average loss: 0.00002163268982258160, Accuracy: 3528961/4000982 (88.20237131784146811242%)


Test set: Average loss: 0.00002401915298833046, Accuracy: 866551/999996 (86.65544662178648138706%)


Test set: Average loss: 0.00002162542477890383, Accuracy: 3529066/4000982 (88.20499567356213788116%)


Test set: Average loss: 0.00002401234632998239, Accuracy: 866601/999996 (86.66044664178656375952%)


Test set: Average loss: 0.00002162471719202586, Accuracy: 3529347/4000982 (88.21201894934793585890%)


Test set: Average loss: 0.00002401123492745683, Accuracy: 866626/999996 (86.66294665178661205118%)


Test set: Average loss: 0.00002162214695999864, Accuracy: 3529189/4000982 (88.20806991883492287343%)


Test set: Average loss: 0.00002401146048214287, Accuracy: 866674/999996 (86.66774667098668771814%)


Test set: Average loss: 0.00002160929216188379, Accuracy: 3529484/4000982 (88.21544310871681204844%)


Test set: Average loss: 0.00002399856748525053, Accuracy: 866706/999996 (86.6709


Test set: Average loss: 0.00002161286101909354, Accuracy: 3529514/4000982 (88.21619292463699935070%)


Test set: Average loss: 0.00002400241464783903, Accuracy: 866691/999996 (86.66944667778670918779%)


Test set: Average loss: 0.00002161525662813801, Accuracy: 3529454/4000982 (88.21469329279662474619%)


Test set: Average loss: 0.00002400407174718566, Accuracy: 866683/999996 (86.66864667458669657663%)


Test set: Average loss: 0.00002160664189432282, Accuracy: 3529832/4000982 (88.22414097339103022932%)


Test set: Average loss: 0.00002399730692559388, Accuracy: 866822/999996 (86.68254673018691391917%)


Test set: Average loss: 0.00002160278745577671, Accuracy: 3529797/4000982 (88.22326618815080223612%)


Test set: Average loss: 0.00002399401637376286, Accuracy: 866789/999996 (86.67924671698686722721%)


Test set: Average loss: 0.00002160637086490169, Accuracy: 3529767/4000982 (88.22251637223061493387%)


Test set: Average loss: 0.00002399695767962839, Accuracy: 866792/999996 (86.6795

In [14]:
cnn.Test(net, device, criterion, test_loader)


Test set: Average loss: 0.00002399695767962839, Accuracy: 866792/999996 (86.67954671818687018003%)

