Skip to content

Commit

Permalink
again
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Apr 8, 2024
1 parent 19606ec commit 1b64d56
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion ext/PyTorchModelExt/generators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The gradients are calculated through PyTorch using PythonCall.jl.
# Arguments
- `generator::AbstractGradientBasedGenerator`: The generator object that is used to generate the counterfactual explanation.
- `M::Models.PyTorchModel`: The PyTorch model for which the counterfactual is generated.
- `M::PyTorchModel`: The PyTorch model for which the counterfactual is generated.
- `ce::AbstractCounterfactualExplanation`: The counterfactual explanation object for which the gradient is calculated.
# Returns
Expand Down
4 changes: 2 additions & 2 deletions ext/PyTorchModelExt/utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
CounterfactualExplanations.pytorch_model_loader(model_path::String, model_file::String, class_name::String, pickle_path::String)
pytorch_model_loader(model_path::String, model_file::String, class_name::String, pickle_path::String)
Loads a previously saved PyTorch model.
Expand Down Expand Up @@ -44,7 +44,7 @@ function TaijaInteroperability.pytorch_model_loader(
end

"""
CounterfactualExplanations.preprocess_python_data(data::CounterfactualData)
preprocess_python_data(data::CounterfactualData)
Converts a `CounterfactualData` object to an input tensor and a label tensor.
Expand Down
10 changes: 5 additions & 5 deletions test/pytorch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ if VERSION >= v"1.8"
# Create and save model in the model_path directory
create_new_pytorch_model(data, model_path)
train_and_save_pytorch_model(data, model_location, pickle_path)
model_loaded = Models.pytorch_model_loader(
model_loaded = pytorch_model_loader(
model_location, model_file, class_name, pickle_path
)

model_pytorch = Models.PyTorchModel(model_loaded, data.likelihood)
model_pytorch = PyTorchModel(model_loaded, data.likelihood)

@testset "Test for errors" begin
@test_throws ArgumentError Models.PyTorchModel(
@test_throws ArgumentError PyTorchModel(
model_loaded, :regression
)
end
Expand Down Expand Up @@ -67,10 +67,10 @@ if VERSION >= v"1.8"
train_and_save_pytorch_model(
counterfactual_data, model_location, pickle_path
)
model_loaded = Models.pytorch_model_loader(
model_loaded = pytorch_model_loader(
model_location, model_file, class_name, pickle_path
)
M = Models.PyTorchModel(model_loaded, counterfactual_data.likelihood)
M = PyTorchModel(model_loaded, counterfactual_data.likelihood)

# Randomly selected factual:
Random.seed!(123)
Expand Down

0 comments on commit 1b64d56

Please sign in to comment.