Skip to content

Commit

Permalink
Merge pull request #65 from JWCook/sqlite-tempfile
Browse files Browse the repository at this point in the history
Add use_temp option to SQLite backend
  • Loading branch information
JWCook committed May 12, 2021
2 parents 591f0ed + 69d5208 commit 0efa0c5
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 11 deletions.
1 change: 1 addition & 0 deletions HISTORY.md
Expand Up @@ -10,6 +10,7 @@
* `Cache-Control: no-store`
* `Expires`
* Add support for HTTP timestamps (RFC 5322) in ``expire_after`` parameters
* Add a `use_temp` option to `SQLiteBackend` to use a tempfile
* Packaging is now handled with Poetry. For users, installation still works the same. For developers,
see [Contributing Guide](https://aiohttp-client-cache.readthedocs.io/en/latest/contributing.html) for details

Expand Down
38 changes: 29 additions & 9 deletions aiohttp_client_cache/backends/sqlite.py
@@ -1,7 +1,10 @@
import asyncio
import sqlite3
from contextlib import asynccontextmanager
from os.path import expanduser, splitext
from os import makedirs
from os.path import abspath, basename, dirname, expanduser, isabs, join
from pathlib import Path
from tempfile import gettempdir
from typing import AsyncIterable, AsyncIterator, Union

import aiosqlite
Expand All @@ -20,17 +23,16 @@ class SQLiteBackend(CacheBackend):
extension is specified)
"""

def __init__(self, cache_name: str = 'aiohttp-cache', **kwargs):
def __init__(self, cache_name: str = 'aiohttp-cache', use_temp: bool = False, **kwargs):
"""
Args:
cache_name: Database filename
use_temp: Store database in a temp directory (e.g., ``/tmp/http_cache.sqlite``).
Note: if ``cache_name`` is an absolute path, this option will be ignored.
"""
super().__init__(cache_name=cache_name, **kwargs)
path, ext = splitext(cache_name)
cache_path = expanduser(f'{path}{ext or ".sqlite"}')

self.responses = SQLitePickleCache(cache_path, 'responses', **kwargs)
self.redirects = SQLiteCache(cache_path, 'redirects', **kwargs)
self.responses = SQLitePickleCache(cache_name, 'responses', use_temp=use_temp, **kwargs)
self.redirects = SQLiteCache(cache_name, 'redirects', use_temp=use_temp, **kwargs)


class SQLiteCache(BaseCache):
Expand All @@ -45,13 +47,15 @@ class SQLiteCache(BaseCache):
Args:
filename: Database filename
table_name: Table name
use_temp: Store database in a temp directory (e.g., ``/tmp/http_cache.sqlite``).
Note: if ``cache_name`` is an absolute path, this option will be ignored.
kwargs: Additional keyword arguments for :py:func:`sqlite3.connect`
"""

def __init__(self, filename: str, table_name: str, **kwargs):
def __init__(self, filename: str, table_name: str, use_temp: bool = False, **kwargs):
super().__init__(**kwargs)
self.connection_kwargs = get_valid_kwargs(sqlite_template, kwargs)
self.filename = filename
self.filename = _get_cache_filename(filename, use_temp)
self.table_name = table_name

self._bulk_commit = False
Expand Down Expand Up @@ -173,3 +177,19 @@ async def values(self) -> AsyncIterable[ResponseOrKey]:

async def write(self, key, item):
await super().write(key, sqlite3.Binary(self.serialize(item)))


def _get_cache_filename(filename: Union[Path, str], use_temp: bool) -> str:
"""Get resolved path for database file"""
# Save to a temp directory, if specified
if use_temp and not isabs(filename):
filename = join(gettempdir(), filename)

# Expand relative and user paths (~/*), and add file extension if not specified
filename = abspath(expanduser(str(filename)))
if '.' not in basename(filename):
filename += '.sqlite'

# Make sure parent dirs exist
makedirs(dirname(filename), exist_ok=True)
return filename
12 changes: 10 additions & 2 deletions test/integration/test_sqlite_backend.py
@@ -1,8 +1,9 @@
import os
import pytest
from datetime import datetime
from tempfile import NamedTemporaryFile
from tempfile import NamedTemporaryFile, gettempdir

from aiohttp_client_cache.backends.sqlite import SQLiteBackend, SQLitePickleCache
from aiohttp_client_cache.backends.sqlite import SQLiteBackend, SQLiteCache, SQLitePickleCache

pytestmark = pytest.mark.asyncio
test_data = {'key_1': 'item_1', 'key_2': datetime.now(), 'key_3': 3.141592654}
Expand Down Expand Up @@ -76,3 +77,10 @@ async def test_clear(cache_client):
assert await cache_client.size() == 0
assert {k async for k in cache_client.keys()} == set()
assert {v async for v in cache_client.values()} == set()


def test_use_temp():
relative_path = SQLiteCache('test-db', 'test-table').filename
temp_path = SQLiteCache('test-db', 'test-table', use_temp=True).filename
assert not relative_path.startswith(gettempdir())
assert temp_path.startswith(gettempdir())

0 comments on commit 0efa0c5

Please sign in to comment.