Permalink
Browse files

Backport full densenet loading

  • Loading branch information...
ahirner committed Oct 6, 2017
1 parent 5626ada commit 23371ca3ed6c3885e912a5015914576e078d0ae1
Showing with 69 additions and 283 deletions.
  1. +0 −256 load_pretrained.ipynb
  2. +29 −8 retrain.py
  3. +40 −19 retrain_benchmark_bees.ipynb
View

This file was deleted.

Oops, something went wrong.
View
@@ -76,9 +76,9 @@
print("Shootout of model(s) %s with batch_size %d running on CUDA %s " % \
(", ".join(models_to_test), batch_size, use_gpu) + \
"with CLR %s for %d classes on data in %s." % \
(use_clr, len(classes), data_dir))
(", ".join(models_to_test), batch_size, use_gpu) + \
"with CLR %s for %d classes on data in %s." % \
(use_clr, len(classes), data_dir))
# ### Generic pretrained model loading
@@ -107,11 +107,32 @@ def diff_states(dict_canonical, dict_subset):
def load_model_merged(name, num_classes):
model = models.__dict__[name](num_classes=num_classes)
#Densenets don't (yet) pass on num_classes, hack it in for 169
if name == 'densenet169':
model = torchvision.models.DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), num_classes=num_classes)
# Densenets don't (yet) pass on num_classes, hack it in
if "densenet" in name:
if name == 'densenet169':
return models.DenseNet(num_init_features=64, growth_rate=32, \
block_config=(6, 12, 32, 32),
num_classes=num_classes)
elif name == 'densenet121':
return models.DenseNet(num_init_features=64, growth_rate=32, \
block_config=(6, 12, 24, 16),
num_classes=num_classes)
elif name == 'densenet201':
return models.DenseNet(num_init_features=64, growth_rate=32, \
block_config=(6, 12, 48, 32),
num_classes=num_classes)
elif name == 'densenet161':
return models.DenseNet(num_init_features=96, growth_rate=48, \
block_config=(6, 12, 36, 24),
num_classes=num_classes)
else:
raise ValueError(
"Cirumventing missing num_classes kwargs not implemented for %s" % name)
pretrained_state = model_zoo.load_url(model_urls[name])
#Diff
Oops, something went wrong.

0 comments on commit 23371ca

Please sign in to comment.