Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 75 additions & 40 deletions aixplain/utils/asset_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from dataclasses import dataclass
from filelock import FileLock

from aixplain.utils import config
logging.getLogger("filelock").setLevel(logging.INFO)

from typing import TypeVar, Generic, Type
from typing import List

Expand All @@ -31,6 +32,7 @@ class Store(Generic[T]):
data (Dict[str, T]): Dictionary mapping asset IDs to their cached instances.
expiry (int): Unix timestamp when the cached data expires.
"""

data: Dict[str, T]
expiry: int

Expand Down Expand Up @@ -77,10 +79,14 @@ def __init__(
# create cache file and lock file name
self.cache_file = os.path.join(CACHE_FOLDER, f"{cache_filename}.json")
self.lock_file = os.path.join(CACHE_FOLDER, f"{cache_filename}.lock")

logger.info(f"Initializing AssetCache for {self.cls.__name__} with cache file: {self.cache_file}")

self.store = Store(data={}, expiry=self.compute_expiry())
self.load()

if not os.path.exists(self.cache_file):
logger.info(f"Cache file doesn't exist, creating new one: {self.cache_file}")
self.save()

def compute_expiry(self) -> int:
Expand All @@ -100,10 +106,7 @@ def compute_expiry(self) -> int:
try:
expiry = int(os.getenv("CACHE_EXPIRY_TIME", CACHE_DURATION))
except Exception as e:
logger.warning(
f"Failed to parse CACHE_EXPIRY_TIME: {e}, "
f"fallback to default value {CACHE_DURATION}"
)
logger.warning(f"Failed to parse CACHE_EXPIRY_TIME: {e}, " f"fallback to default value {CACHE_DURATION}")
# remove the CACHE_EXPIRY_TIME from the environment variables
del os.environ["CACHE_EXPIRY_TIME"]
expiry = CACHE_DURATION
Expand All @@ -118,12 +121,15 @@ def invalidate(self) -> None:
2. Deletes the cache file if it exists
3. Deletes the lock file if it exists
"""
logger.info(f"Invalidating cache for {self.cls.__name__}")
self.store = Store(data={}, expiry=self.compute_expiry())
# delete cache file and lock file
if os.path.exists(self.cache_file):
os.remove(self.cache_file)
logger.info(f"Removed cache file: {self.cache_file}")
if os.path.exists(self.lock_file):
os.remove(self.lock_file)
logger.info(f"Removed lock file: {self.lock_file}")

def load(self) -> None:
"""Load cached data from the cache file.
Expand All @@ -140,36 +146,38 @@ def load(self) -> None:
If any errors occur during loading (file not found, invalid JSON,
deserialization errors), the cache will be invalidated.
"""
logger.info(f"Loading cache for {self.cls.__name__} from {self.cache_file}")

if not os.path.exists(self.cache_file):
logger.info(f"Cache file doesn't exist: {self.cache_file}")
self.invalidate()
return

with FileLock(self.lock_file):
with open(self.cache_file, "r") as f:
try:
try:
with FileLock(self.lock_file):
logger.info(f"Acquired file lock for loading: {self.lock_file}")
with open(self.cache_file, "r") as f:
cache_data = json.load(f)
expiry = cache_data["expiry"]
raw_data = cache_data["data"]
parsed_data = {
k: self.cls.from_dict(v) for k, v in raw_data.items()
}

logger.info(f"Found {len(raw_data)} cached items for {self.cls.__name__}")

parsed_data = {k: self.cls.from_dict(v) for k, v in raw_data.items()}

self.store = Store(data=parsed_data, expiry=expiry)

if self.store.expiry < time.time():
logger.warning(f"Cache expired for {self.cls.__name__}")
logger.warning(
f"Cache expired for {self.cls.__name__} (expiry: {self.store.expiry}, current: {time.time()})"
)
self.invalidate()
else:
logger.info(f"Successfully loaded {len(parsed_data)} cached items for {self.cls.__name__}")

except Exception as e:
self.invalidate()
logger.warning(f"Failed to load cache data: {e}")

if self.store.expiry < time.time():
logger.warning(
f"Cache expired, invalidating cache for {self.cls.__name__}"
)
self.invalidate()
return
except Exception as e:
logger.error(f"Failed to load cache data for {self.cls.__name__}: {e}")
self.invalidate()

def save(self) -> None:
"""Save the current cache state to the cache file.
Expand All @@ -186,23 +194,39 @@ def save(self) -> None:
and an error will be logged, but the save operation will continue
for other assets.
"""
logger.info(f"Saving cache for {self.cls.__name__} with {len(self.store.data)} items")

os.makedirs(CACHE_FOLDER, exist_ok=True)
try:
os.makedirs(CACHE_FOLDER, exist_ok=True)
logger.info(f"Cache directory created/verified: {CACHE_FOLDER}")

with FileLock(self.lock_file):
logger.info(f"Acquired file lock for saving: {self.lock_file}")
with open(self.cache_file, "w") as f:
data_dict = {}
serialization_errors = 0

for asset_id, asset in self.store.data.items():
try:
data_dict[asset_id] = serialize(asset)
except Exception as e:
logger.error(f"Error serializing {asset_id}: {e}")
serialization_errors += 1

serializable_store = {
"expiry": self.store.expiry,
"data": data_dict,
}

