Skip to content

Commit

Permalink
fgfgfg
Browse files Browse the repository at this point in the history
  • Loading branch information
asafgal committed Apr 1, 2020
1 parent 480dcde commit ad57ad6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
23 changes: 21 additions & 2 deletions antrax/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ def load(modelfile):
if c.classes is not None:
c.trained = True

if 'unknown_weight' not in c.prmtrs:
c.prmtrs['unknown_weight'] = 20

if 'multi_weight' not in c.prmtrs:
c.prmtrs['multi_weight'] = 0.1

if 'crop_size' not in c.prmtrs:
c.prmtrs['crop_size'] = None

return c

def __init__(self,
Expand All @@ -71,6 +80,8 @@ def __init__(self,
examplesdir=None,
target_size=64,
crop_size=None,
unknown_weight=20,
multi_weight=0.1,
scale=1,
loaded=False,
model=None,
Expand All @@ -95,6 +106,8 @@ def __init__(self,
self.prmtrs['background'] = background
self.prmtrs['target_size'] = target_size
self.prmtrs['crop_size'] = crop_size
self.prmtrs['unknown_weight'] = unknown_weight
self.prmtrs['multi_weight'] = multi_weight
self.prmtrs['scale'] = scale
self.prmtrs['loss'] = loss
self.prmtrs['optimizer'] = optimizer
Expand Down Expand Up @@ -452,7 +465,7 @@ def validate(self, examplesdir):

return error

def train(self, examplesdir, from_scratch=False, ne=5, unknown_weight=20, multi_weight=0.1, verbose=1, target_size=None, crop_size=None):
def train(self, examplesdir, from_scratch=False, ne=5, unknown_weight=None, multi_weight=None, verbose=1, target_size=None, crop_size=None):

if isinstance(examplesdir, list):
rm_after = True
Expand Down Expand Up @@ -482,7 +495,13 @@ def train(self, examplesdir, from_scratch=False, ne=5, unknown_weight=20, multi_
print('User asked, starting training from scratch')
self.reset_model()


if unknown_weight is not None:
self.prmtrs['unknown_weight'] = unknown_weight

if multi_weight is not None:
self.prmtrs['multi_weight'] = multi_weight


# create data generators
prepfun = None if self.prmtrs['scale'] == 1 else scale_and_crop

Expand Down
2 changes: 1 addition & 1 deletion antrax/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def train(classdir, *, name='classifier', scratch=False, ne=5, unknown_weight=2
f = glob(examplesdir + '/*/*.png')[0]
target_size = max(imread(f).shape)

c = axClassifier(name, nclasses=n, target_size=target_size, crop_size=crop_size)
c = axClassifier(name, nclasses=n, target_size=target_size, crop_size=crop_size, unknown_weight=unknown_weight, multi_weight=multi_weight)

c.train(examplesdir, from_scratch=scratch, ne=ne, multi_weight=multi_weight, unknown_weight=unknown_weight,
target_size=target_size, crop_size=crop_size)
Expand Down

0 comments on commit ad57ad6

Please sign in to comment.