-
Notifications
You must be signed in to change notification settings - Fork 10
/
train.py
90 lines (77 loc) · 3.5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import torch
import torch.nn as nn
import visdom
import random
from mcnn_model import MCNN
from my_dataloader import CrowdDataset
if __name__=="__main__":
torch.backends.cudnn.enabled=False
vis=visdom.Visdom()
device=torch.device("cuda")
mcnn=MCNN().to(device)
criterion=nn.MSELoss(size_average=False).to(device)
optimizer = torch.optim.SGD(mcnn.parameters(), lr=1e-6,
momentum=0.95)
img_root='D:\\workspaceMaZhenwei\\GithubProject\\MCNN-pytorch\\data\\Shanghai_part_A\\train_data\\images'
gt_dmap_root='D:\\workspaceMaZhenwei\\GithubProject\\MCNN-pytorch\\data\\Shanghai_part_A\\train_data\\ground_truth'
dataset=CrowdDataset(img_root,gt_dmap_root,4)
dataloader=torch.utils.data.DataLoader(dataset,batch_size=1,shuffle=True)
test_img_root='D:\\workspaceMaZhenwei\\GithubProject\\MCNN-pytorch\\data\\Shanghai_part_A\\test_data\\images'
test_gt_dmap_root='D:\\workspaceMaZhenwei\\GithubProject\\MCNN-pytorch\\data\\Shanghai_part_A\\test_data\\ground_truth'
test_dataset=CrowdDataset(test_img_root,test_gt_dmap_root,4)
test_dataloader=torch.utils.data.DataLoader(test_dataset,batch_size=1,shuffle=False)
#training phase
if not os.path.exists('./checkpoints'):
os.mkdir('./checkpoints')
min_mae=10000
min_epoch=0
train_loss_list=[]
epoch_list=[]
test_error_list=[]
for epoch in range(0,2000):
mcnn.train()
epoch_loss=0
for i,(img,gt_dmap) in enumerate(dataloader):
img=img.to(device)
gt_dmap=gt_dmap.to(device)
# forward propagation
et_dmap=mcnn(img)
# calculate loss
loss=criterion(et_dmap,gt_dmap)
epoch_loss+=loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
#print("epoch:",epoch,"loss:",epoch_loss/len(dataloader))
epoch_list.append(epoch)
train_loss_list.append(epoch_loss/len(dataloader))
torch.save(mcnn.state_dict(),'./checkpoints/epoch_'+str(epoch)+".param")
mcnn.eval()
mae=0
for i,(img,gt_dmap) in enumerate(test_dataloader):
img=img.to(device)
gt_dmap=gt_dmap.to(device)
# forward propagation
et_dmap=mcnn(img)
mae+=abs(et_dmap.data.sum()-gt_dmap.data.sum()).item()
del img,gt_dmap,et_dmap
if mae/len(test_dataloader)<min_mae:
min_mae=mae/len(test_dataloader)
min_epoch=epoch
test_error_list.append(mae/len(test_dataloader))
print("epoch:"+str(epoch)+" error:"+str(mae/len(test_dataloader))+" min_mae:"+str(min_mae)+" min_epoch:"+str(min_epoch))
vis.line(win=1,X=epoch_list, Y=train_loss_list, opts=dict(title='train_loss'))
vis.line(win=2,X=epoch_list, Y=test_error_list, opts=dict(title='test_error'))
# show an image
index=random.randint(0,len(test_dataloader)-1)
img,gt_dmap=test_dataset[index]
vis.image(win=3,img=img,opts=dict(title='img'))
vis.image(win=4,img=gt_dmap/(gt_dmap.max())*255,opts=dict(title='gt_dmap('+str(gt_dmap.sum())+')'))
img=img.unsqueeze(0).to(device)
gt_dmap=gt_dmap.unsqueeze(0)
et_dmap=mcnn(img)
et_dmap=et_dmap.squeeze(0).detach().cpu().numpy()
vis.image(win=5,img=et_dmap/(et_dmap.max())*255,opts=dict(title='et_dmap('+str(et_dmap.sum())+')'))
import time
print(time.strftime('%Y.%m.%d %H:%M:%S',time.localtime(time.time())))