-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_VGG_fc.py
72 lines (61 loc) · 3.49 KB
/
train_VGG_fc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from __future__ import absolute_import, print_function
import os
import sys
import torch
from torch.utils.data import DataLoader
from torchvision import models
from got10k.datasets import ImageNetVID, GOT10k
from pairwise_cf import Pairwise
from siamfc_VGG_cf import TrackerSiamFC
from got10k_tmp.experiments import *
if __name__ == '__main__':
# setup dataset
name = 'VID'
assert name in ['VID', 'GOT-10k']
if name == 'GOT-10k':
root_dir = 'data/GOT-10k'
seq_dataset = GOT10k(root_dir, subset='train')
elif name == 'VID':
root_dir = '/home/user/ILSVRC2015'
seq_dataset = ImageNetVID(root_dir, subset=('train', 'val'))
pair_dataset = Pairwise(seq_dataset)
# setup data loader
cuda = torch.cuda.is_available()
loader = DataLoader(
pair_dataset, batch_size=8, shuffle=True,
pin_memory=cuda, drop_last=True, num_workers=4)
# setup tracker
tracker = TrackerSiamFC()
#pretrained vgg
model_vgg = models.vgg16(pretrained=True)
model_state = tracker.net.state_dict()
vgg_dict = model_vgg.state_dict()
pretrained_dict = {'features.features1.0.bias':vgg_dict['features.0.bias'], 'features.features1.0.weight':vgg_dict['features.0.weight'], 'features.features1.2.weight':vgg_dict['features.2.weight'], 'features.features1.2.bias':vgg_dict['features.2.bias'], 'features.features2.2.bias':vgg_dict['features.5.bias'], 'features.features2.2.weight':vgg_dict['features.5.weight'], 'features.features2.4.bias':vgg_dict['features.7.bias'], 'features.features2.4.weight':vgg_dict['features.7.weight'], 'features.features2.7.bias':vgg_dict['features.10.bias'], 'features.features2.7.weight':vgg_dict['features.10.weight'], 'features.features2.9.bias':vgg_dict['features.12.bias'], 'features.features2.9.weight':vgg_dict['features.12.weight'], 'features.features2.11.weight':vgg_dict['features.14.weight'], 'features.features2.11.bias':vgg_dict['features.14.bias'], 'features.features2.14.weight':vgg_dict['features.17.weight'], 'features.features2.14.bias':vgg_dict['features.17.bias'], 'features.features2.16.weight':vgg_dict['features.19.weight'], 'features.features2.16.bias':vgg_dict['features.19.bias'], 'features.features2.18.weight':vgg_dict['features.21.weight'], 'features.features2.18.bias':vgg_dict['features.21.bias']}
model_state.update(pretrained_dict)
tracker.net.load_state_dict(model_state)
# path for saving checkpoints
net_dir = 'pretrained/siamfc_new'
if not os.path.exists(net_dir):
os.makedirs(net_dir)
best_auc = 0
# training loop
epoch_num = 50
for epoch in range(epoch_num):
for step, batch in enumerate(loader):
loss = tracker.step(
batch, backward=True, update_lr=(step == 0))
if step % 100 == 0:
print('Epoch [{}][{}/{}]: Loss: {:.3f}'.format(
epoch + 1, step + 1, len(loader), loss))
# save checkpoint
net_path = os.path.join(net_dir, 'model_e%d.pth' % (epoch + 1))
torch.save(tracker.net.state_dict(), net_path)
test_tracker = TrackerSiamFC(net_path=net_path)
e = ExperimentOTB("/home/user/OTB100", version=2015)
e.run(test_tracker, visualize=False)
auc = e.report([test_tracker.name])
if auc > best_auc:
net_path2 = os.path.join(net_dir, 'model_e%d_BEST.pth' % (epoch + 1))
torch.save(tracker.net.state_dict(), net_path2)
best_auc = auc
print("now_best:{}".format(best_auc))