Skip to content

Commit

Permalink
[Tests] Refactor to allow custom IMAGE_COUNT
Browse files Browse the repository at this point in the history
  • Loading branch information
lukeyeager committed Nov 11, 2016
1 parent 3f02bc6 commit 4308a53
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
16 changes: 6 additions & 10 deletions digits/dataset/images/classification/test_imageset_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,18 @@
import PIL.Image


IMAGE_SIZE = 10
IMAGE_COUNT = 10 # per category


def create_classification_imageset(folder, image_size=None, image_count=None, add_unbalanced_category=False):
def create_classification_imageset(
folder,
image_size=10,
image_count=10,
add_unbalanced_category=False,
):
"""
Creates a folder of folders of images for classification
If requested to add an unbalanced category then a category is added with
half the number of samples of other categories
"""
if image_size is None:
image_size = IMAGE_SIZE
if image_count is None:
image_count = IMAGE_COUNT

# Stores the relative path of each image of the dataset
paths = defaultdict(list)

Expand Down
20 changes: 12 additions & 8 deletions digits/dataset/images/classification/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from bs4 import BeautifulSoup
import PIL.Image

from .test_imageset_creator import create_classification_imageset, IMAGE_COUNT as DUMMY_IMAGE_COUNT
from .test_imageset_creator import create_classification_imageset
from digits import test_utils
import digits.test_views

Expand Down Expand Up @@ -64,6 +64,7 @@ class BaseViewsTestWithImageset(BaseViewsTest):
Provides an imageset and some functions
"""
# Inherited classes may want to override these default attributes
IMAGE_COUNT = 10 # per class
IMAGE_HEIGHT = 10
IMAGE_WIDTH = 10
IMAGE_CHANNELS = 3
Expand All @@ -78,8 +79,11 @@ def setUpClass(cls):
super(BaseViewsTestWithImageset, cls).setUpClass()
cls.imageset_folder = tempfile.mkdtemp()
# create imageset
cls.imageset_paths = create_classification_imageset(cls.imageset_folder,
add_unbalanced_category=cls.UNBALANCED_CATEGORY)
cls.imageset_paths = create_classification_imageset(
cls.imageset_folder,
image_count=cls.IMAGE_COUNT,
add_unbalanced_category=cls.UNBALANCED_CATEGORY,
)
cls.created_datasets = []

@classmethod
Expand Down Expand Up @@ -363,7 +367,7 @@ def check_image_count(self, type):
assert parse_info['val_count'] == 0
image_count = parse_info['test_count']
assert self.categoryCount() == parse_info['label_count']
assert image_count == DUMMY_IMAGE_COUNT * parse_info['label_count'], 'image count mismatch'
assert image_count == self.IMAGE_COUNT * parse_info['label_count'], 'image count mismatch'
assert self.delete_dataset(job_id) == 200, 'delete failed'
assert not self.dataset_exists(job_id), 'dataset exists after delete'

Expand All @@ -375,9 +379,9 @@ def test_max_per_class(self):
yield self.check_max_per_class, type

def check_max_per_class(self, type):
# create dataset, asking for at most DUMMY_IMAGE_COUNT/2 images per class
assert DUMMY_IMAGE_COUNT % 2 == 0
max_per_class = DUMMY_IMAGE_COUNT / 2
# create dataset, asking for at most IMAGE_COUNT/2 images per class
assert self.IMAGE_COUNT % 2 == 0
max_per_class = self.IMAGE_COUNT / 2
data = {'folder_pct_val': 0}
if type == 'train':
data['folder_train_max_per_class'] = max_per_class
Expand Down Expand Up @@ -418,7 +422,7 @@ def test_min_per_class(self):
def check_min_per_class(self, type):
# create dataset, asking for one more image per class
# than available in the "unbalanced" category
min_per_class = DUMMY_IMAGE_COUNT / 2 + 1
min_per_class = self.IMAGE_COUNT / 2 + 1
data = {'folder_pct_val': 0}
if type == 'train':
data['folder_train_min_per_class'] = min_per_class
Expand Down

0 comments on commit 4308a53

Please sign in to comment.