Skip to content

Commit

Permalink
come on
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Apr 9, 2024
1 parent 88ac1ed commit 642ab2a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
6 changes: 3 additions & 3 deletions test/pytorch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ if VERSION >= v"1.8"
max_iter = 1000
conv =
CounterfactualExplanations.Convergence.DecisionThresholdConvergence(
γ,
max_iter,
decision_threshold = γ,
max_iter = max_iter,
)
counterfactual = CounterfactualExplanations.generate_counterfactual(
x,
Expand All @@ -143,7 +143,7 @@ if VERSION >= v"1.8"
γ = minimum([1 / length(counterfactual_data.y_levels), 0.5])
conv =
CounterfactualExplanations.Convergence.DecisionThresholdConvergence(
γ,
decision_threshold = γ,
)
counterfactual = CounterfactualExplanations.generate_counterfactual(
x,
Expand Down
4 changes: 3 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using TaijaInteroperability

"""
_load_synthetic()
Expand Down Expand Up @@ -134,7 +136,7 @@ function train_and_save_pytorch_model(
NeuralNetwork = neural_network_class.NeuralNetwork
model = NeuralNetwork()

x_python, y_python = preprocess_python_data(data)
x_python, y_python = TaijaInteroperability.preprocess_python_data(data)

optimizer = torch.optim.Adam(model.parameters(); lr = 0.1)
loss_fun = torch.nn.BCEWithLogitsLoss()
Expand Down

0 comments on commit 642ab2a

Please sign in to comment.