Skip to content

Commit

Permalink
Add yaml configuration support
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewekhalel committed Nov 26, 2018
1 parent 772604c commit b95c0c8
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 11 deletions.
3 changes: 3 additions & 0 deletions conf/imagenet.yaml
@@ -0,0 +1,3 @@
augs: [NO,FLIP_LR]
mean: ARITH
bits: 8
3 changes: 3 additions & 0 deletions conf/pascal_voc.yaml
@@ -0,0 +1,3 @@
augs: [NO,FLIP_UD,FLIP_LR]
mean: ARITH
bits: 8
10 changes: 4 additions & 6 deletions edafa/BasePredictor.py
Expand Up @@ -4,9 +4,9 @@
from .exceptions import *
import numpy as np
from abc import ABC, abstractmethod
import json
import warnings


class BasePredictor(ABC):
"""
An abstract class (wrapper for your model) to apply test time augmentation (TTA)
Expand Down Expand Up @@ -61,11 +61,9 @@ def _parse_conf(self,conf):
:param conf: configuration (json string or file path)
"""
try:
loaded = json.loads(conf)
except:
with open(conf) as f:
loaded = json.load(f)
loaded = conf_to_dict(conf)
if loaded is None:
raise ConfigurationUnrecognized("Unrecognized configuration!")

if "augs" in loaded:
self.augs = loaded["augs"]
Expand Down
7 changes: 7 additions & 0 deletions edafa/exceptions.py
Expand Up @@ -12,3 +12,10 @@ class MeanUnrecognized(Exception):
"""
def __init__(self, message):
self.message = message

class ConfigurationUnrecognized(Exception):
"""
An exception to indicate passed configuration is unrecognized
"""
def __init__(self, message):
self.message = message
14 changes: 13 additions & 1 deletion edafa/tests/test_logistics.py
Expand Up @@ -19,7 +19,7 @@ def __init__(self, *args, **kwargs):

def test_pass_json_file(self):
"""
Test configuration file loading
Test json configuration file loading
"""
p = Child(os.path.join(self.path,"conf/pascal_voc.json"))
self.assertTrue(p.augs == ["NO",
Expand All @@ -28,6 +28,18 @@ def test_pass_json_file(self):
self.assertTrue(p.mean == "ARITH")
self.assertTrue(p.bits == 8)

def test_pass_yaml_file(self):
"""
Test yaml configuration file loading
"""
p = Child(os.path.join(self.path,"conf/pascal_voc.yaml"))
self.assertTrue(p.augs == ["NO",
"FLIP_UD",
"FLIP_LR"])
self.assertTrue(p.mean == "ARITH")
self.assertTrue(p.bits == 8)


def test_pass_json_string(self):
"""
Test configuration as string
Expand Down
25 changes: 24 additions & 1 deletion edafa/utils.py
@@ -1,7 +1,9 @@
from __future__ import absolute_import
import cv2
import math
import numpy as np

import ruamel.yaml as yaml
import json

# EXTENSIONS = ['jpg','png','tif','tiff']
AUGS = ['NO','ROT90','ROT180','ROT270','FLIP_UD','FLIP_LR','BRIGHT','CONTRAST','GAUSSIAN', 'GAMMA']
Expand Down Expand Up @@ -239,3 +241,24 @@ def reverse(aug,img):
return img
elif aug == "GAMMA":
return img

def conf_to_dict(conf):
result = None
# json string?
try:
result = json.loads(conf)
except:
# json file?
try:
with open(conf) as f:
result = json.load(f)
except:
# yaml file?
try:
with open(conf) as stream:
result = yaml.safe_load(stream)
except:
pass
return result


3 changes: 0 additions & 3 deletions setup.py
Expand Up @@ -16,9 +16,6 @@ def readme():
'License :: OSI Approved :: MIT License',
'Operating System :: OS Independent',
'Programming Language :: Python',
'Programming Language :: Python :: 2',
'Programming Language :: Python :: 2.6',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.1',
'Programming Language :: Python :: 3.2',
Expand Down

0 comments on commit b95c0c8

Please sign in to comment.