Skip to content

Commit

Permalink
Added playlist type/methods + Type hinting and cache improvements (#4)
Browse files Browse the repository at this point in the history
- Added `get_playlist_tracks`, `iter_playlist_tracks` methods
- Added `Playlist` type
- Simplified cache check
- Improved type hinting
- Improved formattation with black
  • Loading branch information
xnetcat committed Mar 8, 2022
1 parent 7f17c91 commit 3d351dc
Show file tree
Hide file tree
Showing 20 changed files with 176 additions and 118 deletions.
8 changes: 3 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,16 @@
"Programming Language :: Python :: Implementation :: CPython",
"Topic :: Internet",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules"
"Topic :: Software Development :: Libraries :: Python Modules",
],
keywords="spotify spotipy spotipy2 api wrapper client library oauth",
project_urls={
"Tracker": "https://github.com/CyanBook/spotipy2/issues",
"Community": "https://github.com/CyanBook/spotipy2/discussions",
"Source": "https://github.com/CyanBook/spotipy2"
"Source": "https://github.com/CyanBook/spotipy2",
},
python_requires="~=3.7",
packages=find_packages(),
install_requires=install_requires,
extras_require={
"cache": ["pymongo"]
}
extras_require={"cache": ["pymongo"]},
)
2 changes: 1 addition & 1 deletion spotipy2/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from .client_credentials_flow import ClientCredentialsFlow
from .token import Token

__all__ = [BaseAuthFlow, ClientCredentialsFlow, Token]
__all__ = ["BaseAuthFlow", "ClientCredentialsFlow", "Token"]
14 changes: 4 additions & 10 deletions spotipy2/auth/client_credentials_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@

class ClientCredentialsFlow(BaseAuthFlow):
def __init__(
self,
client_id: str,
client_secret: str,
token: Optional[Token] = None
self, client_id: str, client_secret: str, token: Optional[Token] = None
) -> None:
self.client_id = client_id
self.client_secret = client_secret
Expand All @@ -26,15 +23,12 @@ async def get_access_token(self, http: ClientSession) -> Token:
return self.token

async with http.post(
API_URL,
data=GRANT_TYPE,
headers=await self.make_auth_header()
API_URL, data=GRANT_TYPE, headers=await self.make_auth_header()
) as r:
return await Token.from_dict(await r.json())

