-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
97 lines (79 loc) · 3.24 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import argparse
from cpc.config import get_argument_parser, parse_and_process_arguments, get_trainer
from cpc.data import (
IDRiDDataModule,
IDRiDDataModuleSemi, IDRiDDataModuleExtra,
IDRiDDataModuleBaseCP, IDRiDDataModulePseudoCP, IDRiDDataModuleCP,
ICHDataModule,
ICHDataModuleSemi, ICHDataModuleExtra,
ICHDataModuleBaseCP, ICHDataModulePseudoCP, ICHDataModuleCP
)
from cpc.model import (
BinaryUNet, BinaryUNetCP, BinaryUNetPseudo, BinaryUNetMT, BinaryUNetClassMix, BinaryUNetCutMix
)
DATA_MODULES = {
"idrid": IDRiDDataModule,
"idrid-semi": IDRiDDataModuleSemi,
"idrid-st": IDRiDDataModuleExtra,
"idrid-base-cp": IDRiDDataModuleBaseCP,
"idrid-st-cp": IDRiDDataModulePseudoCP,
"idrid-cp": IDRiDDataModuleCP,
"ich": ICHDataModule,
"ich-semi": ICHDataModuleSemi,
"ich-st": ICHDataModuleExtra,
"ich-base-cp": ICHDataModuleBaseCP,
"ich-st-cp": ICHDataModulePseudoCP,
"ich-cp": ICHDataModuleCP
}
LN_MODULES = {
"unet": BinaryUNet,
"unet-cp": BinaryUNetCP,
"unet-pseudo": BinaryUNetPseudo,
"unet-mt": BinaryUNetMT,
"unet-classmix": BinaryUNetClassMix,
"unet-cutmix": BinaryUNetCutMix
}
MODULE_COMPATIBILITY = {
# input shape: (img, mask)
"unet": {"idrid", "ich"},
# input shape: (img, mask), (synth_img, synth_mask, background)
"unet-cp": {"idrid-cp", "ich-cp"},
# input shape: (img, mask), (synth_img, synth_mask)
"unet-pseudo": {"idrid-st", "idrid-base-cp", "idrid-st-cp", "ich-st", "ich-base-cp", "ich-st-cp"},
# input shape: (img, mask), img
"unet-mt": {"idrid-semi", "ich-semi"},
"unet-classmix": {"idrid-semi", "ich-semi"},
"unet-cutmix": {"idrid-semi", "ich-semi"}
}
def check_compatibility(dm_name, model_name):
if dm_name not in MODULE_COMPATIBILITY[model_name]:
raise ValueError(f"'{model_name}' is not compatible with datamodule '{dm_name}'.")
def main(_args, _dm_cls, _model_cls):
data_args, model_args, trainer_args, other_args = _args
dm = _dm_cls(**data_args)
model = _model_cls(**model_args)
trainer = get_trainer(**trainer_args)
if other_args["do_train"]:
# Train
if trainer_args["auto_lr_find"]:
# run learning rate finder before training
trainer.tune(model, dm)
trainer.fit(model, dm)
if other_args["do_test"]:
# Test
if not other_args["do_train"]:
model = _model_cls.load_from_checkpoint(trainer_args["resume_from_checkpoint"], strict=False, **model_args)
trainer.test(model, dm)
if __name__ == "__main__":
entry_parser = argparse.ArgumentParser(add_help=False)
entry_parser.add_argument('data_module', type=str, choices=list(DATA_MODULES.keys()))
entry_parser.add_argument('model', type=str, choices=list(LN_MODULES.keys()))
entry_args = entry_parser.parse_known_args()[0]
check_compatibility(entry_args.data_module, entry_args.model)
dm_cls = DATA_MODULES[entry_args.data_module]
model_cls = LN_MODULES[entry_args.model]
parser = get_argument_parser(dm_cls, model_cls, parents=[entry_parser])
parser.add_argument("--do_train", action='store_true')
parser.add_argument("--do_test", action='store_true')
args = parse_and_process_arguments(parser)
main(args, dm_cls, model_cls)