Skip to content

Commit

Permalink
Prepare Playlist API for Typegen (#2797)
Browse files Browse the repository at this point in the history
* Prepare Playlist API for Typegen

- Fixes duplicate operation ID
- Fixes reference to Tracks in trending_parser documentation
- Adds docs where missing
- Normalizes decorator order
- Prefers @api.expects() over @api.doc() params

* Update playlists to not use inheritance and instead use route level doc
  • Loading branch information
rickyrombo committed Apr 1, 2022
1 parent 01db549 commit f0f8b95
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 68 deletions.
4 changes: 2 additions & 2 deletions discovery-provider/src/api/v1/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,12 +433,12 @@ def __schema__(self):
full_trending_parser.add_argument(
"genre",
required=False,
description="Filter to trending tracks for a specified genre",
description="Filter trending to a specified genre",
)
full_trending_parser.add_argument(
"time",
required=False,
description="Get trending tracks over a specified time range",
description="Calculate trending over a specified time range",
type=str,
choices=("week", "month", "year", "allTime"),
)
Expand Down
144 changes: 78 additions & 66 deletions discovery-provider/src/api/v1/playlists.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
import logging

from flask.globals import request
from flask_restx import Namespace, Resource, fields, reqparse
from flask_restx import Namespace, Resource, fields
from src.api.v1.helpers import (
abort_bad_path_param,
abort_bad_request_param,
current_user_parser,
decode_with_abort,
extend_playlist,
extend_track,
extend_user,
full_trending_parser,
get_current_user_id,
get_default_max,
make_full_response,
make_response,
pagination_parser,
pagination_with_current_user_parser,
search_parser,
success_response,
trending_parser,
)
from src.api.v1.models.playlists import full_playlist_model, playlist_model
from src.api.v1.models.users import user_model_full
Expand Down Expand Up @@ -91,13 +96,13 @@ class Playlist(Resource):
@record_metrics
@ns.doc(
id="""Get Playlist""",
description="""Get a playlist by ID""",
params={"playlist_id": "A Playlist ID"},
responses={200: "Success", 400: "Bad request", 500: "Server error"},
)
@ns.marshal_with(playlists_response)
@cache(ttl_sec=5)
def get(self, playlist_id):
"""Fetch a playlist."""
playlist_id = decode_with_abort(playlist_id, ns)
playlist = get_playlist(playlist_id, None)
response = success_response([playlist] if playlist else [])
Expand All @@ -108,18 +113,20 @@ def get(self, playlist_id):
"playlist_tracks_response", ns, fields.List(fields.Nested(track))
)

full_playlist_parser = reqparse.RequestParser()
full_playlist_parser.add_argument("user_id", required=False)


@full_ns.route(PLAYLIST_ROUTE)
class FullPlaylist(Resource):
@ns.doc(
id="""Get Playlist""",
description="""Get a playlist by ID""",
params={"playlist_id": "A Playlist ID"},
)
@ns.expect(current_user_parser)
@ns.marshal_with(full_playlists_response)
@cache(ttl_sec=5)
def get(self, playlist_id):
"""Fetch a playlist."""
playlist_id = decode_with_abort(playlist_id, full_ns)
args = full_playlist_parser.parse_args()
args = current_user_parser.parse_args()
current_user_id = get_current_user_id(args)

playlist = get_playlist(playlist_id, current_user_id)
Expand All @@ -135,13 +142,13 @@ class PlaylistTracks(Resource):
@record_metrics
@ns.doc(
id="""Get Playlist Tracks""",
description="""Fetch tracks within a playlist.""",
params={"playlist_id": "A Playlist ID"},
responses={200: "Success", 400: "Bad request", 500: "Server error"},
)
@ns.marshal_with(playlist_tracks_response)
@cache(ttl_sec=5)
def get(self, playlist_id):
"""Fetch tracks within a playlist."""
decoded_id = decode_with_abort(playlist_id, ns)
tracks = get_tracks_for_playlist(decoded_id)
return success_response(tracks)
Expand All @@ -157,14 +164,13 @@ class PlaylistSearchResult(Resource):
@record_metrics
@ns.doc(
id="""Search Playlists""",
params={"query": "Search Query"},
description="""Search for a playlist""",
responses={200: "Success", 400: "Bad request", 500: "Server error"},
)
@ns.marshal_with(playlist_search_result)
@ns.expect(search_parser)
@ns.marshal_with(playlist_search_result)
@cache(ttl_sec=600)
def get(self):
"""Search for a playlist."""
args = search_parser.parse_args()
query = args["query"]
if not query:
Expand All @@ -183,23 +189,22 @@ def get(self):
return success_response(playlists)


