Skip to content
This repository has been archived by the owner on Aug 5, 2024. It is now read-only.

Commit

Permalink
[maintenance] generating separate .proto files per extension | #BAZEL…
Browse files Browse the repository at this point in the history
…-531 Done

Merge-request: BAZEL-MR-408
Merged-by: Katarzyna Mielnik <katarzyna.anna.mielnik@jetbrains.com>
  • Loading branch information
mielnikk authored and qodana-bot committed Aug 22, 2023
1 parent 6a8a7ac commit 3b4e1f7
Show file tree
Hide file tree
Showing 8 changed files with 215 additions and 84 deletions.
65 changes: 28 additions & 37 deletions aspects/core.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,22 @@ load("//aspects:rules/kt/kt_info.bzl", "extract_kotlin_info")
load("//aspects:rules/cpp/cpp_info.bzl", "extract_cpp_info")
load("//aspects:rules/scala/scala_info.bzl", "extract_scala_info", "extract_scala_toolchain_info")
load("//aspects:rules/java/java_info.bzl", "JAVA_RUNTIME_TOOLCHAIN_TYPE", "extract_java_info", "extract_java_runtime", "extract_java_toolchain")
load("//aspects:utils/utils.bzl", "create_struct", "file_location", "update_sync_output_groups")

def get_aspect_ids(ctx, target):
"""Returns the all aspect ids, filtering out self."""
aspect_ids = None
if hasattr(ctx, "aspect_ids"):
aspect_ids = ctx.aspect_ids
elif hasattr(target, "aspect_ids"):
aspect_ids = target.aspect_ids
else:
return None
return [aspect_id for aspect_id in aspect_ids if "bsp_target_info_aspect" not in aspect_id]

def abs(num):
if num < 0:
return -num
else:
return num
load("//aspects:utils/utils.bzl", "abs", "create_struct", "file_location", "get_aspect_ids", "update_sync_output_groups")

EXTENSIONS = [
extract_java_info,
extract_kotlin_info,
extract_java_toolchain,
extract_java_runtime,
extract_scala_info,
extract_scala_toolchain_info,
extract_python_info,
extract_cpp_info,
]

def create_all_extension_info(target, ctx, output_groups, dep_targets):
info = [create_extension_info(target = target, ctx = ctx, output_groups = output_groups, dep_targets = dep_targets) for create_extension_info in EXTENSIONS]
return [(file, data) for file, data in info if file != None]

def _collect_target_from_attr(rule_attrs, attr_name, result):
"""Collects the targets from the given attr into the result."""
Expand Down Expand Up @@ -149,14 +147,7 @@ def _bsp_target_info_aspect_impl(target, ctx):
for f in t.files.to_list()
]

java_target_info = extract_java_info(target, ctx, output_groups)
scala_toolchain_info = extract_scala_toolchain_info(target, ctx, output_groups)
scala_target_info = extract_scala_info(target, ctx, output_groups)
java_toolchain_info, java_toolchain_info_exported = extract_java_toolchain(target, ctx, dep_targets)
java_runtime_info, java_runtime_info_exported = extract_java_runtime(target, ctx, dep_targets)
cpp_target_info = extract_cpp_info(target, ctx)
kotlin_target_info = extract_kotlin_info(target, ctx)
python_target_info = extract_python_info(target, ctx)
aspect_ids = get_aspect_ids(ctx, target)

result = dict(
id = str(target.label),
Expand All @@ -165,23 +156,24 @@ def _bsp_target_info_aspect_impl(target, ctx):
dependencies = list(all_deps),
sources = sources,
resources = resources,
scala_target_info = scala_target_info,
scala_toolchain_info = scala_toolchain_info,
java_target_info = java_target_info,
java_toolchain_info = java_toolchain_info,
java_runtime_info = java_runtime_info,
cpp_target_info = cpp_target_info,
kotlin_target_info = kotlin_target_info,
python_target_info = python_target_info,
env = getattr(rule_attrs, "env", {}),
env_inherit = getattr(rule_attrs, "env_inherit", []),
)

extension_info = create_all_extension_info(target, ctx, output_groups, dep_targets)
extension_exported_properties = dict()
for (_, data) in extension_info:
if data != None:
extension_exported_properties.update(data)

info_files = [file for (file, _) in extension_info]
update_sync_output_groups(output_groups, "bsp-target-info", depset(info_files))

file_name = target.label.name
file_name = file_name + "-" + str(abs(hash(file_name)))
aspect_ids = get_aspect_ids(ctx, target)
if aspect_ids:
file_name = file_name + "-" + str(abs(hash(".".join(aspect_ids))))
file_name = "%s.general" % file_name
file_name = "%s.bsp-info.textproto" % file_name
info_file = ctx.actions.declare_file(file_name)
ctx.actions.write(info_file, create_struct(**result).to_proto())
Expand All @@ -192,9 +184,8 @@ def _bsp_target_info_aspect_impl(target, ctx):
kind = ctx.rule.kind,
export_deps = export_deps,
output_groups = output_groups,
**extension_exported_properties
)
exported_properties.update(java_toolchain_info_exported)
exported_properties.update(java_runtime_info_exported)

