diff --git a/core/src/main/scala/besom/internal/BesomSyntax.scala b/core/src/main/scala/besom/internal/BesomSyntax.scala index 05a8338b..83462510 100644 --- a/core/src/main/scala/besom/internal/BesomSyntax.scala +++ b/core/src/main/scala/besom/internal/BesomSyntax.scala @@ -119,6 +119,26 @@ trait BesomSyntax: } end component + extension [A <: ProviderResource](pr: A) + def provider(using Context): Output[Option[ProviderResource]] = Output { + Context().resources.getStateFor(pr).map(_.custom.provider) + } + + extension [A <: CustomResource](cr: A) + def provider(using Context): Output[Option[ProviderResource]] = Output { + Context().resources.getStateFor(cr).map(_.provider) + } + + extension [A <: ComponentResource](cmpr: A) + def providers(using Context): Output[Map[String, ProviderResource]] = Output { + Context().resources.getStateFor(cmpr).map(_.providers) + } + + extension [A <: RemoteComponentResource](cb: A) + def providers(using Context): Output[Map[String, ProviderResource]] = Output { + Context().resources.getStateFor(cb).map(_.providers) + } + extension [A <: Resource: ResourceDecoder](companion: ResourceCompanion[A]) def get(name: Input[NonEmptyString], id: Input[ResourceId])(using ctx: Context): Output[A] = for diff --git a/core/src/test/scala/besom/internal/DummyContext.scala b/core/src/test/scala/besom/internal/DummyContext.scala index cf0d94c0..08d00b01 100644 --- a/core/src/test/scala/besom/internal/DummyContext.scala +++ b/core/src/test/scala/besom/internal/DummyContext.scala @@ -56,15 +56,18 @@ object DummyContext: monitor: Monitor = dummyMonitor, engine: Engine = dummyEngine, configMap: Map[NonEmptyString, String] = Map.empty, - configSecretKeys: Set[NonEmptyString] = Set.empty + configSecretKeys: Set[NonEmptyString] = Set.empty, + resources: Resources | Result[Resources] = Resources() ): Result[Context] = for taskTracker <- TaskTracker() stackPromise <- Promise[StackResource]() logger <- BesomLogger.local() config <- Config(runInfo.project, isProjectName = true, configMap = configMap, configSecretKeys = configSecretKeys) - resources <- Resources() - memo <- Memo() + resources <- resources match + case r: Resources => Result.pure(r) + case r: Result[Resources] => r + memo <- Memo() given Context = Context.create(runInfo, featureSupport, config, logger, monitor, engine, taskTracker, resources, memo, stackPromise) _ <- stackPromise.fulfill(StackResource()(using ComponentBase(Output(besom.types.URN.empty)))) yield summon[Context] diff --git a/core/src/test/scala/besom/internal/ResourceOpsTest.scala b/core/src/test/scala/besom/internal/ResourceOpsTest.scala index 20a0c965..156cdd81 100644 --- a/core/src/test/scala/besom/internal/ResourceOpsTest.scala +++ b/core/src/test/scala/besom/internal/ResourceOpsTest.scala @@ -6,6 +6,7 @@ import besom.internal.RunResult.given import besom.internal.RunOutput.* import besom.internal.logging.BesomMDC import besom.internal.logging.Key.LabelKey +import besom.aliases.Output class ResourceOpsTest extends munit.FunSuite: import ResourceOpsTest.fixtures.* @@ -300,6 +301,154 @@ class ResourceOpsTest extends munit.FunSuite: assertEquals(transitiveDeps, Set(cust1Urn, cust2Urn, dep1Urn, dep2Urn, remote1Urn)) } + + test("resource provider getters") { + object syntax extends BesomSyntax + val resources = Resources().unsafeRunSync() + given Context = DummyContext(resources = resources).unsafeRunSync() + // given BesomMDC[Label] = BesomMDC[Label](LabelKey, Label.fromNameAndType("test", "pkg:test:test")) + + val providerRes = TestProviderResource( + Output( + URN( + "urn:pulumi:stack::project::provider:resources:TestProviderResource::provider" + ) + ), + Output(ResourceId("provider")), + Output("provider") + ) + + resources + .add( + providerRes, + ProviderResourceState( + CustomResourceState( + CommonResourceState( + children = Set.empty, + provider = None, + providers = Map.empty, + version = "0.0.1", + pluginDownloadUrl = "", + name = "provider", + typ = "pulumi:providers:TestProviderResource", + keepDependency = false // providers never have keepDependency set to true + ), + Output(ResourceId("provider")) + ), + ProviderType.from("pulumi:providers:TestProviderResource").getPackage + ) + ) + .unsafeRunSync() + + val cust1Urn = URN( + "urn:pulumi:stack::project::custom:resources:TestCustomResource::cust1" + ) + val cust1 = TestCustomResource(Output(cust1Urn), Output(ResourceId("cust1")), Output(1)) + + resources + .add( + cust1, + CustomResourceState( + CommonResourceState( + children = Set.empty, + provider = Some(providerRes), + providers = Map.empty, + version = "0.0.1", + pluginDownloadUrl = "", + name = "cust1", + typ = "custom:resources:TestCustomResource", + keepDependency = false // custom resources never have keepDependency set to true + ), + Output(ResourceId("cust1")) + ) + ) + .unsafeRunSync() + + val comp1Urn = URN( + "urn:pulumi:stack::project::component:resources:TestComponentResource::comp1" + ) + val compBase1 = ComponentBase( + Output(comp1Urn) + ) + + val comp1 = TestComponentResource( + Output("comp1") + )(using compBase1) + + resources + .add( + compBase1, + ComponentResourceState( + CommonResourceState( + children = Set.empty, + provider = None, + providers = Map( + "provider" -> providerRes + ), + version = "0.0.1", + pluginDownloadUrl = "", + name = "comp1", + typ = "component:resources:TestComponentResource", + keepDependency = false + ) + ) + ) + .unsafeRunSync() + + val remote1Urn = URN( + "urn:pulumi:stack::project::component:resources:TestRemoteComponentResource::remote1" + ) + + val remote1 = TestRemoteComponentResource( + Output(remote1Urn), + Output("remote1") + ) + + resources + .add( + remote1, + ComponentResourceState( + CommonResourceState( + children = Set.empty, + provider = None, + providers = Map( + "provider" -> providerRes + ), + version = "0.0.1", + pluginDownloadUrl = "", + name = "remote1", + typ = "component:resources:TestRemoteComponentResource", + keepDependency = true + ) + ) + ) + .unsafeRunSync() + + import syntax.{provider, providers} + + providerRes.provider.unsafeRunSync().get match + case Some(_) => fail("providerRes.provider should be None") + case None => + + cust1.provider.unsafeRunSync().get match + case Some(p) => + assert(p == providerRes) + case None => + fail("cust1.provider should be Some(providerRes)") + + comp1.providers.unsafeRunSync().get match + case m if m.isEmpty => + fail("comp1.providers should not be empty") + case m => + assertEquals(m, Map("provider" -> providerRes)) + + remote1.providers.unsafeRunSync().get match + case m if m.isEmpty => + fail("remote1.providers should not be empty") + case m => + assertEquals(m, Map("provider" -> providerRes)) + } + end ResourceOpsTest object ResourceOpsTest: @@ -310,4 +459,6 @@ object ResourceOpsTest: case class TestComponentResource(str: Output[String])(using ComponentBase) extends ComponentResource + case class TestProviderResource(urn: Output[URN], id: Output[ResourceId], str: Output[String]) extends ProviderResource + def prepareTransitiveDependencyResolutionTree(resources: Resources): Result[Unit] = ???