Skip to content

Commit

Permalink
Adding lock to control experiment creation during multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
pebeto committed Apr 9, 2024
1 parent ab7442d commit 70c458b
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 21 deletions.
10 changes: 6 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,21 @@ version = "0.4.1"
MLFlowClient = "64a0f543-368b-4a9a-827a-e71edb2a0b83"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MLJTuning = "03970b2e-30c4-11ea-3135-d1576263f10f"

[compat]
MLFlowClient = "0.4.4"
MLJBase = "1.0.1"
MLJModelInterface = "1.9.3"
MLFlowClient = "0.4.6"
MLJBase = "1.1.2"
MLJModelInterface = "1.9.5"
julia = "1.6"

[extras]
MLFlowClient = "64a0f543-368b-4a9a-827a-e71edb2a0b83"
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
MLJTuning = "03970b2e-30c4-11ea-3135-d1576263f10f"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "MLFlowClient", "MLJModels", "MLJDecisionTreeInterface", "StatisticalMeasures"]
test = ["MLFlowClient", "MLJDecisionTreeInterface", "MLJModels", "MLJTuning", "StatisticalMeasures", "Test"]
39 changes: 23 additions & 16 deletions src/base.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
LOG_EVALUATION_LOCK = ReentrantLock()

function log_evaluation(logger::Logger, performance_evaluation)
experiment = getorcreateexperiment(logger.service, logger.experiment_name;
artifact_location=logger.artifact_location)
run = createrun(logger.service, experiment;
tags=[
Dict(
"key" => "resampling",
"value" => string(performance_evaluation.resampling)
),
Dict("key" => "repeats", "value" => string(performance_evaluation.repeats)),
Dict("key" => "model type", "value" => name(performance_evaluation.model)),
]
)
lock(LOG_EVALUATION_LOCK)
try
experiment = getorcreateexperiment(logger.service, logger.experiment_name;
artifact_location=logger.artifact_location)
run = createrun(logger.service, experiment;
tags=[
Dict(
"key" => "resampling",
"value" => string(performance_evaluation.resampling)
),
Dict("key" => "repeats", "value" => string(performance_evaluation.repeats)),
Dict("key" => "model type", "value" => name(performance_evaluation.model)),
]
)

logmodelparams(logger.service, run, performance_evaluation.model)
logmachinemeasures(logger.service, run, performance_evaluation.measure,
performance_evaluation.measurement)
logmodelparams(logger.service, run, performance_evaluation.model)
logmachinemeasures(logger.service, run, performance_evaluation.measure,
performance_evaluation.measurement)

updaterun(logger.service, run, "FINISHED")
updaterun(logger.service, run, "FINISHED")
finally
unlock(LOG_EVALUATION_LOCK)
end
end

function save(logger::Logger, machine:: Machine)
Expand Down
38 changes: 38 additions & 0 deletions test/multiprocessing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
@testset verbose = true "multiprocessing" begin
logger = MLJFlow.Logger(ENV["MLFLOW_URI"];
experiment_name="MLJFlow multiprocessing tests",
artifact_location="/tmp/mlj-test")

X, y = make_moons(100)
DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree

model = DecisionTreeClassifier()
r = range(model, :max_depth, lower=1, upper=6)

function test_tuned_model(acceleration_method)
tuned_model = TunedModel(
model=model,
range=r,
logger=logger,
acceleration=acceleration_method,
n=100,
)
tuned_model_mach = machine(tuned_model, X, y)
fit!(tuned_model_mach)

experiment = getorcreateexperiment(logger.service, logger.experiment_name)
runs = searchruns(logger.service, experiment)

@assert length(runs) == 100

deleteexperiment(logger.service, experiment)
end

@testset "log_evaluation_with_cpu_threads" begin
test_tuned_model(CPUThreads())
end

@testset "log_evaluation_with_cpu_processes" begin
test_tuned_model(CPUProcesses())
end
end
4 changes: 3 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
using Test
using .Threads

using MLJFlow

using MLJBase
using MLJModels
using MLJTuning
using MLFlowClient
using MLJModelInterface
using StatisticalMeasures
Expand All @@ -21,4 +23,4 @@ end
include("base.jl")
include("types.jl")
include("service.jl")

include("multiprocessing.jl")

0 comments on commit 70c458b

Please sign in to comment.