return struct(
bsp_info = struct(**exported_properties),
Expand Down
10 changes: 6 additions & 4 deletions aspects/rules/cpp/cpp_info.bzl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
load("//aspects:utils/utils.bzl", "create_struct")
load("//aspects:utils/utils.bzl", "create_proto", "create_struct")

def extract_cpp_info(target, ctx):
def extract_cpp_info(target, ctx, **kwargs):
if CcInfo not in target:
return None
return None, None

return create_struct(
result = create_struct(
copts = getattr(ctx.rule.attr, "copts", []),
defines = getattr(ctx.rule.attr, "defines", []),
link_opts = getattr(ctx.rule.attr, "linkopts", []),
link_shared = getattr(ctx.rule.attr, "linkshared", False),
)

return create_proto(target, ctx, result, "cpp_target_info"), None
26 changes: 15 additions & 11 deletions aspects/rules/java/java_info.bzl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
load("//aspects:utils/java_utils.bzl", "get_java_provider")
load("//aspects:utils/utils.bzl", "create_struct", "file_location", "is_external", "map", "to_file_location", "update_sync_output_groups")
load("//aspects:utils/utils.bzl", "create_proto", "create_struct", "file_location", "is_external", "map", "to_file_location", "update_sync_output_groups")

def map_with_resolve_files(f, xs):
results = []
Expand Down Expand Up @@ -91,17 +91,17 @@ def extract_compile_jars(provider):

return compilation_info.compilation_classpath if compilation_info else transitive_compile_time_jars

def extract_java_info(target, ctx, output_groups):
def extract_java_info(target, ctx, output_groups, **kwargs):
provider = get_java_provider(target)
if not provider:
return None
return None, None

if hasattr(provider, "java_outputs") and provider.java_outputs:
java_outputs = provider.java_outputs
elif hasattr(provider, "outputs") and provider.outputs:
java_outputs = provider.outputs.jars
else:
return None
return None, None

resolve_files = []

Expand Down Expand Up @@ -130,7 +130,7 @@ def extract_java_info(target, ctx, output_groups):
if (is_external(target)):
update_sync_output_groups(output_groups, "external-deps-resolve", depset(resolve_files))

return create_struct(
info = create_struct(
jars = jars,
generated_jars = generated_jars,
runtime_classpath = runtime_classpath,
Expand All @@ -142,7 +142,9 @@ def extract_java_info(target, ctx, output_groups):
args = args,
)

def extract_java_toolchain(target, ctx, dep_targets):
return create_proto(target, ctx, info, "java_target_info"), None

def extract_java_toolchain(target, ctx, dep_targets, **kwargs):
toolchain = None

if hasattr(target, "java_toolchain"):
Expand All @@ -166,13 +168,14 @@ def extract_java_toolchain(target, ctx, dep_targets):
break

if toolchain_info != None:
return toolchain_info, dict(java_toolchain_info = toolchain_info)
info_file = create_proto(target, ctx, toolchain_info, "java_toolchain_info")
return info_file, dict(java_toolchain_info = toolchain_info)
else:
return None, dict()
return None, None

JAVA_RUNTIME_TOOLCHAIN_TYPE = "@bazel_tools//tools/jdk:runtime_toolchain_type"

def extract_java_runtime(target, ctx, dep_targets):
def extract_java_runtime(target, ctx, dep_targets, **kwargs):
runtime = None

if java_common.JavaRuntimeInfo in target: # Bazel 5.4.0 way
Expand All @@ -195,6 +198,7 @@ def extract_java_runtime(target, ctx, dep_targets):
break

if runtime_info != None:
return runtime_info, dict(java_runtime_info = runtime_info)
info_file = create_proto(target, ctx, runtime_info, "java_runtime_info")
return info_file, dict(java_runtime_info = runtime_info)
else:
return None, dict()
return None, None
15 changes: 9 additions & 6 deletions aspects/rules/kt/kt_info.bzl
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
load("//aspects:utils/utils.bzl", "create_struct")
load("//aspects:utils/utils.bzl", "create_proto", "create_struct")

def extract_kotlin_info(target, ctx):
def extract_kotlin_info(target, ctx, **kwargs):
# if KtJvmInfo not in target:
# return None
# return None, None

# provider = target[KtJvmInfo]

if not hasattr(target, "kt"):
return None
return None, None

provider = target.kt

# Only supports JVM platform now
if not hasattr(provider, "language_version"):
return None
return None, None

language_version = getattr(provider, "language_version", None)
api_version = language_version
Expand All @@ -34,4 +34,7 @@ def extract_kotlin_info(target, ctx):
if kotlinc_opts != None:
kotlin_info["kotlinc_opts"] = kotlinc_opts

return create_struct(**kotlin_info)
kotlin_target_info = create_struct(**kotlin_info)
info_file = create_proto(target, ctx, kotlin_target_info, "kotlin_target_info")

return info_file, None
12 changes: 7 additions & 5 deletions aspects/rules/python/python_info.bzl
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
load("//aspects:utils/utils.bzl", "create_struct", "file_location")
load("//aspects:utils/utils.bzl", "create_proto", "create_struct", "file_location")

def extract_python_info(target, ctx):
def extract_python_info(target, ctx, **kwargs):
if PyInfo not in target:
return None
return None, None

if PyRuntimeInfo in target:
provider = target[PyRuntimeInfo]
else:
provider = None
provider = None, None

return create_struct(
python_target_info = create_struct(
interpreter = file_location(getattr(provider, "interpreter", None)),
version = getattr(provider, "python_version", None),
)

return create_proto(target, ctx, python_target_info, "python_target_info"), None
17 changes: 10 additions & 7 deletions aspects/rules/scala/scala_info.bzl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
load("//aspects:utils/java_utils.bzl", "get_java_provider")
load("//aspects:utils/utils.bzl", "file_location", "is_external", "map", "update_sync_output_groups")
load("//aspects:utils/utils.bzl", "create_proto", "file_location", "is_external", "map", "update_sync_output_groups")

def find_scalac_classpath(runfiles):
result = []
Expand All @@ -13,26 +13,28 @@ def find_scalac_classpath(runfiles):
result.append(file)
return result if found_scala_compiler_jar and len(result) >= 3 else []

def extract_scala_toolchain_info(target, ctx, output_groups):
def extract_scala_toolchain_info(target, ctx, output_groups, **kwargs):
runfiles = target.default_runfiles.files.to_list()

classpath = find_scalac_classpath(runfiles)

if not classpath:
return None
return None, None

resolve_files = classpath
compiler_classpath = map(file_location, classpath)

if (is_external(target)):
update_sync_output_groups(output_groups, "external-deps-resolve", depset(resolve_files))

return struct(compiler_classpath = compiler_classpath)
scala_toolchain_info = struct(compiler_classpath = compiler_classpath)

def extract_scala_info(target, ctx, output_groups):
return create_proto(target, ctx, scala_toolchain_info, "scala_toolchain_info"), None

def extract_scala_info(target, ctx, output_groups, **kwargs):
provider = get_java_provider(target)
if not provider:
return None
return None, None

# proper solution, but it requires adding scala_toolchain to the aspect
# SCALA_TOOLCHAIN = "@io_bazel_rules_scala//scala:toolchain_type"
Expand Down Expand Up @@ -64,4 +66,5 @@ def extract_scala_info(target, ctx, output_groups):
scala_info = struct(
scalac_opts = scalac_opts,
)
return scala_info

return create_proto(target, ctx, scala_info, "scala_target_info"), None
35 changes: 35 additions & 0 deletions aspects/utils/utils.bzl
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
def abs(num):
if num < 0:
return -num
else:
return num

def map(f, xs):
return [f(x) for x in xs]

def filter(f, xs):
return [x for x in xs if f(x)]

def file_location(file):
if file == None:
return None
Expand Down Expand Up @@ -49,5 +58,31 @@ def update_sync_output_groups(groups_dict, key, new_set):
def update_set_in_dict(input_dict, key, other_set):
input_dict[key] = depset(transitive = [input_dict.get(key, depset()), other_set])

def get_aspect_ids(ctx, target):
"""Returns the all aspect ids, filtering out self."""
aspect_ids = None
if hasattr(ctx, "aspect_ids"):
aspect_ids = ctx.aspect_ids
elif hasattr(target, "aspect_ids"):
aspect_ids = target.aspect_ids
else:
return None
return [aspect_id for aspect_id in aspect_ids if "bsp_target_info_aspect" not in aspect_id]

def create_proto(target, ctx, data, name):
if data == None:
return None

aspect_ids = get_aspect_ids(ctx, target)
file_name = target.label.name
file_name = file_name + "-" + str(abs(hash(file_name)))
if aspect_ids:
file_name = file_name + "-" + str(abs(hash(".".join(aspect_ids))))
file_name = "%s.%s" % (file_name, name)
file_name = "%s.bsp-info.textproto" % file_name
info_file = ctx.actions.declare_file(file_name)
ctx.actions.write(info_file, data.to_proto())
return info_file

def is_external(target):
return not str(target.label).startswith("@//") and not str(target.label).startswith("//")
Loading

0 comments on commit 3b4e1f7

Please sign in to comment.