Skip to content

Commit

Permalink
made config an explicit arg for model of nat var
Browse files Browse the repository at this point in the history
  • Loading branch information
arobey1 committed Dec 14, 2020
1 parent 3a25e8c commit f49bd05
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions core/models/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def init_G(fname, reverse, args):
if args.setup_verbose is True and args.local_rank == 0:
print(f'Loading MUNIT model: {fname}')

G = MUNITModelOfNatVar(fname, reverse=reverse, args=args).cuda()
G = MUNITModelOfNatVar(fname, reverse=reverse, config=args.config).cuda()

# save model to ONNX
# if args.local_rank == 0:
Expand Down Expand Up @@ -65,19 +65,19 @@ def forward(self, x, delta):
return x

class MUNITModelOfNatVar(nn.Module):
def __init__(self, fname: str, reverse: bool, args: dict):
def __init__(self, fname: str, reverse: bool, config: str):
"""Instantiantion of pre-trained MUNIT model.
Params:
fname: File name of trained MUNIT checkpoint file.
reverse: If True, returns model mapping from domain A-->B.
otherwise, model maps from B-->A.
args: train.py command line arguments.
config: Configuration .yaml file for MUNIT.
"""

super(MUNITModelOfNatVar, self).__init__()

self._config = self.__get_config(args.config)
self._config = self.__get_config(config)
self._fname = fname
self._reverse = reverse
self._gen_A, self._gen_B = self.__load()
Expand Down

0 comments on commit f49bd05

Please sign in to comment.