Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
HB2-93 commands to clean-up and train activity_types
- Loading branch information
Cedric Hofstetter
committed
Nov 25, 2020
1 parent
8f2d567
commit efff438
Showing
5 changed files
with
258 additions
and
68 deletions.
There are no files selected for viewing
44 changes: 44 additions & 0 deletions
44
homebytwo/routes/management/commands/cleanup_activity_types.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from django.core.management import BaseCommand | ||
from django.db.models import Q | ||
|
||
from homebytwo.routes.models import Activity, ActivityType | ||
|
||
|
||
class Command(BaseCommand): | ||
|
||
help = "Remove activity types that are unsupported or have no associated activity. " | ||
|
||
def add_arguments(self, parser): | ||
parser.add_argument( | ||
"--dry-run", | ||
"--dryrun", | ||
action="store_false", | ||
dest="delete", | ||
default=True, | ||
help="test to command outcome.", | ||
) | ||
|
||
def handle(self, *args, **options): | ||
supported_activity_types = ActivityType.SUPPORTED_ACTIVITY_TYPES | ||
activities = Activity.objects.exclude( | ||
activity_type__name__in=supported_activity_types | ||
) | ||
|
||
activity_types = ActivityType.objects.filter( | ||
Q(activities=None) | ~Q(name__in=supported_activity_types) | ||
).distinct() | ||
|
||
activities_count = activities.count() | ||
activity_types_count = activity_types.count() | ||
|
||
if options["delete"]: | ||
activities.delete() | ||
activity_types.delete() | ||
message = "Deleted " | ||
else: | ||
message = "Would delete " | ||
|
||
message += "{} activities and {} activity_types.".format( | ||
activities_count, activity_types_count | ||
) | ||
return message |
37 changes: 37 additions & 0 deletions
37
homebytwo/routes/management/commands/train_activity_types.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from django.core.management import BaseCommand | ||
|
||
from homebytwo.routes.models import ActivityType | ||
|
||
|
||
class Command(BaseCommand): | ||
|
||
help = "Train prediction models for activity types. " | ||
|
||
def add_arguments(self, parser): | ||
# choose activities | ||
parser.add_argument( | ||
"activities", | ||
type=str, | ||
nargs="*", | ||
default=None, | ||
help="Choose activity types to train. ", | ||
) | ||
|
||
# Limit to number of places | ||
parser.add_argument( | ||
"--limit", | ||
type=int, | ||
nargs="?", | ||
default=None, | ||
help="Limits the number of activities used for training. ", | ||
) | ||
|
||
def handle(self, *args, **options): | ||
if options["activities"]: | ||
activity_types = ActivityType.objects.filter(name__in=options["activities"]) | ||
else: | ||
activity_types = ActivityType.objects.all() | ||
for activity_type in activity_types: | ||
print(activity_type.train_prediction_model(options["limit"])) | ||
|
||
return f"{activity_types.count()} activity_types trained successfully." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
from io import StringIO | ||
|
||
from django.core.management import call_command | ||
|
||
import pytest | ||
from pandas import DataFrame | ||
|
||
from homebytwo.importers.tests.factories import ( | ||
StravaRouteFactory, | ||
SwitzerlandMobilityRouteFactory, | ||
) | ||
from homebytwo.routes.tests.factories import ActivityFactory, RouteFactory, \ | ||
ActivityTypeFactory | ||
|
||
from ..management.commands.fix_routes_data import interpolate_from_existing_data | ||
from ..models import ActivityType | ||
|
||
|
||
################### | ||
# fix_routes_data # | ||
################### | ||
|
||
|
||
def call_fix_command(*args, **kwargs): | ||
out = StringIO() | ||
call_command( | ||
"fix_routes_data", | ||
*args, | ||
stdout=out, | ||
stderr=StringIO(), | ||
**kwargs, | ||
) | ||
return out.getvalue() | ||
|
||
|
||
def test_interpolate_from_existing(): | ||
route = RouteFactory.build() | ||
route.data.drop(index=list(range(10)), inplace=True) | ||
assert not len(route.geom) == len(route.data.altitude) | ||
assert interpolate_from_existing_data(route) | ||
assert len(route.geom) == len(route.data.altitude) | ||
|
||
|
||
@pytest.mark.django_db | ||
def test_fix_routes_data_no_routes(): | ||
out = call_fix_command("--verbosity", "2") | ||
assert "Re-imported 0 routes and restored 0 from data." in out | ||
|
||
|
||
@pytest.mark.django_db | ||
def test_fix_routes_data_all( | ||
settings, | ||
mock_route_details_response, | ||
mock_strava_streams_response, | ||
): | ||
settings.SWITZERLAND_MOBILITY_ROUTE_DATA_URL = "https://example.org/%d" | ||
|
||
RouteFactory() | ||
|
||
bad_data = DataFrame({"distance": range(10), "altitude": range(10)}) | ||
|
||
strava_route = StravaRouteFactory(data=bad_data) | ||
mock_strava_streams_response(strava_route.source_id) | ||
|
||
switzerland_mobility_route = SwitzerlandMobilityRouteFactory(data=bad_data) | ||
mock_route_details_response( | ||
data_source=switzerland_mobility_route.data_source, | ||
source_id=switzerland_mobility_route.source_id, | ||
api_response_status=404, | ||
api_response_json="404.json", | ||
) | ||
|
||
out = call_fix_command("--verbosity", "2") | ||
assert "Re-imported 1 routes and restored 1 from data." in out | ||
|
||
################### | ||
# fix_routes_data # | ||
################### | ||
|
||
|
||
would_delete_message = "Would delete {} activities and {} activity_types.\n" | ||
delete_message = "Deleted {} activities and {} activity_types.\n" | ||
|
||
|
||
def call_cleanup_activity_types(*args, **kwargs): | ||
out = StringIO() | ||
call_command( | ||
"cleanup_activity_types", | ||
*args, | ||
stdout=out, | ||
stderr=StringIO(), | ||
**kwargs, | ||
) | ||
return out.getvalue() | ||
|
||
|
||
@pytest.fixture | ||
def create_activities(): | ||
""" | ||
create Strava activities for existing activity types so that | ||
they don't get picked up for deletion | ||
""" | ||
for activity_type in ActivityType.objects.all(): | ||
ActivityFactory(activity_type=activity_type) | ||
|
||
|
||
@pytest.mark.django_db | ||
def test_clean_up_activity_types(create_activities): | ||
assert call_cleanup_activity_types() == delete_message.format(0, 0) | ||
|
||
|
||
@pytest.mark.django_db | ||
def test_clean_up_activity_types_dry_run(create_activities): | ||
assert call_cleanup_activity_types("--dry-run") == would_delete_message.format(0, 0) | ||
|
||
|
||
@pytest.mark.django_db | ||
def test_clean_up_activity_types_unsupported(create_activities): | ||
ActivityFactory(activity_type=ActivityTypeFactory(name=ActivityType.YOGA)) | ||
assert call_cleanup_activity_types() == delete_message.format(1, 1) | ||
|
||
|
||
@pytest.mark.django_db | ||
def test_clean_up_activity_types_empty(create_activities): | ||
ActivityTypeFactory(name=ActivityType.INLINESKATE) | ||
assert call_cleanup_activity_types("--dry-run") == would_delete_message.format(0, 1) | ||
|
||
|
||
def call_train_activity_types(*args, **kwargs): | ||
out = StringIO() | ||
call_command( | ||
"train_activity_types", | ||
*args, | ||
stdout=out, | ||
stderr=StringIO(), | ||
**kwargs, | ||
) | ||
return out.getvalue() | ||
|
||
|
||
trained_message = "{} activity_types trained successfully.\n" | ||
|
||
|
||
@pytest.mark.django_db | ||
def test_train_activity_types(): | ||
count = ActivityType.objects.count() | ||
for activity_type in ActivityType.objects.all(): | ||
ActivityFactory(activity_type=activity_type) | ||
assert trained_message.format(count) in call_train_activity_types() | ||
for activity_type in ActivityType.objects.all(): | ||
assert not activity_type.model_score == 0.0 | ||
|
||
|
||
@pytest.mark.django_db | ||
def test_train_activity_types_single_activity(): | ||
ActivityFactory(activity_type__name="Run") | ||
assert trained_message.format(1) in call_train_activity_types("Run") | ||
assert not ActivityType.objects.get(name="Run").model_score == 0.0 | ||
|
||
|
||
@pytest.mark.django_db | ||
def test_train_activity_types_two_activities(): | ||
ActivityFactory(activity_type__name="Run") | ||
ActivityFactory(activity_type__name="Ride") | ||
assert trained_message.format(2) in call_train_activity_types("Run", "Ride") | ||
assert not ActivityType.objects.get(name="Run").model_score == 0.0 | ||
assert not ActivityType.objects.get(name="Ride").model_score == 0.0 | ||
|
||
|
||
@pytest.mark.django_db | ||
def test_train_activity_types_limit(): | ||
ActivityFactory(activity_type__name="Run") | ||
assert trained_message.format(1) in call_train_activity_types("Run", "--limit", 1) | ||
assert not ActivityType.objects.get(name="Run").model_score == 0.0 |