Skip to content

Commit

Permalink
feat(stubs): assume latest version by default, optionally show latest…
Browse files Browse the repository at this point in the history
… only in search, general improvements in stub repo.

Signed-off-by: Braden Mars <bradenmars@bradenmars.me>
  • Loading branch information
BradenM committed Dec 11, 2022
1 parent 49b6df0 commit b55b483
Showing 1 changed file with 46 additions and 12 deletions.
58 changes: 46 additions & 12 deletions micropy/stubs/repo.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import collections
from collections import defaultdict
from typing import TYPE_CHECKING, Generator, Iterator
import inspect
from typing import TYPE_CHECKING, ClassVar, Generator, Iterator, Optional, Type

import attrs
import micropy.exceptions as exc
Expand All @@ -19,12 +19,22 @@
class StubRepository:
manifests: list[StubsManifest] = attrs.field(factory=list)

packages_index: dict[str, StubRepositoryPackage] = attrs.field(factory=dict)
versions_index: defaultdict[str, list[StubRepositoryPackage]] = attrs.field(
factory=lambda: collections.defaultdict(list)
packages_index: collections.ChainMap[str, StubRepositoryPackage] = attrs.field(
factory=collections.ChainMap
)
versions_index: collections.ChainMap[str, list[StubRepositoryPackage]] = attrs.field(
factory=collections.ChainMap
)

manifest_types: ClassVar[list[Type[StubsManifest]]] = []

def __attrs_post_init__(self) -> None:
if not any(StubRepository.manifest_types):
StubRepository.manifest_types = [
klass
for klass in get_all_subclasses(StubsManifest)
if not inspect.isabstract(klass)
]
self.build_indexes()

@property
Expand All @@ -38,10 +48,15 @@ def build_indexes(self) -> None:
pkg = next(iter(manifest.packages), None)
if pkg and manifest.resolve_package_absolute_versioned_name(pkg) in self.packages_index:
continue
packages_index = dict()
versions_index = collections.defaultdict(list)
for package in manifest.packages:
repo_package = StubRepositoryPackage(manifest=manifest, package=package)
self.packages_index[repo_package.absolute_versioned_name] = repo_package
self.versions_index[repo_package.absolute_name].append(repo_package)
packages_index[repo_package.absolute_versioned_name] = repo_package
versions_index[repo_package.name].append(repo_package)

self.packages_index = self.packages_index.new_child(packages_index)
self.versions_index = self.versions_index.new_child(dict(versions_index))

def add_repository(self, info: RepositoryInfo) -> StubRepository:
"""Creates a new `StubRepository` instance with a `StubManifest` derived from `info`.
Expand All @@ -55,7 +70,7 @@ def add_repository(self, info: RepositoryInfo) -> StubRepository:
"""
contents = info.fetch_source()
data = dict(repository=info, packages=contents)
for manifest_type in get_all_subclasses(StubsManifest):
for manifest_type in StubRepository.manifest_types:
try:
manifest = manifest_type.parse_obj(data)
except (
Expand All @@ -72,20 +87,34 @@ def add_repository(self, info: RepositoryInfo) -> StubRepository:
)
raise ValueError(f"Failed to determine manifest format for repo: {info}")

def search(self, query: str) -> Generator[StubRepositoryPackage, None, None]:
def search(
self, query: str, include_versions: bool = True
) -> Generator[StubRepositoryPackage, None, None]:
"""Search packages for `query`.
Args:
query: Search constraint.
include_versions: Whether to include versions in search results.
Returns:
A generator of `StubRepositoryPackage` objects.
"""
query = query.strip().lower()
for package_name in self.versions_index:
for package_name in self.versions_index.keys():
if query in package_name.lower() or package_name.lower() in query:
yield from self.versions_index[package_name]
if include_versions:
yield from self.versions_index[package_name]
continue
yield self.latest_for_package(self.versions_index[package_name][0])

def latest_for_package(
self, repo_package: StubRepositoryPackage
) -> Optional[StubRepositoryPackage]:
versions = self.versions_index[repo_package.name]
if len(versions) == 1:
return versions[0]
return max(versions, key=lambda x: x.package.version)

def resolve_package(self, name: str) -> str:
"""Resolve a package name to a package path.
Expand All @@ -101,6 +130,11 @@ def resolve_package(self, name: str) -> str:
"""
for package in self.search(name):
if package.match_exact(name):
if package.match_exact(name) or package.match_exact(
"/".join([package.repo_name, name])
):
return package.url
latest = self.latest_for_package(package)
if latest and latest.name == name:
return latest.url
raise exc.StubNotFound(name)

0 comments on commit b55b483

Please sign in to comment.