Skip to content

Commit

Permalink
imports fixes; minor loading fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
avoin committed Feb 19, 2020
1 parent 479157a commit 8fbb3da
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 17 deletions.
5 changes: 0 additions & 5 deletions GANs/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1 @@
from .normal import NormalDistribution
from .two_spheres import TwoSpheresDistribution
from .torus import TorusDistribution
from .discrete import DiscerteDistribution
from .mixed import MixedDistribution
from .sphere import SphereDistribution
8 changes: 4 additions & 4 deletions GANs/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
from GANs.models.sngan.generators.sn_gen_resnet import SN_RES_GEN_CONFIGS, make_resnet_generator
from GANs.models.sngan.generators.sn_sngen_resnet import make_snresnet_generator
from GANs.models.sngan.discriminators.sn_dis_resnet import SN_RES_DIS_CONFIGS, ResnetDiscriminator

from GANs.distributions import NormalDistribution

Expand Down Expand Up @@ -73,7 +72,7 @@ def make_models(args):
return generator


def load_model_from_state_dict(root_dir, model_index=None, cuda=True):
def load_model_from_state_dict(root_dir, model_index=None, cuda=True, verbose=False):
args = json.load(open(os.path.join(root_dir, 'args.json')))

if model_index is None:
Expand All @@ -82,11 +81,12 @@ def load_model_from_state_dict(root_dir, model_index=None, cuda=True):
[int(name.split('.')[0].split('_')[-1]) for name in models
if name.startswith('generator')])

print('using generator generator_{}.pt'.format(model_index))
if verbose:
print('using generator generator_{}.pt'.format(model_index))
generator_model_path = os.path.join(root_dir, 'generator_{}.pt'.format(model_index))

args = Args(**args)
generator, _ = make_models(args)
generator = make_models(args)
generator.load_state_dict(
torch.load(generator_model_path, map_location=torch.device('cpu')), strict=False)
if cuda:
Expand Down
1 change: 0 additions & 1 deletion GANs/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from .conv_gan import *
4 changes: 2 additions & 2 deletions constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@
WEIGHTS = {
'BigGAN': 'models/pretrained/BigGAN/138k/G_ema.pth',
'ProgGAN': 'models/pretrained/ProgGAN/100_celeb_hq_network-snapshot-010403.pth',
'SN_MNIST': 'models/GANs/SN_MNIST',
'Anime_64': 'models/GANs/SN_Anime',
'SN_MNIST': 'models/pretrained/GANs/SN_MNIST',
'Anime_64': 'models/pretrained/GANs/SN_Anime',
}
11 changes: 7 additions & 4 deletions loading.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import json
import torch
from collections import OrderedDict

from run_train import DEFORMATOR_TYPE_DICT
from models.gan_load import make_big_gan, make_proggan, make_external
Expand All @@ -9,7 +10,7 @@
from constants import WEIGHTS


def load_from_dir(root_dir, model_index=None, G_weights=None):
def load_from_dir(root_dir, model_index=None, G_weights=None, verbose=False):
args = json.load(open(os.path.join(root_dir, 'args.json')))

models_dir = os.path.join(root_dir, 'models')
Expand All @@ -19,13 +20,15 @@ def load_from_dir(root_dir, model_index=None, G_weights=None):
[int(name.split('.')[0].split('_')[-1]) for name in models
if name.startswith('deformator')])

print('using max index {}'.format(model_index))
if verbose:
print('using max index {}'.format(model_index))


if G_weights is None:
G_weights = args['gan_weights']
if G_weights is None or not os.path.isfile(G_weights):
print('Using default local G weights')
if verbose:
print('Using default local G weights')
G_weights = WEIGHTS[args['gan_type']]

if args['gan_type'] == 'BigGAN':
Expand Down Expand Up @@ -53,7 +56,7 @@ def load_from_dir(root_dir, model_index=None, G_weights=None):
directions_json = os.path.join(root_dir, 'directions.json')
if os.path.isfile(directions_json):
with open(directions_json, 'r') as f:
directions_dict = json.load(f)
directions_dict = json.load(f, object_pairs_hook=OrderedDict)
setattr(deformator, 'directions_dict', directions_dict)


Expand Down
2 changes: 1 addition & 1 deletion models/gan_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def make_proggan(weights_root):


def make_external(gan_dir):
gan = load_model_from_state_dict(gan_dir)[0]
gan = load_model_from_state_dict(gan_dir)
G = gan.model.eval()
setattr(G, 'dim_z', gan.distribution.dim)

Expand Down

0 comments on commit 8fbb3da

Please sign in to comment.