-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
29 lines (23 loc) · 1.03 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from networks.Qnetworks import setup_networks
from options.options import gather_options, print_options
import torch
import numpy as np
import random
from utils import train
if __name__=="__main__":
# 1. gather options
parser = gather_options(phase="train")
config = parser.parse_args()
config.use_cuda = torch.cuda.is_available()
config.device = torch.device("cuda:{}".format(config.gpu_device) if config.use_cuda else "cpu")
print_options(config, parser)
# assert we are not using clinical data (we can only train on XCAT/fakeCT data with our framework)
assert not config.realCT, "can only test on real clinical data."
# manual seeds
torch.manual_seed(config.seed)
np.random.seed(config.seed)
random.seed(config.seed)
# 2. instanciate Qnetworks
qnetwork_local, qnetwork_target = setup_networks(config)
# 3. launch training
train(config, qnetwork_local, qnetwork_target, name=config.name, sweep=False)