forked from WangLibo1995/GeoSeg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
109 lines (82 loc) · 3.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
from tools.cfg import py2cfg
import argparse
from datetime import datetime
from pathlib import Path
def get_args():
parser = argparse.ArgumentParser()
arg = parser.add_argument
arg("-c", "--config_path", type=Path, help="Path to the config.", required=True)
return parser.parse_args()
args = get_args()
config = py2cfg(args.config_path)
training_loader = config.train_loader
validation_loader = config.val_loader
optimizer = config.optimizer
model = config.model.to("cuda:9")
loss_fn = config.loss
for name, param in model.named_parameters():
if "pos" in name:
print(name, param)
def train_one_epoch(epoch_index, tb_writer):
running_loss = 0.
last_loss = 0.
# Here, we use enumerate(training_loader) instead of
# iter(training_loader) so that we can track the batch
# index and do some intra-epoch reporting
for i, data in enumerate(training_loader):
# Every data instance is an input + label pair
inputs, labels = data["img"].to("cuda:9"), data["gt_semantic_seg"].to("cuda:9")
# Zero your gradients for every batch!
optimizer.zero_grad()
# Make predictions for this batch
con_loss, outputs = model(inputs, labels)
# Compute the loss and its gradients
loss = loss_fn(outputs, labels)
print("con_loss:", con_loss)
print("loss:", loss)
loss += con_loss[0]
loss.backward()
# Adjust learning weights
optimizer.step()
# Gather data and report
running_loss += loss.item()
if i % 1000 == 999:
last_loss = running_loss / 1000 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
tb_x = epoch_index * len(training_loader) + i + 1
#tb_writer.add_scalar('Loss/train', last_loss, tb_x)
running_loss = 0.
return last_loss
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
epoch_number = 0
EPOCHS = 5
best_vloss = 1_000_000.
for epoch in range(EPOCHS):
print('EPOCH {}:'.format(epoch_number + 1))
# Make sure gradient tracking is on, and do a pass over the data
model.train(True)
avg_loss = train_one_epoch(epoch_number, None)
# We don't need gradients on to do reporting
model.train(False)
running_vloss = 0.0
for i, vdata in enumerate(validation_loader):
vinputs, vlabels = vdata
voutputs = model(vinputs)
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
avg_vloss = running_vloss / (i + 1)
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
# Log the running loss averaged per batch
# for both training and validation
#writer.add_scalars('Training vs. Validation Loss',
# { 'Training' : avg_loss, 'Validation' : avg_vloss },
# epoch_number + 1)
#writer.flush()
# Track best performance, and save the model's state
if avg_vloss < best_vloss:
best_vloss = avg_vloss
model_path = 'model_{}_{}'.format(timestamp, epoch_number)
torch.save(model.state_dict(), model_path)
epoch_number += 1