Skip to content

Commit

Permalink
Turn session into contextmanager
Browse files Browse the repository at this point in the history
  • Loading branch information
c-w committed Aug 17, 2018
1 parent 581cc4a commit 2948dc5
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 176 deletions.
38 changes: 16 additions & 22 deletions app/faceanalysis/auth.py
Expand Up @@ -3,8 +3,8 @@
from itsdangerous import TimedJSONWebSignatureSerializer as Serializer
from passlib.apps import custom_app_context as password_context

from faceanalysis.models.database_manager import get_database_manager
from faceanalysis.models.models import User
from faceanalysis.models.models import get_db_session
from faceanalysis.settings import TOKEN_EXPIRATION
from faceanalysis.settings import TOKEN_SECRET_KEY

Expand Down Expand Up @@ -42,22 +42,18 @@ def load_user_from_auth_token(token: str) -> User:
except (BadSignature, SignatureExpired):
raise InvalidAuthToken()

db = get_database_manager()
session = db.get_session()
user = session.query(User)\
.filter(User.id == data['id'])\
.first()
session.close()
with get_db_session() as session:
user = session.query(User)\
.filter(User.id == data['id'])\
.first()
return user


def load_user(username: str, password: str) -> User:
db = get_database_manager()
session = db.get_session()
user = session.query(User) \
.filter(User.username == username) \
.first()
session.close()
with get_db_session() as session:
user = session.query(User) \
.filter(User.username == username) \
.first()

if not user:
raise UserDoesNotExist()
Expand All @@ -69,19 +65,17 @@ def load_user(username: str, password: str) -> User:


def register_user(username: str, password: str):
db = get_database_manager()
session = db.get_session()
user = session.query(User) \
.filter(User.username == username) \
.first()
session.close()
with get_db_session() as session:
user = session.query(User) \
.filter(User.username == username) \
.first()

if user is not None:
raise DuplicateUser()

user = User()
user.username = username
user.password_hash = password_context.encrypt(password)
session = db.get_session()
session.add(user)
db.safe_commit(session)

with get_db_session(commit=True) as session:
session.add(user)
61 changes: 25 additions & 36 deletions app/faceanalysis/domain/docker.py
Expand Up @@ -7,23 +7,21 @@
from faceanalysis.domain.errors import ImageAlreadyProcessed
from faceanalysis.domain.errors import ImageDoesNotExist
from faceanalysis.log import get_logger
from faceanalysis.models.database_manager import get_database_manager
from faceanalysis.models.image_status_enum import ImageStatusEnum
from faceanalysis.models.models import Image
from faceanalysis.models.models import ImageStatus
from faceanalysis.models.models import Match
from faceanalysis.models.models import get_db_session
from faceanalysis.storage import store_image

logger = get_logger(__name__)


def process_image(img_id: str):
db = get_database_manager()
session = db.get_session()
img_status = session.query(ImageStatus) \
.filter(ImageStatus.img_id == img_id) \
.first()
session.close()
with get_db_session() as session:
img_status = session.query(ImageStatus) \
.filter(ImageStatus.img_id == img_id) \
.first()

if img_status is None:
raise ImageDoesNotExist()
Expand All @@ -36,12 +34,10 @@ def process_image(img_id: str):


def get_processing_status(img_id: str) -> Tuple[str, str]:
db = get_database_manager()
session = db.get_session()
img_status = session.query(ImageStatus) \
.filter(ImageStatus.img_id == img_id) \
.first()
session.close()
with get_db_session() as session:
img_status = session.query(ImageStatus) \
.filter(ImageStatus.img_id == img_id) \
.first()

if img_status is None:
raise ImageDoesNotExist()
Expand All @@ -52,12 +48,11 @@ def get_processing_status(img_id: str) -> Tuple[str, str]:

def upload_image(stream: IO[bytes], filename: str) -> str:
img_id = filename[:filename.find('.')]
db = get_database_manager()
session = db.get_session()
prev_img_upload = session.query(ImageStatus) \
.filter(ImageStatus.img_id == img_id) \
.first()
session.close()

with get_db_session() as session:
prev_img_upload = session.query(ImageStatus) \
.filter(ImageStatus.img_id == img_id) \
.first()

