Skip to content

Commit

Permalink
Merge pull request #94 from HomebyTwo/feature/hb2-93
Browse files Browse the repository at this point in the history
HB2-93 train prediction models for activity types too
  • Loading branch information
drixselecta committed Dec 1, 2020
2 parents 6938d62 + fcbdb2d commit c95e58d
Show file tree
Hide file tree
Showing 28 changed files with 967 additions and 382 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
@@ -1,4 +1,7 @@
# Change Log
## [0.16.0] - 2020-12-01 Fix schedule display in route title
- Display schedule in route title section when available

## [0.15.1] - 2020-12-01 Fix schedule display in route title
- Display schedule in route title section when available

Expand Down
13 changes: 6 additions & 7 deletions homebytwo/importers/tests/test_stravaroute.py
Expand Up @@ -16,7 +16,7 @@
from ...conftest import STRAVA_API_BASE_URL
from ...routes.fields import DataFrameField
from ...routes.models import ActivityType
from ...routes.tests.factories import PlaceFactory
from ...routes.tests.factories import PlaceFactory, ActivityFactory, ActivityTypeFactory
from ...utils.factories import AthleteFactory
from ...utils.tests import get_route_post_data
from ..models import StravaRoute
Expand Down Expand Up @@ -214,10 +214,9 @@ def test_get_import_strava_route_already_imported(
def test_post_import_strava_route_already_imported(
athlete, mock_import_route_call_response
):
route = StravaRouteFactory(
source_id=22798494,
athlete=athlete,
)
run = ActivityType.objects.get(name="Run")
ActivityFactory(athlete=athlete, activity_type=run)
route = StravaRouteFactory(source_id=22798494, athlete=athlete, activity_type=run)

response = mock_import_route_call_response(
data_source=route.data_source,
Expand All @@ -237,9 +236,9 @@ def test_post_import_strava_route_bad_distance(
mock_import_route_call_response,
):
route = StravaRouteFactory.build(
athlete=athlete,
athlete=athlete, activity_type=ActivityTypeFactory()
)

ActivityFactory(athlete=athlete, activity_type=route.activity_type)
post_data = get_route_post_data(route)
response = mock_import_route_call_response(
route.data_source,
Expand Down
121 changes: 71 additions & 50 deletions homebytwo/importers/tests/test_switzerlandmobilityroute.py
Expand Up @@ -4,7 +4,6 @@
from pathlib import Path
from re import compile as re_compile

import pytest
from django.conf import settings
from django.contrib.gis.geos import LineString, Point
from django.forms.models import model_to_dict
Expand All @@ -13,14 +12,15 @@
from django.urls import reverse
from django.utils.http import urlencode

import pytest
import responses
from pytest_django.asserts import assertContains, assertRedirects
from requests.exceptions import ConnectionError

from ...routes.fields import DataFrameField
from ...routes.forms import RouteForm
from ...routes.models import Checkpoint
from ...routes.tests.factories import PlaceFactory
from ...routes.tests.factories import ActivityFactory, ActivityTypeFactory, PlaceFactory
from ...utils.factories import AthleteFactory, UserFactory
from ...utils.tests import create_checkpoints_from_geom, get_route_post_data, read_data
from ..exceptions import SwitzerlandMobilityError, SwitzerlandMobilityMissingCredentials
Expand Down Expand Up @@ -643,49 +643,6 @@ def test_switzerland_mobility_routes_no_cookies(self):

self.assertEqual(response.url, redirect_url)

#########
# Forms #
#########

def test_switzerland_mobility_valid_login_form(self):
username = "test@test.com"
password = "123456"
data = {"username": username, "password": password}
form = SwitzerlandMobilityLogin(data=data)

self.assertTrue(form.is_valid())

def test_switzerland_mobility_invalid_login_form(self):
username = ""
password = ""
data = {"username": username, "password": password}
form = SwitzerlandMobilityLogin(data=data)

self.assertFalse(form.is_valid())

def test_switzerland_mobility_valid_route_form(self):
route = SwitzerlandMobilityRouteFactory.build()
route_data = model_to_dict(route)
route_data.update(
{
"activity_type": 1,
"start_place": route.start_place.id,
"end_place": route.end_place.id,
}
)
form = RouteForm(data=route_data)
self.assertTrue(form.is_valid())

def test_switzerland_mobility_invalid_route_form(self):
route = SwitzerlandMobilityRouteFactory.build()
route_data = model_to_dict(route)
route_data.update(
{"start_place": route.start_place.id, "end_place": route.end_place.id}
)
del route_data["activity_type"]
form = RouteForm(data=route_data)
self.assertFalse(form.is_valid())


@pytest.fixture
def mock_login_response(mocked_responses, settings):
Expand Down Expand Up @@ -916,7 +873,10 @@ def test_post_import_switzerland_mobility_route_no_checkpoints(
athlete, client, mock_import_route_call_response
):
source_id = 2191833
route = SwitzerlandMobilityRouteFactory.build(source_id=source_id)
route = SwitzerlandMobilityRouteFactory.build(
source_id=source_id, athlete=athlete, activity_type=ActivityTypeFactory()
)
ActivityFactory(athlete=athlete, activity_type=route.activity_type)

post_data = get_route_post_data(route)
response = mock_import_route_call_response(
Expand All @@ -940,7 +900,11 @@ def test_post_import_switzerland_mobility_route_with_checkpoints(
geom, _ = switzerland_mobility_data_from_json(route_json)

number_of_checkpoints = 5
route = SwitzerlandMobilityRouteFactory.build()
route = SwitzerlandMobilityRouteFactory.build(
athlete=athlete, activity_type=ActivityTypeFactory(name="Run")
)
ActivityFactory(athlete=athlete, activity_type=route.activity_type)

post_data = get_route_post_data(route)
post_data["checkpoints"] = create_checkpoints_from_geom(geom, number_of_checkpoints)

Expand All @@ -966,8 +930,9 @@ def test_post_import_switzerland_mobility_route_updated(
athlete, mock_import_route_call_response
):
route = SwitzerlandMobilityRouteFactory(
source_id=2191833, athlete=athlete, start_place=None, end_place=None
source_id=2191833, athlete=athlete, activity_type=ActivityTypeFactory()
)
ActivityFactory(athlete=athlete, activity_type=route.activity_type)

post_data = get_route_post_data(route)
response = mock_import_route_call_response(
Expand Down Expand Up @@ -1008,9 +973,9 @@ def test_get_switzerland_mobility_route_deleted_data(
assert response.status_code == 200


#####################
######################
# view routes:update #
#####################
######################


def test_get_update_switzerland_mobility_route_redirect_to_login_with_update_id(
Expand All @@ -1028,3 +993,59 @@ def test_get_update_switzerland_mobility_route_redirect_to_login_with_update_id(
params = urlencode({"update": route.pk})
redirect_url = reverse("switzerland_mobility_login") + "?" + params
assertRedirects(response, redirect_url)


#########
# Forms #
#########


def test_switzerland_mobility_login_form(athlete):
username = "test@test.com"
password = "123456"
data = {"username": username, "password": password}
form = SwitzerlandMobilityLogin(data=data)

assert form.is_valid()


def test_switzerland_mobility_login_form_invalid(athlete):
username = ""
password = ""
data = {"username": username, "password": password}
form = SwitzerlandMobilityLogin(data=data)

assert not form.is_valid()


def test_switzerland_mobility_route_form(athlete):
route = SwitzerlandMobilityRouteFactory.build(
athlete=athlete, activity_type=ActivityTypeFactory()
)
ActivityFactory(athlete=athlete, activity_type=route.activity_type)
route_data = model_to_dict(route)
route_data.update(
{
"activity_type": route.activity_type.id,
"start_place": route.start_place.id,
"end_place": route.end_place.id,
}
)
form = RouteForm(instance=route, data=route_data)

assert form.is_valid()


def test_switzerland_mobility_route_form_invalid(athlete):
route = SwitzerlandMobilityRouteFactory.build(
athlete=athlete, activity_type=ActivityTypeFactory()
)
ActivityFactory(athlete=athlete, activity_type=route.activity_type)
route_data = model_to_dict(route)
route_data.update(
{"start_place": route.start_place.id, "end_place": route.end_place.id}
)
del route_data["activity_type"]
form = RouteForm(instance=route, data=route_data)

assert not form.is_valid()
11 changes: 8 additions & 3 deletions homebytwo/routes/fields.py
Expand Up @@ -9,7 +9,7 @@
from django.core import checks
from django.core.exceptions import FieldDoesNotExist, ValidationError
from django.core.files.storage import default_storage
from django.db import connection
from django.db import connection as db_connection
from django.forms import MultipleChoiceField
from django.forms.widgets import CheckboxSelectMultiple
from django.utils.translation import gettext_lazy as _
Expand All @@ -29,7 +29,7 @@ def LineSubstring(line, start_location, end_location):
"ST_GeomFromText(%(line)s, %(srid)s), %(start)s, %(end)s));"
)

with connection.cursor() as cursor:
with db_connection.cursor() as cursor:
cursor.execute(
sql,
{
Expand Down Expand Up @@ -225,7 +225,8 @@ def save_dataframe_to_file(self, dataframe, model_instance):

def generate_filepath(self, instance):
"""
return a filepath based on the model's class name, dataframe_field and unique fields
return a filepath based on the model's class name
dataframe_field and unique fields
"""

# create filename based on instance and field name
Expand Down Expand Up @@ -270,6 +271,8 @@ def get_prep_value(self, value):
"""
convert NumPy array to a list.
"""
if value is None:
return value
return list(value)

def get_db_prep_value(self, value, connection, prepared=False):
Expand All @@ -284,6 +287,8 @@ def to_python(self, value):
"""
convert the list value to a NumPy array.
"""
if value is None:
return value
return array(value)

def from_db_value(self, value, *args, **kwargs):
Expand Down

0 comments on commit c95e58d

Please sign in to comment.