with FileLock(self.lock_file):
with open(self.cache_file, "w") as f:
data_dict = {}
for asset_id, asset in self.store.data.items():
try:
data_dict[asset_id] = serialize(asset)
except Exception as e:
logger.error(f"Error serializing {asset_id}: {e}")
serializable_store = {
"expiry": self.store.expiry,
"data": data_dict,
}
json.dump(serializable_store, f, indent=4)

json.dump(serializable_store, f, indent=4)
if serialization_errors > 0:
logger.warning(f"Saved cache for {self.cls.__name__} with {serialization_errors} serialization errors")
else:
logger.info(f"Successfully saved cache for {self.cls.__name__} with {len(data_dict)} items")

except Exception as e:
logger.error(f"Failed to save cache for {self.cls.__name__}: {e}")

def get(self, asset_id: str) -> Optional[T]:
"""Retrieve a cached asset by its ID.
Expand All @@ -213,7 +237,12 @@ def get(self, asset_id: str) -> Optional[T]:
Returns:
Optional[T]: The cached asset instance if found, None otherwise.
"""
return self.store.data.get(asset_id)
result = self.store.data.get(asset_id)
if result:
logger.info(f"Cache hit for {self.cls.__name__} asset: {asset_id}")
else:
logger.info(f"Cache miss for {self.cls.__name__} asset: {asset_id}")
return result

def add(self, asset: T) -> None:
"""Add a single asset to the cache.
Expand All @@ -226,6 +255,7 @@ def add(self, asset: T) -> None:
This method automatically saves the updated cache to disk after
adding the asset.
"""
logger.info(f"Adding {self.cls.__name__} asset to cache: {asset.id}")
self.store.data[asset.id] = asset.__dict__
self.save()

Expand All @@ -242,6 +272,7 @@ def add_list(self, assets: List[T]) -> None:
This method automatically saves the updated cache to disk after
adding the assets.
"""
logger.info(f"Adding {len(assets)} {self.cls.__name__} assets to cache (replacing existing)")
self.store.data = {asset.id: asset for asset in assets}
self.save()

Expand All @@ -261,8 +292,13 @@ def has_valid_cache(self) -> bool:
bool: True if the cache has not expired and contains data,
False otherwise.
"""
return self.store.expiry >= time.time() and bool(self.store.data)

is_valid = self.store.expiry >= time.time() and bool(self.store.data)
logger.info(
f"Cache validity check for {self.cls.__name__}: {is_valid} (expiry: {self.store.expiry}, current: {time.time()}, data count: {len(self.store.data)})"
)
return is_valid


def serialize(obj: Any) -> Any:
"""Convert a Python object into a JSON-serializable format.

Expand Down Expand Up @@ -292,4 +328,3 @@ def serialize(obj: Any) -> Any:
return serialize(vars(obj))
else:
return str(obj)

33 changes: 29 additions & 4 deletions aixplain/utils/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import logging
from filelock import FileLock

logging.getLogger("filelock").setLevel(logging.INFO)

logger = logging.getLogger(__name__)

CACHE_FOLDER = ".cache"
CACHE_FILE = f"{CACHE_FOLDER}/cache.json"
LOCK_FILE = f"{CACHE_FILE}.lock"
Expand Down Expand Up @@ -38,13 +42,18 @@ def save_to_cache(cache_file: str, data: dict, lock_file: str) -> None:
- Logs an error if saving fails but doesn't raise an exception
- The data is saved with a timestamp for expiration checking
"""
logger.info(f"Attempting to save cache to {cache_file}")
try:
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
logger.info(f"Cache directory created/verified: {os.path.dirname(cache_file)}")

with FileLock(lock_file):
logger.info(f"Acquired file lock: {lock_file}")
with open(cache_file, "w") as f:
json.dump({"timestamp": time.time(), "data": data}, f)
logger.info(f"Successfully saved cache to {cache_file}")
except Exception as e:
logging.error(f"Failed to save cache to {cache_file}: {e}")
logger.error(f"Failed to save cache to {cache_file}: {e}")


def load_from_cache(cache_file: str, lock_file: str) -> dict:
Expand All @@ -66,12 +75,28 @@ def load_from_cache(cache_file: str, lock_file: str) -> dict:
- Returns None if the cached data has expired based on CACHE_EXPIRY_TIME
- Uses thread-safe file locking for reading
"""
if os.path.exists(cache_file):
logger.info(f"Attempting to load cache from {cache_file}")

if not os.path.exists(cache_file):
logger.info(f"Cache file does not exist: {cache_file}")
return None

try:
with FileLock(lock_file):
logger.info(f"Acquired file lock for reading: {lock_file}")
with open(cache_file, "r") as f:
cache_data = json.load(f)
if time.time() - cache_data["timestamp"] < int(get_cache_expiry()):
cache_age = time.time() - cache_data["timestamp"]
expiry_time = int(get_cache_expiry())

logger.info(f"Cache age: {cache_age:.2f}s, expiry threshold: {expiry_time}s")

if cache_age < expiry_time:
logger.info(f"Successfully loaded valid cache from {cache_file}")
return cache_data["data"]
else:
logger.info(f"Cache expired (age: {cache_age:.2f}s > {expiry_time}s): {cache_file}")
return None
return None
except Exception as e:
logger.error(f"Failed to load cache from {cache_file}: {e}")
return None