Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2288,11 +2288,14 @@ def addArtifacts(
messageParameters={"normalized_path": normalized_path},
)
if archive:
self._sc.addArchive(*path)
for p in path:
self._sc.addArchive(p)
elif pyfile:
self._sc.addPyFile(*path)
for p in path:
self._sc.addPyFile(p)
elif file:
self._sc.addFile(*path) # type: ignore[arg-type]
for p in path:
self._sc.addFile(p)

addArtifact = addArtifacts

Expand Down
32 changes: 32 additions & 0 deletions python/pyspark/sql/tests/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pyspark.sql.tests.connect.client.test_artifact import ArtifactTestsMixin
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.errors import PySparkRuntimeError
from pyspark.sql.functions import assert_true, lit, udf


class ArtifactTests(ArtifactTestsMixin, ReusedSQLTestCase):
Expand All @@ -36,6 +37,37 @@ def test_add_pyfile(self):
# file from different session.
self.check_add_pyfile(self.spark.newSession())

def test_add_multiple_pyfiles(self):
def check_add_multiple_pyfiles(spark_session):
with tempfile.TemporaryDirectory(prefix="check_add_multiple_pyfiles") as d:
pyfile_paths = []
for name, value in [
("my_pyfile_a.py", 1),
("my_pyfile_b.py", 2),
("my_pyfile_c.py", 3),
]:
pyfile_path = os.path.join(d, name)
with open(pyfile_path, "w") as f:
f.write(f"my_func = lambda: {value}")
pyfile_paths.append(pyfile_path)

@udf("int")
def func(x):
import my_pyfile_a
import my_pyfile_b
import my_pyfile_c

return my_pyfile_a.my_func() + my_pyfile_b.my_func() + my_pyfile_c.my_func()

spark_session.addArtifacts(*pyfile_paths, pyfile=True)
spark_session.range(1).select(assert_true(func("id") == lit(6))).show()

check_add_multiple_pyfiles(self.spark)

# Test multi sessions. Should be able to add the same
# files from different session.
check_add_multiple_pyfiles(self.spark.newSession())

def test_add_file(self):
self.check_add_file(self.spark)

Expand Down