top_parser = reqparse.RequestParser()
top_parser.add_argument("type", required=True)
top_parser.add_argument("limit", required=False, type=int)
top_parser.add_argument("offset", required=False, type=int)
top_parser = pagination_parser.copy()
top_parser.add_argument(
"type",
required=True,
choices=("album", "playlist"),
description="The collection type",
)


@ns.route("/top", doc=False)
class Top(Resource):
@record_metrics
@ns.doc(
id="""Top Playlists""",
params={"type": "album or playlist", "limit": "limit", "offset": "offset"},
)
@ns.doc(id="""Top Playlists""", description="""Gets top playlists.""")
@ns.marshal_with(playlists_response)
@cache(ttl_sec=30 * 60)
def get(self):
"""Gets top playlists."""
args = top_parser.parse_args()
if args.get("limit") is None:
args["limit"] = 100
Expand All @@ -218,27 +223,24 @@ def get(self):
return success_response(playlists)


playlist_favorites_route_parser = reqparse.RequestParser()
playlist_favorites_route_parser.add_argument("user_id", required=False)
playlist_favorites_route_parser.add_argument("limit", required=False, type=int)
playlist_favorites_route_parser.add_argument("offset", required=False, type=int)
playlist_favorites_response = make_full_response(
"following_response", full_ns, fields.List(fields.Nested(user_model_full))
)


@full_ns.route("/<string:playlist_id>/favorites")
class FullTrackFavorites(Resource):
@full_ns.expect(playlist_favorites_route_parser)
@full_ns.doc(
id="""Get Users that Favorited a Playlist""",
params={"user_id": "A User ID", "limit": "Limit", "offset": "Offset"},
id="""Get Users From Playlist Favorites""",
description="""Get users that favorited a playlist""",
params={"playlist_id": "A Playlist ID"},
responses={200: "Success", 400: "Bad request", 500: "Server error"},
)
@full_ns.expect(pagination_with_current_user_parser)
@full_ns.marshal_with(playlist_favorites_response)
@cache(ttl_sec=5)
def get(self, playlist_id):
args = playlist_favorites_route_parser.parse_args()
args = pagination_with_current_user_parser.parse_args()
decoded_id = decode_with_abort(playlist_id, full_ns)
limit = get_default_max(args.get("limit"), 10, 100)
offset = get_default_max(args.get("offset"), 0)
Expand All @@ -255,27 +257,23 @@ def get(self, playlist_id):
return success_response(users)


playlist_reposts_route_parser = reqparse.RequestParser()
playlist_reposts_route_parser.add_argument("user_id", required=False)
playlist_reposts_route_parser.add_argument("limit", required=False, type=int)
playlist_reposts_route_parser.add_argument("offset", required=False, type=int)
playlist_reposts_response = make_full_response(
"following_response", full_ns, fields.List(fields.Nested(user_model_full))
)


@full_ns.route("/<string:playlist_id>/reposts")
class FullPlaylistReposts(Resource):
@full_ns.expect(playlist_reposts_route_parser)
@full_ns.doc(
id="""Get Users that Reposted a Playlist""",
params={"user_id": "A User ID", "limit": "Limit", "offset": "Offset"},
id="""Get Users From Playlist Reposts""",
params={"playlist_id": "A Playlist ID"},
responses={200: "Success", 400: "Bad request", 500: "Server error"},
)
@full_ns.expect(pagination_with_current_user_parser)
@full_ns.marshal_with(playlist_reposts_response)
@cache(ttl_sec=5)
def get(self, playlist_id):
args = playlist_reposts_route_parser.parse_args()
args = pagination_with_current_user_parser.parse_args()
decoded_id = decode_with_abort(playlist_id, full_ns)
limit = get_default_max(args.get("limit"), 10, 100)
offset = get_default_max(args.get("offset"), 0)
Expand All @@ -294,28 +292,39 @@ def get(self, playlist_id):
trending_response = make_response(
"trending_playlists_response", ns, fields.List(fields.Nested(playlist_model))
)
trending_parser = reqparse.RequestParser()
trending_parser.add_argument("time", required=False)
trending_playlist_parser = trending_parser.copy()
trending_playlist_parser.remove_argument("genre")


