Skip to content

Commit

Permalink
Merge pull request #2835 from daspecster/vision-safe-search-system-test
Browse files Browse the repository at this point in the history
Add Vision system tests for detect_safe_search().
  • Loading branch information
daspecster committed Dec 7, 2016
2 parents 3598a1e + c3fb287 commit 8fd98e0
Showing 1 changed file with 60 additions and 8 deletions.
68 changes: 60 additions & 8 deletions system_tests/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ def _assert_coordinate(self, coordinate):
self.assertIsInstance(coordinate, (int, float))
self.assertNotEqual(coordinate, 0.0)

def _assert_likelihood(self, likelihood):
from google.cloud.vision.likelihood import Likelihood

levels = [Likelihood.UNKNOWN, Likelihood.VERY_LIKELY,
Likelihood.UNLIKELY, Likelihood.POSSIBLE, Likelihood.LIKELY,
Likelihood.VERY_UNLIKELY]
self.assertIn(likelihood, levels)


class TestVisionClientLogo(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -130,14 +138,6 @@ def tearDown(self):
for value in self.to_delete_by_case:
value.delete()

def _assert_likelihood(self, likelihood):
from google.cloud.vision.likelihood import Likelihood

levels = [Likelihood.UNKNOWN, Likelihood.VERY_LIKELY,
Likelihood.UNLIKELY, Likelihood.POSSIBLE, Likelihood.LIKELY,
Likelihood.VERY_UNLIKELY]
self.assertIn(likelihood, levels)

def _assert_landmarks(self, landmarks):
from google.cloud.vision.face import Landmark
from google.cloud.vision.face import LandmarkTypes
Expand Down Expand Up @@ -340,3 +340,55 @@ def test_detect_landmark_filename(self):
self.assertEqual(len(landmarks), 1)
landmark = landmarks[0]
self._assert_landmark(landmark)


class TestVisionClientSafeSearch(BaseVisionTestCase):
def setUp(self):
self.to_delete_by_case = []

def tearDown(self):
for value in self.to_delete_by_case:
value.delete()

def _assert_safe_search(self, safe_search):
from google.cloud.vision.safe import SafeSearchAnnotation

self.assertIsInstance(safe_search, SafeSearchAnnotation)
self._assert_likelihood(safe_search.adult)
self._assert_likelihood(safe_search.spoof)
self._assert_likelihood(safe_search.medical)
self._assert_likelihood(safe_search.violence)

def test_detect_safe_search_content(self):
client = Config.CLIENT
with open(FACE_FILE, 'rb') as image_file:
image = client.image(content=image_file.read())
safe_searches = image.detect_safe_search()
self.assertEqual(len(safe_searches), 1)
safe_search = safe_searches[0]
self._assert_safe_search(safe_search)

def test_detect_safe_search_gcs(self):
bucket_name = Config.TEST_BUCKET.name
blob_name = 'faces.jpg'
blob = Config.TEST_BUCKET.blob(blob_name)
self.to_delete_by_case.append(blob) # Clean-up.
with open(FACE_FILE, 'rb') as file_obj:
blob.upload_from_file(file_obj)

source_uri = 'gs://%s/%s' % (bucket_name, blob_name)

client = Config.CLIENT
image = client.image(source_uri=source_uri)
safe_searches = image.detect_safe_search()
self.assertEqual(len(safe_searches), 1)
safe_search = safe_searches[0]
self._assert_safe_search(safe_search)

def test_detect_safe_search_filename(self):
client = Config.CLIENT
image = client.image(filename=FACE_FILE)
safe_searches = image.detect_safe_search()
self.assertEqual(len(safe_searches), 1)
safe_search = safe_searches[0]
self._assert_safe_search(safe_search)

0 comments on commit 8fd98e0

Please sign in to comment.