-
Notifications
You must be signed in to change notification settings - Fork 68
/
train_diffusion.json
174 lines (174 loc) · 5.23 KB
/
train_diffusion.json
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
{
"ckpt_dir": "$@bundle_root + '/models'",
"train_batch_size_img": 2,
"train_batch_size_slice": 50,
"lr": 5e-05,
"train_patch_size": [
256,
256
],
"latent_shape": [
"@latent_channels",
64,
64
],
"load_autoencoder_path": "$@bundle_root + '/models/model_autoencoder.pt'",
"load_autoencoder": "$@autoencoder_def.load_state_dict(torch.load(@load_autoencoder_path))",
"autoencoder": "$@autoencoder_def.to(@device)",
"network_def": {
"_target_": "generative.networks.nets.DiffusionModelUNet",
"spatial_dims": "@spatial_dims",
"in_channels": "@latent_channels",
"out_channels": "@latent_channels",
"num_channels": [
32,
64,
128,
256
],
"attention_levels": [
false,
true,
true,
true
],
"num_head_channels": [
0,
32,
32,
32
],
"num_res_blocks": 2
},
"diffusion": "$@network_def.to(@device)",
"optimizer": {
"_target_": "torch.optim.Adam",
"params": "$@diffusion.parameters()",
"lr": "@lr"
},
"lr_scheduler": {
"_target_": "torch.optim.lr_scheduler.MultiStepLR",
"optimizer": "@optimizer",
"milestones": [
1000
],
"gamma": 0.1
},
"scale_factor": "$scripts.utils.compute_scale_factor(@autoencoder,@train#dataloader,@device)",
"noise_scheduler": {
"_target_": "generative.networks.schedulers.DDPMScheduler",
"_requires_": [
"@load_autoencoder"
],
"schedule": "scaled_linear_beta",
"num_train_timesteps": 1000,
"beta_start": 0.0015,
"beta_end": 0.0195
},
"loss": {
"_target_": "torch.nn.MSELoss"
},
"train": {
"inferer": {
"_target_": "generative.inferers.LatentDiffusionInferer",
"scheduler": "@noise_scheduler",
"scale_factor": "@scale_factor"
},
"crop_transforms": [
{
"_target_": "DivisiblePadd",
"keys": "image",
"k": [
32,
32,
1
]
},
{
"_target_": "RandSpatialCropSamplesd",
"keys": "image",
"random_size": false,
"roi_size": "$[@train_patch_size[0], @train_patch_size[1], 1]",
"num_samples": "@train_batch_size_slice"
},
{
"_target_": "SqueezeDimd",
"keys": "image",
"dim": 3
}
],
"preprocessing": {
"_target_": "Compose",
"transforms": "$@preprocessing_transforms + @train#crop_transforms"
},
"dataset": {
"_target_": "monai.apps.DecathlonDataset",
"root_dir": "@dataset_dir",
"task": "Task01_BrainTumour",
"section": "training",
"cache_rate": 1.0,
"num_workers": 8,
"download": false,
"transform": "@train#preprocessing"
},
"dataloader": {
"_target_": "DataLoader",
"dataset": "@train#dataset",
"batch_size": "@train_batch_size_img",
"shuffle": true,
"num_workers": 0
},
"handlers": [
{
"_target_": "LrScheduleHandler",
"lr_scheduler": "@lr_scheduler",
"print_lr": true
},
{
"_target_": "CheckpointSaver",
"save_dir": "@ckpt_dir",
"save_dict": {
"model": "@diffusion"
},
"save_interval": 0,
"save_final": true,
"epoch_level": true,
"final_filename": "model.pt"
},
{
"_target_": "StatsHandler",
"tag_name": "train_diffusion_loss",
"output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)"
},
{
"_target_": "TensorBoardStatsHandler",
"log_dir": "@tf_dir",
"tag_name": "train_diffusion_loss",
"output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)"
}
],
"trainer": {
"_target_": "scripts.ldm_trainer.LDMTrainer",
"device": "@device",
"max_epochs": 1000,
"train_data_loader": "@train#dataloader",
"network": "@diffusion",
"autoencoder_model": "@autoencoder",
"optimizer": "@optimizer",
"loss_function": "@loss",
"latent_shape": "@latent_shape",
"inferer": "@train#inferer",
"key_train_metric": "$None",
"train_handlers": "@train#handlers"
}
},
"initialize": [
"$monai.utils.set_determinism(seed=0)"
],
"run": [
"@load_autoencoder",
"$@autoencoder.eval()",
"$print('scale factor:',@scale_factor)",
"$@train#trainer.run()"
]
}