diff --git a/NAMESPACE b/NAMESPACE index be968c8..48e3af6 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -9,6 +9,7 @@ export(bow_pp_create_basic_text_rep) export(bow_pp_create_vocab_draft) export(calc_standard_classification_measures) export(check_aif_py_modules) +export(clean_pytorch_log_transformers) export(combine_embeddings) export(create_bert_model) export(create_deberta_v2_model) @@ -20,6 +21,7 @@ export(get_coder_metrics) export(get_n_chunks) export(get_synthetic_cases) export(install_py_modules) +export(is.null_or_na) export(load_ai_model) export(matrix_to_array_c) export(save_ai_model) diff --git a/NEWS.md b/NEWS.md index 87310df..9c53f01 100644 --- a/NEWS.md +++ b/NEWS.md @@ -9,17 +9,22 @@ editor_options: **TextEmbeddingClassifiers** - Fixed a bug in GlobalAveragePooling1D_PT. Now the layer makes a correct pooling. - This change has an effect on models trained with version 0.3.1. + **This change has an effect on PyTorch models trained with version 0.3.1.** **TextEmbeddingModel** + - Replaced the parameter 'aggregation' with three new parameters allowing to explicitly choose the start and end layer to be included in the creation of embeddings. Furthermore, two options for the pooling method within each layer is added ("cls" and "average"). +- Added support for reporting the training and validation loss during training + the corresponding base model. **Transformer Models** - Fixed a bug in the creation of all transformer models except funnel. Now choosing the number of layers is working. +- A file 'history.log' is now saved within the model's folder reporting the loss + and validation loss during training for each epoch. **EmbeddedText** @@ -27,6 +32,11 @@ editor_options: the model's unique name is used for the validation. - Added new fields and updated methods to account for the new options in creating embeddings (layer selection and pooling type). + +**Graphical User Interface Aifeducation Studio** + +- Adapted the interface according to the changes made in this version. + # aifeducation 0.3.1 diff --git a/R/aif_gui.R b/R/aif_gui.R index f79bcca..106df7d 100644 --- a/R/aif_gui.R +++ b/R/aif_gui.R @@ -2308,7 +2308,7 @@ start_aifeducation_studio<-function(){ if(length(interface_architecture()[[2]])>0){ max_layer_transformer=interface_architecture()[[3]] - print(interface_architecture()[[1]]) + if(interface_architecture()[[1]]=="FunnelForMaskedLM"| interface_architecture()[[1]]=="FunnelModel"){ pool_type_choices=c("cls") @@ -2404,7 +2404,7 @@ start_aifeducation_studio<-function(){ #Create the interface shiny::observeEvent(input$lm_save_interface,{ - model_architecture=interface_architecture()[1] + model_architecture=interface_architecture()[[1]] print(model_architecture) if(model_architecture=="BertForMaskedLM"| model_architecture=="BertModel"){ @@ -2835,6 +2835,37 @@ start_aifeducation_studio<-function(){ ) ) + ), + #Language Model Training------------------------------------------ + shiny::tabPanel("Training", + shiny::fluidRow( + shinydashboard::box(title = "Training", + solidHeader = TRUE, + status = "primary", + width = 12, + shiny::sidebarLayout( + position="right", + sidebarPanel=shiny::sidebarPanel( + shiny::sliderInput(inputId = "lm_performance_text_size", + label = "Text Size", + min = 1, + max = 20, + step = 0.5, + value = 12), + shiny::numericInput(inputId = "lm_performance_y_min", + label = "Y Min", + value = 0), + shiny::numericInput(inputId = "lm_performance_y_max", + label = "Y Max", + value = 20), + ), + mainPanel =shiny::mainPanel( + shiny::plotOutput(outputId = "lm_performance_training_loss") + ) + ) + ) + ) + ), #Create Text Embeddings--------------------------------------------- shiny::tabPanel("Create Text Embeddings", @@ -2988,6 +3019,38 @@ start_aifeducation_studio<-function(){ } }) + output$lm_performance_training_loss<-shiny::renderPlot({ + plot_data=LanguageModel_for_Use()$last_training$history + + if(!is.null(plot_data)){ + y_min=input$lm_performance_y_min + y_max=input$lm_performance_y_max + + val_loss_min=min(plot_data$val_loss) + best_model_epoch=which(x=(plot_data$val_loss)==val_loss_min) + + plot<-ggplot2::ggplot(data=plot_data)+ + ggplot2::geom_line(ggplot2::aes(x=.data$epoch,y=.data$loss,color="train"))+ + ggplot2::geom_line(ggplot2::aes(x=.data$epoch,y=.data$val_loss,color="validation"))+ + ggplot2::geom_vline(xintercept = best_model_epoch, + linetype="dashed") + + plot=plot+ggplot2::theme_classic()+ + ggplot2::ylab("value")+ + ggplot2::coord_cartesian(ylim=c(y_min,y_max))+ + ggplot2::xlab("epoch")+ + ggplot2::scale_color_manual(values = c("train"="red", + "validation"="blue", + "test"="darkgreen"))+ + ggplot2::theme(text = ggplot2::element_text(size = input$lm_performance_text_size), + legend.position="bottom") + return(plot) + } else { + return(NULL) + } + },res = 72*2) + + #Document Page-------------------------------------------------------------- shinyFiles::shinyDirChoose(input=input, id="lm_db_select_model_for_documentation", diff --git a/R/aux_fct.R b/R/aux_fct.R index 0568fde..fa7cbbc 100644 --- a/R/aux_fct.R +++ b/R/aux_fct.R @@ -969,3 +969,65 @@ calc_standard_classification_measures<-function(true_values,predicted_values){ return(results) } + +#'Clean pytorch log of transformers +#' +#'Function for preparing and cleaning the log created by an object of class Trainer +#'from the python library 'transformer's +#' +#'@param log \code{data.frame} containing the log. +#' +#'@return Returns a \code{data.frame} containing epochs, loss, and val_loss. +#' +#'@family Auxiliary Functions +#'@keywords internal +#' +#'@export +clean_pytorch_log_transformers<-function(log){ + max_epochs<-max(log$epoch) + + cols=c("epoch","loss","val_loss") + + cleaned_log<-matrix(data = NA, + nrow = max_epochs, + ncol = length(cols)) + colnames(cleaned_log)=cols + for(i in 1:max_epochs){ + cleaned_log[i,"epoch"]=i + + tmp_loss=subset(log,log$epoch==i & is.na(log$loss)==FALSE) + tmp_loss=tmp_loss[1,"loss"] + cleaned_log[i,"loss"]=tmp_loss + + tmp_val_loss=subset(log,log$epoch==i & is.na(log$eval_loss)==FALSE) + tmp_val_loss=tmp_val_loss[1,"eval_loss"] + cleaned_log[i,"val_loss"]=tmp_val_loss + + } + return(as.data.frame(cleaned_log)) +} + +#'Check if NULL or NA +#' +#'Function for checking if an object is \code{NULL} or \codee{NA} +#' +#'@param object An object to test. +#' +#'@return Returns \code{FALSE} if the object is not \code{NULL} and not \code{NA}. +#'Returns \code{TRUE} in all other cases. +#' +#'@family Auxiliary Functions +#'@keywords internal +#' +#'@export +is.null_or_na<-function(object){ + if(is.null(object)==FALSE){ + if(anyNA(object)==FALSE){ + return(FALSE) + } else { + return(TRUE) + } + } else { + return(TRUE) + } +} diff --git a/R/onLoad.R b/R/onLoad.R index 2c71af3..54c4e5c 100644 --- a/R/onLoad.R +++ b/R/onLoad.R @@ -10,6 +10,7 @@ os<-NULL keras<-NULL accelerate<-NULL safetensors<-NULL +pandas<-NULL aifeducation_config<-NULL @@ -43,6 +44,7 @@ aifeducation_config<-NULL torcheval<<-reticulate::import("torcheval", delay_load = TRUE) accelerate<<-reticulate::import("accelerate", delay_load = TRUE) safetensors<<-reticulate::import("safetensors", delay_load = TRUE) + pandas<<-reticulate::import("pandas", delay_load = TRUE) codecarbon<<-reticulate::import("codecarbon", delay_load = TRUE) keras<<-reticulate::import("keras", delay_load = TRUE) diff --git a/R/text_embedding_model.R b/R/text_embedding_model.R index 6e2c854..75e8516 100644 --- a/R/text_embedding_model.R +++ b/R/text_embedding_model.R @@ -97,6 +97,14 @@ TextEmbeddingModel<-R6::R6Class( ) ), public = list( + + #'@field last_training ('list()')\cr + #'List for storing the history and the results of the last training. This + #'information will be overwritten if a new training is started. + last_training=list( + history=NULL + ), + #-------------------------------------------------------------------------- #'@description Method for creating a new text embedding model #'@param model_name \code{string} containing the name of the new model. @@ -376,6 +384,7 @@ TextEmbeddingModel<-R6::R6Class( } } + #Sustainability tracking sustainability_datalog_path=paste0(model_dir,"/","sustainability.csv") if(file.exists(sustainability_datalog_path)){ tmp_sustainability_data<-read.csv(sustainability_datalog_path) @@ -386,6 +395,15 @@ TextEmbeddingModel<-R6::R6Class( private$sustainability$track_log=NA } + #Training history + training_datalog_path=paste0(model_dir,"/","history.log") + if(file.exists(training_datalog_path)==TRUE){ + self$last_training$history=read.csv2(file = training_datalog_path) + } else { + self$last_training$history=NA + } + + #Check Embedding Configuration if(method=="funnel"){ max_layers_funnel=sum(private$transformer_components$model$config$block_repeats* @@ -778,6 +796,7 @@ TextEmbeddingModel<-R6::R6Class( } } + #Sustainability Data sustainability_datalog_path=paste0(model_dir,"/","sustainability.csv") if(file.exists(sustainability_datalog_path)){ tmp_sustainability_data<-read.csv(sustainability_datalog_path) @@ -788,6 +807,15 @@ TextEmbeddingModel<-R6::R6Class( private$sustainability$track_log=NA } + #Training History + training_datalog_path=paste0(model_dir,"/","history.log") + if(file.exists(training_datalog_path)){ + self$last_training$history=read.csv2(file = training_datalog_path) + } else { + self$last_training$history=NULL + } + + } else { message("Method only relevant for transformer models.") } @@ -859,6 +887,15 @@ TextEmbeddingModel<-R6::R6Class( row.names = FALSE ) + #Saving training history + if(is.null_or_na(self$last_training$history)==FALSE){ + write.csv2( + x=self$last_training$history, + file=paste0(model_dir,"/","history.log"), + row.names = FALSE, + quote = FALSE) + } + } else { message("Method only relevant for transformer models.") } @@ -1267,7 +1304,8 @@ TextEmbeddingModel<-R6::R6Class( if(private$transformer_components$emb_pool_type=="average"){ #Average Pooling over all tokens for(i in tmp_selected_layer){ - tensor_embeddings[i]=list(pooling(tensor_embeddings[[as.integer(i)]])) + tensor_embeddings[i]=list(pooling(x=tensor_embeddings[[as.integer(i)]], + mask=tokens$encodings["attention_mask"])) } } diff --git a/R/transformer_bert.R b/R/transformer_bert.R index c6f3193..9cc1aae 100644 --- a/R/transformer_bert.R +++ b/R/transformer_bert.R @@ -742,13 +742,20 @@ train_tune_bert_model=function(ml_framework=aifeducation_config$get_framework(), save_weights_only= TRUE ) + callback_history=tf$keras$callbacks$CSVLogger( + filename=paste0(output_dir,"/checkpoints/history.log"), + separator=",", + append=FALSE) + + callbacks=list(callback_checkpoint,callback_history) + #Add Callback if Shiny App is running if(requireNamespace("shiny",quietly=TRUE) & requireNamespace("shinyWidgets",quietly=TRUE)){ if(shiny::isRunning()){ shiny_app_active=TRUE reticulate::py_run_file(system.file("python/keras_callbacks.py", package = "aifeducation")) - callback_checkpoint=list(callback_checkpoint,py$ReportAiforeducationShiny()) + callbacks=list(callback_checkpoint,callback_history,py$ReportAiforeducationShiny()) } } @@ -772,7 +779,7 @@ train_tune_bert_model=function(ml_framework=aifeducation_config$get_framework(), epochs=as.integer(n_epoch), workers=as.integer(n_workers), use_multiprocessing=multi_process, - callbacks=list(callback_checkpoint), + callbacks=list(callbacks), verbose=as.integer(keras_trace)) if(trace==TRUE){ @@ -830,12 +837,15 @@ train_tune_bert_model=function(ml_framework=aifeducation_config$get_framework(), tokenizer = tokenizer) trainer$remove_callback(transformers$integrations$CodeCarbonCallback) + #Load Custom Callbacks + + #Add Callback if Shiny App is running if(requireNamespace("shiny") & requireNamespace("shinyWidgets")){ if(shiny::isRunning()){ - shiny_app_active=TRUE reticulate::py_run_file(system.file("python/pytorch_transformer_callbacks.py", package = "aifeducation")) + shiny_app_active=TRUE trainer$add_callback(py$ReportAiforeducationShiny_PT()) } } @@ -857,9 +867,20 @@ train_tune_bert_model=function(ml_framework=aifeducation_config$get_framework(), } if(ml_framework=="tensorflow"){ mlm_model$save_pretrained(save_directory=output_dir) + history_log=read.csv(file = paste0(output_dir,"/checkpoints/history.log")) + write.csv2(history_log, + file=paste0(output_dir,"/history.log"), + row.names=FALSE, + quote=FALSE) } else { mlm_model$save_pretrained(save_directory=output_dir, safe_serilization=pt_safe_save) + history_log=pandas$DataFrame(trainer$state$log_history) + history_log=clean_pytorch_log_transformers(history_log) + write.csv2(history_log, + file=paste0(output_dir,"/history.log"), + row.names=FALSE, + quote=FALSE) } update_aifeducation_progress_bar(value = 8, total = pgr_max, title = "BERT Model") diff --git a/R/transformer_deberta_v2.R b/R/transformer_deberta_v2.R index 4c6e3eb..7c374a2 100644 --- a/R/transformer_deberta_v2.R +++ b/R/transformer_deberta_v2.R @@ -771,12 +771,19 @@ train_tune_deberta_v2_model=function(ml_framework=aifeducation_config$get_framew save_freq="epoch", save_weights_only= TRUE) + callback_history=tf$keras$callbacks$CSVLogger( + filename=paste0(output_dir,"/checkpoints/history.log"), + separator=",", + append=FALSE) + + callbacks=list(callback_checkpoint,callback_history) + if(requireNamespace("shiny",quietly=TRUE) & requireNamespace("shinyWidgets",quietly=TRUE)){ if(shiny::isRunning()){ shiny_app_active=TRUE reticulate::py_run_file(system.file("python/keras_callbacks.py", package = "aifeducation")) - callback_checkpoint=list(callback_checkpoint,py$ReportAiforeducationShiny()) + callbacks=list(callback_checkpoint,callback_history,py$ReportAiforeducationShiny()) } } @@ -801,7 +808,7 @@ train_tune_deberta_v2_model=function(ml_framework=aifeducation_config$get_framew epochs=as.integer(n_epoch), workers=as.integer(n_workers), use_multiprocessing=multi_process, - callbacks=list(callback_checkpoint), + callbacks=list(callbacks), verbose=as.integer(keras_trace)) if(trace==TRUE){ @@ -900,9 +907,20 @@ train_tune_deberta_v2_model=function(ml_framework=aifeducation_config$get_framew if(ml_framework=="tensorflow"){ mlm_model$save_pretrained(save_directory=output_dir) + history_log=read.csv(file = paste0(output_dir,"/checkpoints/history.log")) + write.csv2(history_log, + file=paste0(output_dir,"/history.log"), + row.names=FALSE, + quote=FALSE) } else { mlm_model$save_pretrained(save_directory=output_dir, safe_serilization=pt_safe_save) + history_log=pandas$DataFrame(trainer$state$log_history) + history_log=clean_pytorch_log_transformers(history_log) + write.csv2(history_log, + file=paste0(output_dir,"/history.log"), + row.names=FALSE, + quote=FALSE) } update_aifeducation_progress_bar(value = 8, total = pgr_max, title = "DeBERTa V2 Model") diff --git a/R/transformer_funnel.R b/R/transformer_funnel.R index eb6c599..fcc4e25 100644 --- a/R/transformer_funnel.R +++ b/R/transformer_funnel.R @@ -752,13 +752,20 @@ train_tune_funnel_model=function(ml_framework=aifeducation_config$get_framework( save_weights_only= TRUE ) + callback_history=tf$keras$callbacks$CSVLogger( + filename=paste0(output_dir,"/checkpoints/history.log"), + separator=",", + append=FALSE) + + callbacks=list(callback_checkpoint,callback_history) + #Add Callback if Shiny App is running if(requireNamespace("shiny",quietly=TRUE) & requireNamespace("shinyWidgets",quietly=TRUE)){ if(shiny::isRunning()){ shiny_app_active=TRUE reticulate::py_run_file(system.file("python/keras_callbacks.py", package = "aifeducation")) - callback_checkpoint=list(callback_checkpoint,py$ReportAiforeducationShiny()) + callbacks=list(callback_checkpoint,callback_history,py$ReportAiforeducationShiny()) } } @@ -782,7 +789,7 @@ train_tune_funnel_model=function(ml_framework=aifeducation_config$get_framework( epochs=as.integer(n_epoch), workers=as.integer(n_workers), use_multiprocessing=multi_process, - callbacks=list(callback_checkpoint), + callbacks=list(callbacks), verbose=as.integer(keras_trace)) if(trace==TRUE){ @@ -871,9 +878,20 @@ train_tune_funnel_model=function(ml_framework=aifeducation_config$get_framework( } if(ml_framework=="tensorflow"){ mlm_model$save_pretrained(save_directory=output_dir) + history_log=read.csv(file = paste0(output_dir,"/checkpoints/history.log")) + write.csv2(history_log, + file=paste0(output_dir,"/history.log"), + row.names=FALSE, + quote=FALSE) } else { mlm_model$save_pretrained(save_directory=output_dir, safe_serilization=pt_safe_save) + history_log=pandas$DataFrame(trainer$state$log_history) + history_log=clean_pytorch_log_transformers(history_log) + write.csv2(history_log, + file=paste0(output_dir,"/history.log"), + row.names=FALSE, + quote=FALSE) } update_aifeducation_progress_bar(value = 8, total = pgr_max, title = "Funnel Model") diff --git a/R/transformer_longformer.R b/R/transformer_longformer.R index c1f1ad1..b771efa 100644 --- a/R/transformer_longformer.R +++ b/R/transformer_longformer.R @@ -695,13 +695,20 @@ train_tune_longformer_model=function(ml_framework=aifeducation_config$get_framew save_freq="epoch", save_weights_only= TRUE) + callback_history=tf$keras$callbacks$CSVLogger( + filename=paste0(output_dir,"/checkpoints/history.log"), + separator=",", + append=FALSE) + + callbacks=list(callback_checkpoint,callback_history) + #Add Callback if Shiny App is running if(requireNamespace("shiny",quietly=TRUE) & requireNamespace("shinyWidgets",quietly=TRUE)){ if(shiny::isRunning()){ shiny_app_active=TRUE reticulate::py_run_file(system.file("python/keras_callbacks.py", package = "aifeducation")) - callback_checkpoint=list(callback_checkpoint,py$ReportAiforeducationShiny()) + callbacks=list(callback_checkpoint,callback_history,py$ReportAiforeducationShiny()) } } @@ -726,7 +733,7 @@ train_tune_longformer_model=function(ml_framework=aifeducation_config$get_framew epochs=as.integer(n_epoch), workers=as.integer(n_workers), use_multiprocessing=multi_process, - callbacks=list(callback_checkpoint), + callbacks=list(callbacks), verbose=as.integer(keras_trace)) if(trace==TRUE){ @@ -806,9 +813,20 @@ train_tune_longformer_model=function(ml_framework=aifeducation_config$get_framew } if(ml_framework=="tensorflow"){ mlm_model$save_pretrained(save_directory=output_dir) + history_log=read.csv(file = paste0(output_dir,"/checkpoints/history.log")) + write.csv2(history_log, + file=paste0(output_dir,"/history.log"), + row.names=FALSE, + quote=FALSE) } else { mlm_model$save_pretrained(save_directory=output_dir, safe_serilization=pt_safe_save) + history_log=pandas$DataFrame(trainer$state$log_history) + history_log=clean_pytorch_log_transformers(history_log) + write.csv2(history_log, + file=paste0(output_dir,"/history.log"), + row.names=FALSE, + quote=FALSE) } update_aifeducation_progress_bar(value = 8, total = pgr_max, title = "Longformer Model") diff --git a/R/transformer_roberta.R b/R/transformer_roberta.R index 8bc3956..f386bc0 100644 --- a/R/transformer_roberta.R +++ b/R/transformer_roberta.R @@ -705,13 +705,20 @@ train_tune_roberta_model=function(ml_framework=aifeducation_config$get_framework save_freq="epoch", save_weights_only= TRUE) + callback_history=tf$keras$callbacks$CSVLogger( + filename=paste0(output_dir,"/checkpoints/history.log"), + separator=",", + append=FALSE) + + callbacks=list(callback_checkpoint,callback_history) + #Add Callback if Shiny App is running if(requireNamespace("shiny",quietly=TRUE) & requireNamespace("shinyWidgets",quietly=TRUE)){ if(shiny::isRunning()){ shiny_app_active=TRUE reticulate::py_run_file(system.file("python/keras_callbacks.py", package = "aifeducation")) - callback_checkpoint=list(callback_checkpoint,py$ReportAiforeducationShiny()) + callbacks=list(callback_checkpoint,callback_history,py$ReportAiforeducationShiny()) } } @@ -735,7 +742,7 @@ train_tune_roberta_model=function(ml_framework=aifeducation_config$get_framework epochs=as.integer(n_epoch), workers=as.integer(n_workers), use_multiprocessing=multi_process, - callbacks=list(callback_checkpoint), + callbacks=list(callbacks), verbose=as.integer(keras_trace)) if(trace==TRUE){ @@ -815,9 +822,20 @@ train_tune_roberta_model=function(ml_framework=aifeducation_config$get_framework } if(ml_framework=="tensorflow"){ mlm_model$save_pretrained(save_directory=output_dir) + history_log=read.csv(file = paste0(output_dir,"/checkpoints/history.log")) + write.csv2(history_log, + file=paste0(output_dir,"/history.log"), + row.names=FALSE, + quote=FALSE) } else { mlm_model$save_pretrained(save_directory=output_dir, safe_serilization=pt_safe_save) + history_log=pandas$DataFrame(trainer$state$log_history) + history_log=clean_pytorch_log_transformers(history_log) + write.csv2(history_log, + file=paste0(output_dir,"/history.log"), + row.names=FALSE, + quote=FALSE) } update_aifeducation_progress_bar(value = 8, total = pgr_max, title = "RoBERTa Model") diff --git a/inst/python/pytorch_te_classifier.py b/inst/python/pytorch_te_classifier.py index 588b536..87151a3 100644 --- a/inst/python/pytorch_te_classifier.py +++ b/inst/python/pytorch_te_classifier.py @@ -152,7 +152,10 @@ class GlobalAveragePooling1D_PT(torch.nn.Module): def __init__(self): super().__init__() - def forward(self,x): + def forward(self,x,mask=None): + if not mask is None: + mask_r=mask.reshape(mask.size()[0],mask.size()[1],1) + x=torch.mul(x,mask_r) x=torch.sum(x,dim=1)*(1/self.get_length(x)) return x diff --git a/inst/python/pytorch_transformer_callbacks.py b/inst/python/pytorch_transformer_callbacks.py index 32445ef..b23f121 100644 --- a/inst/python/pytorch_transformer_callbacks.py +++ b/inst/python/pytorch_transformer_callbacks.py @@ -1,4 +1,5 @@ import transformers +import pandas class ReportAiforeducationShiny_PT(transformers.TrainerCallback): def on_train_begin(self, args, state, control, **kwargs): @@ -13,4 +14,3 @@ def on_step_end(self, args, state, control, **kwargs): - diff --git a/man/TextEmbeddingModel.Rd b/man/TextEmbeddingModel.Rd index 4548a12..0ad06b1 100644 --- a/man/TextEmbeddingModel.Rd +++ b/man/TextEmbeddingModel.Rd @@ -20,6 +20,15 @@ Other Text Embedding: \code{\link{combine_embeddings}()} } \concept{Text Embedding} +\section{Public fields}{ +\if{html}{\out{