Skip to content

Commit

Permalink
Adding option to choose the save format for pytorch models in the cas…
Browse files Browse the repository at this point in the history
…e of base models and TextEmbeddingModels
  • Loading branch information
FBerding committed Jan 22, 2024
1 parent 1219b4b commit 59c3e6a
Show file tree
Hide file tree
Showing 22 changed files with 477 additions and 65 deletions.
63 changes: 58 additions & 5 deletions R/saving_and_loading.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,36 @@ save_ai_model<-function(model,
if(methods::is(model,"TextEmbeddingClassifierNeuralNet") |
methods::is(model,"TextEmbeddingModel")){

#Check for valid save formats----------------------------------------------
if(methods::is(model,"TextEmbeddingClassifierNeuralNet")){
if(model$get_ml_framework()=="pytorch"){
if(save_format%in%c("default","pt","safetensors")==FALSE){
stop("For classifiers based on 'pytorch' only 'pt' and 'safetensors' are
valid save formats.")
}
} else if(model$get_ml_framework()=="tensorflow"){
if(save_format%in%c("default","h5","tf","keras")==FALSE){
stop("For classifiers based on 'tensorflow' only 'h5', 'tf', and 'keras' are
valid save formats.")
}
}
} else if(methods::is(model,"TextEmbeddingModel")){
if(model$get_model_info()$model_method%in%c("glove_cluster","lda")==FALSE){
if(model$get_ml_framework()=="pytorch"){
if(save_format%in%c("default","pt","safetensors")==FALSE){
stop("For TextEmbeddingModels based on 'pytorch' only 'pt' and 'safetensors' are
valid save formats.")
}
} else if(model$get_ml_framework()=="tensorflow"){
if(save_format%in%c("default","h5","tf","keras")==FALSE){
stop("For TextEmbeddingModels based on 'tensorflow' only 'h5' is a
valid save format.")
}
}
}
}


if(is.null(dir_name)){
if(append_ID==TRUE){
final_model_dir_path=paste0(model_dir,"/",model$get_model_info()$model_name)
Expand All @@ -124,15 +154,38 @@ save_ai_model<-function(model,
#Save R Interface------------------------------
save(model,file = paste0(final_model_dir_path,"/r_interface.rda"))

#---------------------
#Get Package Version
if(methods::is(model,"TextEmbeddingClassifierNeuralNet")){
model$save_model(dir_path = final_model_dir_path,save_format=save_format)
aifeducation_version<-model$get_package_versions()[[1]]$aifeducation
} else {
#TextEmbeddingModels
if(model$get_model_info()$model_method%in%c("glove_cluster","lda")==FALSE){
model$save_model(model_dir = final_model_dir_path)
aifeducation_version<-model$get_package_versions()$aifeducation
}

#Save ML-Model---------------------
if(utils::compareVersion(as.character(aifeducation_version),"0.3.0")<=0){
if(methods::is(model,"TextEmbeddingClassifierNeuralNet")){
model$save_model(dir_path = final_model_dir_path,
save_format=save_format)
} else {
#TextEmbeddingModels
if(model$get_model_info()$model_method%in%c("glove_cluster","lda")==FALSE){
model$save_model(model_dir = final_model_dir_path)
}
}
} else {
if(methods::is(model,"TextEmbeddingClassifierNeuralNet")){
model$save_model(dir_path = final_model_dir_path,
save_format=save_format)
} else {
#TextEmbeddingModels
if(model$get_model_info()$model_method%in%c("glove_cluster","lda")==FALSE){
model$save_model(model_dir = final_model_dir_path,
save_format=save_format)
}
}
}


} else {
stop("Function supports only objects of class TextEmbeddingClassifierNeuralNet or
TextEmbeddingModel")
Expand Down
12 changes: 10 additions & 2 deletions R/te_classifier_neuralnet_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ TextEmbeddingClassifierNeuralNet<-R6::R6Class(
if(attention_type %in% c("fourier","multihead")==FALSE){
stop("Optimzier must be 'fourier' oder 'multihead'.")
}
if(repeat_encoder>0 & attention_type=="multihead" & self_attention_heads<0){
if(repeat_encoder>0 & attention_type=="multihead" & self_attention_heads<=0){
stop("Encoder layer is set to 'multihead'. This requires self_attention_heads>=1.")
}

Expand Down Expand Up @@ -2045,7 +2045,7 @@ TextEmbeddingClassifierNeuralNet<-R6::R6Class(
#'\code{"tf"} for SavedModel
#'or \code{"h5"} for HDF5.
#'For 'pytorch' models \code{"safetensors"} for 'safetensors' or
#'\code{"pt"} for 'pytorch via pickle'.
#'\code{"pt"} for 'pytorch' via pickle.
#'Use \code{"default"} for the standard format. This is keras for
#''tensorflow'/'keras' models and safetensors for 'pytorch' models.
#'@return Function does not return a value. It saves the model to disk.
Expand Down Expand Up @@ -2221,6 +2221,14 @@ TextEmbeddingClassifierNeuralNet<-R6::R6Class(
#'information on the training infrastructure.
get_sustainability_data=function(){
return(private$sustainability)
},
#---------------------------------------------------------------------------
#'@description Method for requesting the machine learning framework used
#'for the classifier.
#'@return Returns a \code{string} describing the machine learning framework used
#'for the classifier
get_ml_framework=function(){
return(private$ml_framework)
}
),
private = list(
Expand Down
54 changes: 50 additions & 4 deletions R/text_embedding_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -719,22 +719,60 @@ TextEmbeddingModel<-R6::R6Class(
#'only for transformer models.
#'@param model_dir \code{string} containing the path to the relevant
#'model directory.
#'@param save_format Format for saving the model. For 'tensorflow'/'keras' models
#' \code{"h5"} for HDF5.
#'For 'pytorch' models \code{"safetensors"} for 'safetensors' or
#'\code{"pt"} for 'pytorch' via pickle.
#'Use \code{"default"} for the standard format. This is h5 for
#''tensorflow'/'keras' models and safetensors for 'pytorch' models.
#'@return Function does not return a value. It is used for saving a transformer model
#'to disk.
#'
#'@importFrom utils write.csv
save_model=function(model_dir){
save_model=function(model_dir,save_format="default"){
if((private$basic_components$method %in%private$supported_transformers)==TRUE){

if(save_format%in%c("default","h5","pt","safetensors")==FALSE){
stop("For TextEmbeddingModels save_format must be 'h5', 'pt', or 'safetensors'.")
}

if(save_format=="default"){
if(private$transformer_components$ml_framework=="tensorflow"){
save_format="h5"
} else if(private$transformer_components$ml_framework=="pytorch"){
save_format="safetensors"
}
}

model_dir_data_path<-paste0(model_dir,"/","model_data")

if(dir.exists(model_dir)==FALSE){
dir.create(model_dir)
cat("Creating Directory\n")
}

private$transformer_components$model$save_pretrained(save_directory=model_dir_data_path)
private$transformer_components$tokenizer$save_pretrained(model_dir_data_path)
if(private$transformer_components$ml_framework=="pytorch"){
if(save_format=="safetensors" & reticulate::py_module_available("safetensors")==TRUE){

private$transformer_components$model$save_pretrained(save_directory=model_dir_data_path,
safe_serilization=TRUE)
private$transformer_components$tokenizer$save_pretrained(model_dir_data_path)

} else if (save_format=="safetensors" & reticulate::py_module_available("safetensors")==FALSE){
private$transformer_components$model$save_pretrained(save_directory=model_dir_data_path,
safe_serilization=FALSE)
private$transformer_components$tokenizer$save_pretrained(model_dir_data_path)
warning("Python library 'safetensors' is not available. Saving model in standard
pytorch format.")
} else if (save_format=="pt"){
private$transformer_components$model$save_pretrained(save_directory=model_dir_data_path,
safe_serilization=FALSE)
private$transformer_components$tokenizer$save_pretrained(model_dir_data_path)
}
} else {
private$transformer_components$model$save_pretrained(save_directory=model_dir_data_path)
private$transformer_components$tokenizer$save_pretrained(model_dir_data_path)
}

#Saving Sustainability Data
sustain_matrix=private$sustainability$track_log
Expand Down Expand Up @@ -1523,6 +1561,14 @@ TextEmbeddingModel<-R6::R6Class(
#'information on the training infrastructure for every training run.
get_sustainability_data=function(){
return(private$sustainability$track_log)
},
#---------------------------------------------------------------------------
#'@description Method for requesting the machine learning framework used
#'for the classifier.
#'@return Returns a \code{string} describing the machine learning framework used
#'for the classifier
get_ml_framework=function(){
return(private$transformer_components$ml_framework)
}
)
)
Expand Down Expand Up @@ -1655,7 +1701,7 @@ EmbeddedText<-R6::R6Class(
#'generated this embedding.
#'@return \code{string} Label of the corresponding text embedding model
get_model_label=function(){
return(private$model_label)
return(private$transformer_components$ml_framework)
}
)
)
Expand Down
45 changes: 41 additions & 4 deletions R/transformer_bert.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@
#'
#'@param trace \code{bool} \code{TRUE} if information about the progress should be
#'printed to the console.
#'
#'@param pytorch_safetensors \code{bool} If \code{TRUE} a 'pytorch' model
#'is saved in safetensors format. If \code{FALSE} or 'safetensors' not available
#'it is saved in the standard pytorch format (.bin). Only relevant for pytorch models.
#'
#'@return This function does not return an object. Instead the configuration
#'and the vocabulary of the new model are saved on disk.
#'@note To train the model, pass the directory of the model to the function
Expand Down Expand Up @@ -77,7 +82,8 @@ create_bert_model<-function(
sustain_iso_code=NULL,
sustain_region=NULL,
sustain_interval=15,
trace=TRUE){
trace=TRUE,
pytorch_safetensors=TRUE){

#Set Shiny Progress Tracking
pgr_max=10
Expand Down Expand Up @@ -115,6 +121,19 @@ create_bert_model<-function(
}
}

#Check possible save formats
if(ml_framework=="pytorch"){
if(pytorch_safetensors==TRUE & reticulate::py_module_available("safetensors")==TRUE){
pt_safe_save=TRUE
} else if(pytorch_safetensors==TRUE & reticulate::py_module_available("safetensors")==FALSE){
pt_safe_save=FALSE
warning("Python library 'safetensors' not available. Model will be saved
in the standard pytorch format.")
} else {
pt_safe_save=FALSE
}
}

update_aifeducation_progress_bar(value = 1,
total = pgr_max,
title = "BERT Model")
Expand Down Expand Up @@ -280,7 +299,7 @@ create_bert_model<-function(
bert_model$save_pretrained(save_directory=model_dir)
} else {
bert_model$save_pretrained(save_directory=model_dir,
safe_serilization=reticulate::py_module_available("safetensors"))
safe_serilization=pt_safe_save)
}


Expand Down Expand Up @@ -384,6 +403,10 @@ create_bert_model<-function(
#'information about the training process from pytorch on the console.
#'\code{pytorch_trace=1} prints a progress bar.
#'
#'@param pytorch_safetensors \code{bool} If \code{TRUE} a 'pytorch' model
#'is saved in safetensors format. If \code{FALSE} or 'safetensors' not available
#'it is saved in the standard pytorch format (.bin). Only relevant for pytorch models.
#'
#'@return This function does not return an object. Instead the trained or fine-tuned
#'model is saved to disk.
#'@note This models uses a WordPiece Tokenizer like BERT and can be trained with
Expand Down Expand Up @@ -433,7 +456,8 @@ train_tune_bert_model=function(ml_framework=aifeducation_config$get_framework(),
sustain_interval=15,
trace=TRUE,
keras_trace=1,
pytorch_trace=1){
pytorch_trace=1,
pytorch_safetensors=TRUE){

#Set Shiny Progress Tracking
pgr_max=10
Expand Down Expand Up @@ -492,6 +516,19 @@ train_tune_bert_model=function(ml_framework=aifeducation_config$get_framework(),
}
}

#Check possible save formats
if(ml_framework=="pytorch"){
if(pytorch_safetensors==TRUE & reticulate::py_module_available("safetensors")==TRUE){
pt_safe_save=TRUE
} else if(pytorch_safetensors==TRUE & reticulate::py_module_available("safetensors")==FALSE){
pt_safe_save=FALSE
warning("Python library 'safetensors' not available. Model will be saved
in the standard pytorch format.")
} else {
pt_safe_save=FALSE
}
}

update_aifeducation_progress_bar(value = 1, total = pgr_max, title = "BERT Model")

#Start Sustainability Tracking-----------------------------------------------
Expand Down Expand Up @@ -835,7 +872,7 @@ train_tune_bert_model=function(ml_framework=aifeducation_config$get_framework(),
mlm_model$save_pretrained(save_directory=output_dir)
} else {
mlm_model$save_pretrained(save_directory=output_dir,
safe_serilization=reticulate::py_module_available("safetensors"))
safe_serilization=pt_safe_save)
}

update_aifeducation_progress_bar(value = 8, total = pgr_max, title = "BERT Model")
Expand Down

0 comments on commit 59c3e6a

Please sign in to comment.