Skip to content

Commit

Permalink
Bug Fix in Saving and Loading
Browse files Browse the repository at this point in the history
  • Loading branch information
FBerding committed Oct 2, 2023
1 parent ccb5a82 commit 069e9ec
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
2 changes: 1 addition & 1 deletion R/saving_and_loading.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ load_ai_model<-function(model_dir,ml_framework="auto"){
#For aifeducation 0.2.1 and higher-----------------------------------------
if(methods::is(loaded_model,"TextEmbeddingClassifierNeuralNet")){
loaded_model$load_model(
model_dir=model_dir,
dir_path=model_dir,
ml_framework=ml_framework)
} else if (methods::is(loaded_model,"TextEmbeddingModel")){
if(loaded_model$get_model_info()$model_method%in%c("glove_cluster","lda")==FALSE){
Expand Down
4 changes: 2 additions & 2 deletions R/text_embedding_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -528,9 +528,9 @@ TextEmbeddingModel<-R6::R6Class(

#Change ml framework if requested
if(ml_framework=="tensorflow"){
private$transformer_components$ml_framework=="tensorflow"
private$transformer_components$ml_framework="tensorflow"
} else if(ml_framework=="pytorch"){
private$transformer_components$ml_framework=="pytorch"
private$transformer_components$ml_framework="pytorch"
}

#Search for the corresponding files
Expand Down
26 changes: 26 additions & 0 deletions tests/testthat/test-04_transformer_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -753,4 +753,30 @@ for(ai_method in ai_methods){
}
}

for(ai_method in ai_methods){
for(framework in ml_frameworks){
if(framework=="tensorflow"){
other_framework="pytorch"
} else {
other_framework="tensorflow"
}

tmp_path=testthat::test_path(
paste0(
"test_artefacts/tmp_full_models/",
other_framework,"/",
ai_method,"_embedding")
)

test_that(paste(ai_method,"load from",other_framework,"to",framework,"framework_check"),{
test<-load_ai_model(
model_dir = tmp_path,
ml_framework = framework)

tmp<-test$get_transformer_components()[[4]]

expect_equal(tmp,framework)
})
}
}

0 comments on commit 069e9ec

Please sign in to comment.