if prev_img_upload is not None:
raise DuplicateImage()
Expand All @@ -66,37 +61,31 @@ def upload_image(stream: IO[bytes], filename: str) -> str:
img_status = ImageStatus(img_id=img_id,
status=ImageStatusEnum.uploaded.name,
error_msg=None)
session = db.get_session()
session.add(img_status)
db.safe_commit(session)
logger.debug('Image %s uploaded', img_id)

with get_db_session(commit=True) as session:
session.add(img_status)

logger.debug('Image %s uploaded', img_id)
return img_id


def list_images() -> List[str]:
db = get_database_manager()
session = db.get_session()
query = session.query(Image) \
.all()
image_ids = [image.img_id for image in query]
session.close()
with get_db_session() as session:
image_ids = [image.img_id for image in session.query(Image).all()]

logger.debug('Got %d images overall', len(image_ids))
return image_ids


def lookup_matching_images(img_id: str) -> Tuple[List[str], List[float]]:
db = get_database_manager()
session = db.get_session()
query = session.query(Match) \
.filter(Match.this_img_id == img_id) \
.all()
session.close()
with get_db_session() as session:
matches = session.query(Match) \
.filter(Match.this_img_id == img_id) \
.all()

images = []
distances = []
for match in query:
for match in matches:
images.append(match.that_img_id)
distances.append(match.distance_score)

Expand Down
49 changes: 19 additions & 30 deletions app/faceanalysis/domain/faceapi.py
Expand Up @@ -11,9 +11,9 @@
from faceanalysis.domain.errors import DuplicateImage
from faceanalysis.domain.errors import ImageDoesNotExist
from faceanalysis.log import get_logger
from faceanalysis.models.database_manager import get_database_manager
from faceanalysis.models.image_status_enum import ImageStatusEnum
from faceanalysis.models.models import FaceApiMapping
from faceanalysis.models.models import get_db_session
from faceanalysis.settings import FACE_API_ACCESS_KEY
from faceanalysis.settings import FACE_API_ENDPOINT
from faceanalysis.settings import FACE_API_MODEL_ID
Expand Down Expand Up @@ -98,15 +98,12 @@ def upload_image(stream: IO[bytes], filename: str) -> str:
model_id, _ = _get_model_id()
img_id = filename[:filename.find('.')]

db = get_database_manager()
session = db.get_session()

mapping = session.query(FaceApiMapping)\
.filter(FaceApiMapping.img_id == img_id)\
.first()
with get_db_session() as session:
mapping = session.query(FaceApiMapping)\
.filter(FaceApiMapping.img_id == img_id)\
.first()

if mapping:
session.close()
raise DuplicateImage()

storage.store_image(stream, filename)
Expand All @@ -122,33 +119,27 @@ def upload_image(stream: IO[bytes], filename: str) -> str:
user_data=img_id)
face_id = response['persistedFaceId']

mapping = FaceApiMapping(face_id=face_id, img_id=img_id)
session.add(mapping)
db.safe_commit(session)
with get_db_session(commit=True) as session:
mapping = FaceApiMapping(face_id=face_id, img_id=img_id)
session.add(mapping)

return img_id


def list_images() -> List[str]:
db = get_database_manager()
session = db.get_session()
mappings = session.query(FaceApiMapping)\
.all()
session.close()

image_ids = [mapping.img_id for mapping in mappings]
with get_db_session() as session:
image_ids = [mapping.img_id
for mapping in session.query(FaceApiMapping).all()]

logger.debug('Got %d images', len(image_ids))
return image_ids


def _fetch_faces_for_person(img_id: str) -> Tuple[List[str], str]:
db = get_database_manager()
session = db.get_session()
mapping = session.query(FaceApiMapping)\
.filter(FaceApiMapping.img_id == img_id) \
.first()
session.close()
with get_db_session() as session:
mapping = session.query(FaceApiMapping)\
.filter(FaceApiMapping.img_id == img_id) \
.first()

if not mapping:
logger.debug('No mapping found for image %s', img_id)
Expand Down Expand Up @@ -188,12 +179,10 @@ def _fetch_matching_faces(face_ids: List[str]) -> Dict[str, float]:


def _fetch_mappings_for_faces(face_ids: Iterable[str]) -> List[FaceApiMapping]:
db = get_database_manager()
session = db.get_session()
mappings = session.query(FaceApiMapping) \
.filter(FaceApiMapping.face_id.in_(face_ids)) \
.all()
session.close()
with get_db_session() as session:
mappings = session.query(FaceApiMapping) \
.filter(FaceApiMapping.face_id.in_(face_ids)) \
.all()
return mappings