@ns.route(
"/trending",
defaults={"version": DEFAULT_TRENDING_VERSIONS[TrendingType.PLAYLISTS].name},
strict_slashes=False,
doc={
"get": {
"id": """Get Trending Playlists""",
"description": """Gets trending playlists for a time period""",
"responses": {200: "Success", 400: "Bad request", 500: "Server error"},
}
},
)
@ns.route(
"/trending/<string:version>",
doc={
"get": {
"id": """Get Trending Playlists With Version""",
"description": """Gets trending playlists for a time period based on the given trending strategy version""",
"params": {"version": "The strategy version of trending to use"},
"responses": {200: "Success", 400: "Bad request", 500: "Server error"},
}
},
)
@ns.route("/trending/<string:version>")
class TrendingPlaylists(Resource):
@record_metrics
@ns.doc(
id="""Trending Playlists""",
params={"time": "time range to query"},
responses={200: "Success", 400: "Bad request", 500: "Server error"},
)
@ns.expect(trending_parser)
@ns.expect(trending_playlist_parser)
@ns.marshal_with(trending_response)
@cache(ttl_sec=TRENDING_TTL_SEC)
def get(self, version):
"""Gets top trending playlists for time period on Audius"""
trending_playlist_versions = trending_strategy_factory.get_versions_for_type(
TrendingType.PLAYLISTS
).keys()
Expand All @@ -325,7 +334,7 @@ def get(self, version):
if not version_list:
abort_bad_path_param("version", ns)

args = trending_parser.parse_args()
args = trending_playlist_parser.parse_args()
time = args.get("time")
time = "week" if time not in ["week", "month", "year"] else time
args = {"time": time, "with_tracks": False}
Expand All @@ -345,35 +354,38 @@ def get(self, version):
fields.List(fields.Nested(full_playlist_model)),
)

full_trending_parser = trending_parser.copy()
full_trending_parser.add_argument("time", required=False)
full_trending_parser.add_argument("limit", required=False)
full_trending_parser.add_argument("offset", required=False)
full_trending_parser.add_argument("user_id", required=False)
full_trending_playlist_parser = full_trending_parser.copy()
full_trending_playlist_parser.remove_argument("genre")


@full_ns.route(
"/trending",
defaults={"version": DEFAULT_TRENDING_VERSIONS[TrendingType.PLAYLISTS].name},
strict_slashes=False,
doc={
"get": {
"id": """Get Trending Playlists""",
"description": """Returns trending playlists for a time period""",
"responses": {200: "Success", 400: "Bad request", 500: "Server error"},
}
},
)
@full_ns.route(
"/trending/<string:version>",
doc={
"get": {
"id": """Get Trending Playlists With Version""",
"description": """Returns trending playlists for a time period based on the given trending version""",
"params": {"version": "The strategy version of trending to use"},
"responses": {200: "Success", 400: "Bad request", 500: "Server error"},
}
},
)
@full_ns.route("/trending/<string:version>")
class FullTrendingPlaylists(Resource):
@full_ns.expect(full_trending_parser)
@full_ns.doc(
id="""Returns trending playlists for a time period based on the given trending version""",
params={
"user_id": "A User ID",
"limit": "Limit",
"offset": "Offset",
"time": "week / month / year",
},
responses={200: "Success", 400: "Bad request", 500: "Server error"},
)
@record_metrics
@full_ns.expect(full_trending_playlist_parser)
@full_ns.marshal_with(full_trending_playlists_response)
def get(self, version):
"""Get trending playlists"""
trending_playlist_versions = trending_strategy_factory.get_versions_for_type(
TrendingType.PLAYLISTS
).keys()
Expand All @@ -383,7 +395,7 @@ def get(self, version):
if not version_list:
abort_bad_path_param("version", full_ns)

args = full_trending_parser.parse_args()
args = full_trending_playlist_parser.parse_args()
strategy = trending_strategy_factory.get_strategy(
TrendingType.PLAYLISTS, version_list[0]
)
Expand Down

0 comments on commit f0f8b95

Please sign in to comment.