Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 29 additions & 26 deletions src/Providers/ProviderRegistry.php
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ class ProviderRegistry implements WithHttpTransporterInterface
/**
* @var array<string, class-string<ProviderInterface>> Mapping of provider IDs to class names.
*/
private array $providerClassNames = [];
private array $registeredIdsToClassNames = [];

/**
* @var array<class-string<ProviderInterface>, true> Set of registered class names for fast lookup.
* @var array<class-string<ProviderInterface>, string> Mapping of provider class names to IDs.
*/
private array $registeredClassNames = [];
private array $registeredClassNamesToIds = [];

/**
* @var array<class-string<ProviderInterface>, RequestAuthenticationInterface> Mapping of provider class names to
Expand Down Expand Up @@ -109,8 +109,8 @@ public function registerProvider(string $className): void
$this->setRequestAuthenticationForProvider($className, $this->providerAuthenticationInstances[$className]);
}

$this->providerClassNames[$metadata->getId()] = $className;
$this->registeredClassNames[$className] = true;
$this->registeredIdsToClassNames[$metadata->getId()] = $className;
$this->registeredClassNamesToIds[$className] = $metadata->getId();
}

/**
Expand All @@ -122,7 +122,7 @@ public function registerProvider(string $className): void
*/
public function getRegisteredProviderIds(): array
{
return array_keys($this->providerClassNames);
return array_keys($this->registeredIdsToClassNames);
}

/**
Expand All @@ -135,28 +135,35 @@ public function getRegisteredProviderIds(): array
*/
public function hasProvider(string $idOrClassName): bool
{
return isset($this->providerClassNames[$idOrClassName]) ||
isset($this->registeredClassNames[$idOrClassName]);
return isset($this->registeredIdsToClassNames[$idOrClassName]) ||
isset($this->registeredClassNamesToIds[$idOrClassName]);
}

/**
* Gets the class name for a registered provider.
*
* @since 0.1.0
*
* @param string $id The provider ID.
* @param string|class-string<ProviderInterface> $idOrClassName The provider ID or class name.
* @return string The provider class name.
* @throws InvalidArgumentException If the provider is not registered.
*/
public function getProviderClassName(string $id): string
public function getProviderClassName(string $idOrClassName): string
{
if (!isset($this->providerClassNames[$id])) {
throw new InvalidArgumentException(
sprintf('Provider not registered: %s', $id)
);
// If it's already a class name, return it
if (isset($this->registeredClassNamesToIds[$idOrClassName])) {
return $idOrClassName;
}

// If it's a registered ID, return its class name
if (isset($this->registeredIdsToClassNames[$idOrClassName])) {
return $this->registeredIdsToClassNames[$idOrClassName];
}

return $this->providerClassNames[$id];
// Not found
throw new InvalidArgumentException(
sprintf('Provider not registered: %s', $idOrClassName)
);
}

/**
Expand All @@ -171,17 +178,13 @@ public function getProviderClassName(string $id): string
public function getProviderId(string $idOrClassName): string
{
// If it's already an ID, return it
if (isset($this->providerClassNames[$idOrClassName])) {
if (isset($this->registeredIdsToClassNames[$idOrClassName])) {
return $idOrClassName;
}

// If it's a class name, find its ID
if (isset($this->registeredClassNames[$idOrClassName])) {
foreach ($this->providerClassNames as $id => $className) {
if ($className === $idOrClassName) {
return $id;
}
}
// If it's a registered class name, return its ID
if (isset($this->registeredClassNamesToIds[$idOrClassName])) {
return $this->registeredClassNamesToIds[$idOrClassName];
}

// Not found
Expand Down Expand Up @@ -225,7 +228,7 @@ public function findModelsMetadataForSupport(ModelRequirements $modelRequirement
{
$results = [];

foreach ($this->providerClassNames as $providerId => $className) {
foreach ($this->registeredIdsToClassNames as $providerId => $className) {
$providerResults = $this->findProviderModelsMetadataForSupport($providerId, $modelRequirements);
if (!empty($providerResults)) {
// Use static method from ProviderInterface
Expand Down Expand Up @@ -337,7 +340,7 @@ public function bindModelDependencies(ModelInterface $modelInstance): void
private function resolveProviderClassName(string $idOrClassName): string
{
// Handle both ID and class name
$className = $this->providerClassNames[$idOrClassName] ?? $idOrClassName;
$className = $this->registeredIdsToClassNames[$idOrClassName] ?? $idOrClassName;

if (!$this->hasProvider($idOrClassName)) {
throw new InvalidArgumentException(
Expand All @@ -359,7 +362,7 @@ public function setHttpTransporter(HttpTransporterInterface $httpTransporter): v
$this->setHttpTransporterOriginal($httpTransporter);

// Make sure all registered providers have the HTTP transporter hooked up as needed.
foreach ($this->providerClassNames as $className) {
foreach ($this->registeredIdsToClassNames as $className) {
$this->setHttpTransporterForProvider($className, $httpTransporter);
}
}
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/Providers/ProviderRegistryTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ public function testRegisterProviderWithValidProvider(): void
$this->assertTrue($this->registry->hasProvider('mock'));
$this->assertTrue($this->registry->hasProvider(MockProvider::class));
$this->assertEquals(MockProvider::class, $this->registry->getProviderClassName('mock'));
$this->assertEquals('mock', $this->registry->getProviderId(MockProvider::class));

// Ensure calling the lookup functions with a value in the correct format simply returns it as is.
$this->assertEquals(MockProvider::class, $this->registry->getProviderClassName(MockProvider::class));
$this->assertEquals('mock', $this->registry->getProviderId('mock'));
}

/**
Expand Down Expand Up @@ -105,6 +110,19 @@ public function testGetProviderClassNameWithUnregisteredProvider(): void
$this->registry->getProviderClassName('nonexistent');
}

/**
* Tests getProviderId with unregistered provider.
*
* @return void
*/
public function testGetProviderIdWithUnregisteredProvider(): void
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Provider not registered: ' . InvalidArgumentException::class);

$this->registry->getProviderId(InvalidArgumentException::class);
}

/**
* Tests isProviderConfigured with registered provider.
*
Expand Down
Loading