Expand Down
63 changes: 31 additions & 32 deletions app/faceanalysis/face_matcher.py
Expand Up @@ -4,25 +4,24 @@
from typing import Tuple

import numpy as np
from sqlalchemy.orm import Session

from faceanalysis.face_vectorizer import face_vector_from_text, FaceVector
from faceanalysis.face_vectorizer import face_vector_to_text
from faceanalysis.face_vectorizer import get_face_vectors
from faceanalysis.log import get_logger
from faceanalysis.models.database_manager import Session
from faceanalysis.models.database_manager import get_database_manager
from faceanalysis.models.image_status_enum import ImageStatusEnum
from faceanalysis.models.models import FeatureMapping
from faceanalysis.models.models import Image
from faceanalysis.models.models import ImageStatus
from faceanalysis.models.models import Match
from faceanalysis.models.models import get_db_session
from faceanalysis.settings import DISTANCE_SCORE_THRESHOLD
from faceanalysis.settings import FACE_VECTORIZE_ALGORITHM
from faceanalysis.storage import StorageError
from faceanalysis.storage import delete_image
from faceanalysis.storage import get_image_path

db = get_database_manager()
logger = get_logger(__name__)


Expand Down Expand Up @@ -59,11 +58,12 @@ def _store_matches(this_img_id: str,

def _load_image_ids_and_face_vectors() -> Tuple[List[str], np.array]:
logger.debug('getting all img ids and respective features')
session = db.get_session()

with get_db_session() as session:
rows = session.query(FeatureMapping)\
.all()

known_features = []
rows = session.query(FeatureMapping)\
.all()
session.close()
img_ids = []
for row in rows:
img_ids.append(row.img_id)
Expand Down Expand Up @@ -93,16 +93,16 @@ def _prepare_matches(matches: List[dict],
def _update_img_status(img_id: str,
status: Optional[ImageStatusEnum] = None,
error_msg: Optional[str] = None):

session = db.get_session()
update_fields = {}
if status:
update_fields['status'] = status.name
if error_msg:
update_fields['error_msg'] = error_msg
session.query(ImageStatus).filter(
ImageStatus.img_id == img_id).update(update_fields)
db.safe_commit(session)

with get_db_session(commit=True) as session:
session.query(ImageStatus)\
.filter(ImageStatus.img_id == img_id)\
.update(update_fields)


# pylint: disable=len-as-condition
Expand Down Expand Up @@ -134,27 +134,26 @@ def process_image(img_id: str):
logger.info('Found %d faces in image %s', len(face_vectors), img_id)
_update_img_status(img_id, status=ImageStatusEnum.face_vector_computed)

session = db.get_session()
_add_entry_to_session(Image, session, img_id=img_id)
matches = [] # type: List[dict]
for face_vector in face_vectors:
_store_face_vector(face_vector, img_id, session)

distances = _compute_distances(prev_face_vectors, face_vector)
for that_img_id, distance in zip(prev_img_ids, distances):
if img_id == that_img_id:
continue
distance = float(distance)
if distance >= DISTANCE_SCORE_THRESHOLD:
continue
_prepare_matches(matches, that_img_id, distance)

logger.info('Found %d face matches for image %s', len(matches), img_id)
for match in matches:
_store_matches(img_id, match["that_img_id"],
match["distance_score"], session)
with get_db_session(commit=True) as session:
_add_entry_to_session(Image, session, img_id=img_id)
matches = [] # type: List[dict]
for face_vector in face_vectors:
_store_face_vector(face_vector, img_id, session)

distances = _compute_distances(prev_face_vectors, face_vector)
for that_img_id, distance in zip(prev_img_ids, distances):
if img_id == that_img_id:
continue
distance = float(distance)
if distance >= DISTANCE_SCORE_THRESHOLD:
continue
_prepare_matches(matches, that_img_id, distance)

logger.info('Found %d face matches for image %s', len(matches), img_id)
for match in matches:
_store_matches(img_id, match["that_img_id"],
match["distance_score"], session)

db.safe_commit(session)
_update_img_status(img_id,
status=ImageStatusEnum.finished_processing,
error_msg=('No faces found in image'
Expand Down

0 comments on commit 2948dc5

Please sign in to comment.