async def make_auth_header(self) -> dict:
return {
"Authorization": "Basic %s" % b64encode(
f"{self.client_id}:{self.client_secret}".encode()
).decode()
"Authorization": "Basic %s"
% b64encode(f"{self.client_id}:{self.client_secret}".encode()).decode()
}
8 changes: 3 additions & 5 deletions spotipy2/auth/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(
scopes: Optional[List[str]],
expires_in: int,
expires_at: datetime,
refresh_token: Optional[str] = None
refresh_token: Optional[str] = None,
) -> None:
self.access_token = access_token
self.token_type = token_type
Expand All @@ -22,15 +22,13 @@ def __init__(

@classmethod
async def from_dict(cls, d: dict) -> Token:
expires_at = datetime.now(
timezone.utc
) + timedelta(seconds=d["expires_in"])
expires_at = datetime.now(timezone.utc) + timedelta(seconds=d["expires_in"])

return cls(
access_token=d["access_token"],
token_type=d["token_type"],
scopes=d.get("scope", "").split() or None,
expires_in=d["expires_in"],
expires_at=expires_at,
refresh_token=d.get("refresh_token")
refresh_token=d.get("refresh_token"),
)
37 changes: 13 additions & 24 deletions spotipy2/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,15 @@ def __init__(
auth_flow: ClientCredentialsFlow,
mongodb_uri: Optional[str] = None,
*args,
**kwargs
**kwargs,
) -> None:
self.auth_flow = auth_flow
self.http = ClientSession(*args, **kwargs)

if mongodb_uri:
from pymongo import MongoClient

match = re.match(
r"(mongodb:\/\/\S+\/)(\w+)(\?\S+)?",
mongodb_uri
)
match = re.match(r"(mongodb:\/\/\S+\/)(\w+)(\?\S+)?", mongodb_uri)

if match:
db_url, db_name, db_params = match.groups()
Expand All @@ -39,19 +36,17 @@ def __init__(

db = MongoClient(db_url + (db_params or ""))
self.cache = db[db_name].spotipy2
self.is_cache = True
else:
self.cache = None
self.is_cache = False

async def _req(
self,
method: str,
endpoint: str,
params: Optional[dict] = None,
can_be_cached: bool = False
can_be_cached: bool = False,
) -> dict:
if self.is_cache and can_be_cached:
if self.cache and can_be_cached:
doc = self.cache.find_one({"_endpoint": endpoint})
if doc:
doc.pop("_endpoint")
Expand All @@ -61,37 +56,28 @@ async def _req(
headers = {"Authorization": f"Bearer {token.access_token}"}

async with self.http.request(
method,
f"{self.API_URL}{endpoint}",
params=params,
headers=headers
method, f"{self.API_URL}{endpoint}", params=params, headers=headers
) as r:
json = await r.json()

try:
assert r.status == 200
except AssertionError:
raise SpotifyException(
json["error"]["status"],
json["error"]["message"]
json["error"]["status"], json["error"]["message"]
)
else:
# Cache if possible
if self.is_cache and can_be_cached:
asyncio.create_task(
self.cache_resource(endpoint, json)
)
if self.cache and can_be_cached:
asyncio.create_task(self.cache_resource(endpoint, json))

return json

async def _get(self, endpoint: str, params: Optional[dict] = None) -> dict:
# Check if cache is enabled and request is a simple get [resource]
can_be_cached = self.is_cache and (
can_be_cached = self.cache is None and (
params is None
and re.match(
r"^(?!me|browse)([\w-]+)\/(\w+)$",
endpoint
) is not None
and re.match(r"^(?!me|browse)([\w-]+)\/(\w+)$", endpoint) is not None
)

return await self._req("GET", endpoint, params, can_be_cached)
Expand All @@ -106,6 +92,9 @@ async def __aexit__(self, exc_type, exc_value, traceback) -> None:
await self.stop()

async def cache_resource(self, endpoint, value) -> None:
if not self.cache:
return

try:
# Insert endpoint for future requests
value["_endpoint"] = endpoint
Expand Down
5 changes: 2 additions & 3 deletions spotipy2/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ def __init__(self, status: int, message: str) -> None:
self.message = message

def __repr__(self) -> str:
return "<SpotifyException(status={0}, message=\"{1}\")>".format(
self.status,
self.message
return '<SpotifyException(status={0}, message="{1}")>'.format(
self.status, self.message
)
5 changes: 4 additions & 1 deletion spotipy2/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
from .artists import ArtistMethods
from .search import SearchMethods
from .tracks import TrackMethods
from .playlists import PlaylistMethods


class Methods(AlbumMethods, ArtistMethods, SearchMethods, TrackMethods):
class Methods(
AlbumMethods, ArtistMethods, SearchMethods, TrackMethods, PlaylistMethods
):
@staticmethod
def get_id(s: str) -> str:
if m := re.search("(?!.*/).+", s):
Expand Down
18 changes: 7 additions & 11 deletions spotipy2/methods/albums.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,22 @@

class AlbumMethods:
async def get_albums(
self: spotipy2.Spotify,
album_ids: List[str]
self: spotipy2.Spotify, album_ids: List[str] # type: ignore
) -> List[Album]:
albums = await self._get(
"albums",
params={"ids": ",".join([self.get_id(i) for i in album_ids])}
"albums", params={"ids": ",".join([self.get_id(i) for i in album_ids])}
)
return [Album.from_dict(a) for a in albums["albums"]]

async def get_album(self: spotipy2.Spotify, album_id: str) -> Album:
return Album.from_dict(
await self._get(f"albums/{self.get_id(album_id)}")
)
async def get_album(self: spotipy2.Spotify, album_id: str) -> Album: # type: ignore
return Album.from_dict(await self._get(f"albums/{self.get_id(album_id)}"))

async def get_album_tracks(
self: spotipy2.Spotify,
self: spotipy2.Spotify, # type: ignore
album_id: str,
market: Optional[str] = None,
limit: int = None,
offset: int = None
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> List[Track]:
params = self.wrapper(market=market, limit=limit, offset=offset)

Expand Down
36 changes: 15 additions & 21 deletions spotipy2/methods/artists.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,42 @@
from __future__ import annotations
from typing import List
from typing import List, Optional

import spotipy2
from spotipy2.types import Artist, SimplifiedAlbum, Track


class ArtistMethods:
async def get_artists(
self: spotipy2.Spotify,
artist_ids: List[str]
self: spotipy2.Spotify, artist_ids: List[str] # type: ignore
) -> List[Artist]:
artists = await self._get(
"artists",
params={"ids": ",".join([self.get_id(i) for i in artist_ids])}
"artists", params={"ids": ",".join([self.get_id(i) for i in artist_ids])}
)
return [Artist.from_dict(a) for a in artists["artists"]]

async def get_artist(self: spotipy2.Spotify, artist_id: str) -> Artist:
return Artist.from_dict(
await self._get(f"artists/{self.get_id(artist_id)}")
)
async def get_artist(
self: spotipy2.Spotify, artist_id: str # type: ignore
) -> Artist:
return Artist.from_dict(await self._get(f"artists/{self.get_id(artist_id)}"))

async def get_artist_top_tracks(
self: spotipy2.Spotify, artist_id: str, market: str
self: spotipy2.Spotify, artist_id: str, market: str # type: ignore
) -> List[Track]:
top_tracks = await self._get(
f"artists/{self.get_id(artist_id)}/top-tracks",
params={"market": market}
f"artists/{self.get_id(artist_id)}/top-tracks", params={"market": market}
)
return [Track.from_dict(track) for track in top_tracks["tracks"]]

async def get_artist_albums(
self: spotipy2.Spotify,
self: spotipy2.Spotify, # type: ignore
artist_id: str,
include_groups: str = None,
market: str = None,
limit: int = None,
offset: int = None
include_groups: Optional[str] = None,
market: Optional[str] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> List[SimplifiedAlbum]:
params = self.wrapper(
include_groups=include_groups,
market=market,
limit=limit,
offset=offset
include_groups=include_groups, market=market, limit=limit, offset=offset
)

artist_albums = await self._get(
Expand Down
51 changes: 51 additions & 0 deletions spotipy2/methods/playlists.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations
from typing import AsyncGenerator, List, Optional

import spotipy2
from spotipy2.types import Playlist, Track


class PlaylistMethods:
async def get_playlist(
self: spotipy2.Spotify, playlist_id: str # type: ignore
) -> Playlist:
return Playlist.from_dict(
await self._get(f"playlists/{self.get_id(playlist_id)}")
)

async def get_playlist_tracks(
self: spotipy2.Spotify, # type: ignore
playlist_id: str,
market: Optional[str] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> List[Track]:
params = self.wrapper(market=market, limit=limit, offset=offset)

playlist_tracks = await self._get(
f"playlists/{self.get_id(playlist_id)}/tracks", params=params
)

return [Track.from_dict(track["track"]) for track in playlist_tracks["items"]]

async def iter_playlist_tracks(
self: spotipy2.Spotify, # type: ignore
playlist_id: str,
market: Optional[str] = None,
limit: Optional[int] = None,
offset: int = 0,
) -> AsyncGenerator[Track, None]:
while True:
params = self.wrapper(market=market, limit=limit, offset=offset)

playlist_tracks = await self._get(
f"playlists/{self.get_id(playlist_id)}/tracks", params=params
)

for track in playlist_tracks["items"]:
yield Track.from_dict(track["track"])

offset += len(playlist_tracks["items"])

if not playlist_tracks["next"]:
return None

0 comments on commit 3d351dc

Please sign in to comment.