diff --git a/NAMESPACE b/NAMESPACE index be968c8..48e3af6 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -9,6 +9,7 @@ export(bow_pp_create_basic_text_rep) export(bow_pp_create_vocab_draft) export(calc_standard_classification_measures) export(check_aif_py_modules) +export(clean_pytorch_log_transformers) export(combine_embeddings) export(create_bert_model) export(create_deberta_v2_model) @@ -20,6 +21,7 @@ export(get_coder_metrics) export(get_n_chunks) export(get_synthetic_cases) export(install_py_modules) +export(is.null_or_na) export(load_ai_model) export(matrix_to_array_c) export(save_ai_model) diff --git a/NEWS.md b/NEWS.md index 87310df..9c53f01 100644 --- a/NEWS.md +++ b/NEWS.md @@ -9,17 +9,22 @@ editor_options: **TextEmbeddingClassifiers** - Fixed a bug in GlobalAveragePooling1D_PT. Now the layer makes a correct pooling. - This change has an effect on models trained with version 0.3.1. + **This change has an effect on PyTorch models trained with version 0.3.1.** **TextEmbeddingModel** + - Replaced the parameter 'aggregation' with three new parameters allowing to explicitly choose the start and end layer to be included in the creation of embeddings. Furthermore, two options for the pooling method within each layer is added ("cls" and "average"). +- Added support for reporting the training and validation loss during training + the corresponding base model. **Transformer Models** - Fixed a bug in the creation of all transformer models except funnel. Now choosing the number of layers is working. +- A file 'history.log' is now saved within the model's folder reporting the loss + and validation loss during training for each epoch. **EmbeddedText** @@ -27,6 +32,11 @@ editor_options: the model's unique name is used for the validation. - Added new fields and updated methods to account for the new options in creating embeddings (layer selection and pooling type). + +**Graphical User Interface Aifeducation Studio** + +- Adapted the interface according to the changes made in this version. + # aifeducation 0.3.1 diff --git a/R/aif_gui.R b/R/aif_gui.R index f79bcca..106df7d 100644 --- a/R/aif_gui.R +++ b/R/aif_gui.R @@ -2308,7 +2308,7 @@ start_aifeducation_studio<-function(){ if(length(interface_architecture()[[2]])>0){ max_layer_transformer=interface_architecture()[[3]] - print(interface_architecture()[[1]]) + if(interface_architecture()[[1]]=="FunnelForMaskedLM"| interface_architecture()[[1]]=="FunnelModel"){ pool_type_choices=c("cls") @@ -2404,7 +2404,7 @@ start_aifeducation_studio<-function(){ #Create the interface shiny::observeEvent(input$lm_save_interface,{ - model_architecture=interface_architecture()[1] + model_architecture=interface_architecture()[[1]] print(model_architecture) if(model_architecture=="BertForMaskedLM"| model_architecture=="BertModel"){ @@ -2835,6 +2835,37 @@ start_aifeducation_studio<-function(){ ) ) + ), + #Language Model Training------------------------------------------ + shiny::tabPanel("Training", + shiny::fluidRow( + shinydashboard::box(title = "Training", + solidHeader = TRUE, + status = "primary", + width = 12, + shiny::sidebarLayout( + position="right", + sidebarPanel=shiny::sidebarPanel( + shiny::sliderInput(inputId = "lm_performance_text_size", + label = "Text Size", + min = 1, + max = 20, + step = 0.5, + value = 12), + shiny::numericInput(inputId = "lm_performance_y_min", + label = "Y Min", + value = 0), + shiny::numericInput(inputId = "lm_performance_y_max", + label = "Y Max", + value = 20), + ), + mainPanel =shiny::mainPanel( + shiny::plotOutput(outputId = "lm_performance_training_loss") + ) + ) + ) + ) + ), #Create Text Embeddings--------------------------------------------- shiny::tabPanel("Create Text Embeddings", @@ -2988,6 +3019,38 @@ start_aifeducation_studio<-function(){ } }) + output$lm_performance_training_loss<-shiny::renderPlot({ + plot_data=LanguageModel_for_Use()$last_training$history + + if(!is.null(plot_data)){ + y_min=input$lm_performance_y_min + y_max=input$lm_performance_y_max + + val_loss_min=min(plot_data$val_loss) + best_model_epoch=which(x=(plot_data$val_loss)==val_loss_min) + + plot<-ggplot2::ggplot(data=plot_data)+ + ggplot2::geom_line(ggplot2::aes(x=.data$epoch,y=.data$loss,color="train"))+ + ggplot2::geom_line(ggplot2::aes(x=.data$epoch,y=.data$val_loss,color="validation"))+ + ggplot2::geom_vline(xintercept = best_model_epoch, + linetype="dashed") + + plot=plot+ggplot2::theme_classic()+ + ggplot2::ylab("value")+ + ggplot2::coord_cartesian(ylim=c(y_min,y_max))+ + ggplot2::xlab("epoch")+ + ggplot2::scale_color_manual(values = c("train"="red", + "validation"="blue", + "test"="darkgreen"))+ + ggplot2::theme(text = ggplot2::element_text(size = input$lm_performance_text_size), + legend.position="bottom") + return(plot) + } else { + return(NULL) + } + },res = 72*2) + + #Document Page-------------------------------------------------------------- shinyFiles::shinyDirChoose(input=input, id="lm_db_select_model_for_documentation", diff --git a/R/aux_fct.R b/R/aux_fct.R index 0568fde..fa7cbbc 100644 --- a/R/aux_fct.R +++ b/R/aux_fct.R @@ -969,3 +969,65 @@ calc_standard_classification_measures<-function(true_values,predicted_values){ return(results) } + +#'Clean pytorch log of transformers +#' +#'Function for preparing and cleaning the log created by an object of class Trainer +#'from the python library 'transformer's +#' +#'@param log \code{data.frame} containing the log. +#' +#'@return Returns a \code{data.frame} containing epochs, loss, and val_loss. +#' +#'@family Auxiliary Functions +#'@keywords internal +#' +#'@export +clean_pytorch_log_transformers<-function(log){ + max_epochs<-max(log$epoch) + + cols=c("epoch","loss","val_loss") + + cleaned_log<-matrix(data = NA, + nrow = max_epochs, + ncol = length(cols)) + colnames(cleaned_log)=cols + for(i in 1:max_epochs){ + cleaned_log[i,"epoch"]=i + + tmp_loss=subset(log,log$epoch==i & is.na(log$loss)==FALSE) + tmp_loss=tmp_loss[1,"loss"] + cleaned_log[i,"loss"]=tmp_loss + + tmp_val_loss=subset(log,log$epoch==i & is.na(log$eval_loss)==FALSE) + tmp_val_loss=tmp_val_loss[1,"eval_loss"] + cleaned_log[i,"val_loss"]=tmp_val_loss + + } + return(as.data.frame(cleaned_log)) +} + +#'Check if NULL or NA +#' +#'Function for checking if an object is \code{NULL} or \codee{NA} +#' +#'@param object An object to test. +#' +#'@return Returns \code{FALSE} if the object is not \code{NULL} and not \code{NA}. +#'Returns \code{TRUE} in all other cases. +#' +#'@family Auxiliary Functions +#'@keywords internal +#' +#'@export +is.null_or_na<-function(object){ + if(is.null(object)==FALSE){ + if(anyNA(object)==FALSE){ + return(FALSE) + } else { + return(TRUE) + } + } else { + return(TRUE) + } +} diff --git a/R/onLoad.R b/R/onLoad.R index 2c71af3..54c4e5c 100644 --- a/R/onLoad.R +++ b/R/onLoad.R @@ -10,6 +10,7 @@ os<-NULL keras<-NULL accelerate<-NULL safetensors<-NULL +pandas<-NULL aifeducation_config<-NULL @@ -43,6 +44,7 @@ aifeducation_config<-NULL torcheval<<-reticulate::import("torcheval", delay_load = TRUE) accelerate<<-reticulate::import("accelerate", delay_load = TRUE) safetensors<<-reticulate::import("safetensors", delay_load = TRUE) + pandas<<-reticulate::import("pandas", delay_load = TRUE) codecarbon<<-reticulate::import("codecarbon", delay_load = TRUE) keras<<-reticulate::import("keras", delay_load = TRUE) diff --git a/R/text_embedding_model.R b/R/text_embedding_model.R index 6e2c854..75e8516 100644 --- a/R/text_embedding_model.R +++ b/R/text_embedding_model.R @@ -97,6 +97,14 @@ TextEmbeddingModel<-R6::R6Class( ) ), public = list( + + #'@field last_training ('list()')\cr + #'List for storing the history and the results of the last training. This + #'information will be overwritten if a new training is started. + last_training=list( + history=NULL + ), + #-------------------------------------------------------------------------- #'@description Method for creating a new text embedding model #'@param model_name \code{string} containing the name of the new model. @@ -376,6 +384,7 @@ TextEmbeddingModel<-R6::R6Class( } } + #Sustainability tracking sustainability_datalog_path=paste0(model_dir,"/","sustainability.csv") if(file.exists(sustainability_datalog_path)){ tmp_sustainability_data<-read.csv(sustainability_datalog_path) @@ -386,6 +395,15 @@ TextEmbeddingModel<-R6::R6Class( private$sustainability$track_log=NA } + #Training history + training_datalog_path=paste0(model_dir,"/","history.log") + if(file.exists(training_datalog_path)==TRUE){ + self$last_training$history=read.csv2(file = training_datalog_path) + } else { + self$last_training$history=NA + } + + #Check Embedding Configuration if(method=="funnel"){ max_layers_funnel=sum(private$transformer_components$model$config$block_repeats* @@ -778,6 +796,7 @@ TextEmbeddingModel<-R6::R6Class( } } + #Sustainability Data sustainability_datalog_path=paste0(model_dir,"/","sustainability.csv") if(file.exists(sustainability_datalog_path)){ tmp_sustainability_data<-read.csv(sustainability_datalog_path) @@ -788,6 +807,15 @@ TextEmbeddingModel<-R6::R6Class( private$sustainability$track_log=NA } + #Training History + training_datalog_path=paste0(model_dir,"/","history.log") + if(file.exists(training_datalog_path)){ + self$last_training$history=read.csv2(file = training_datalog_path) + } else { + self$last_training$history=NULL + } + + } else { message("Method only relevant for transformer models.") } @@ -859,6 +887,15 @@ TextEmbeddingModel<-R6::R6Class( row.names = FALSE ) + #Saving training history + if(is.null_or_na(self$last_training$history)==FALSE){ + write.csv2( + x=self$last_training$history, + file=paste0(model_dir,"/","history.log"), + row.names = FALSE, + quote = FALSE) + } + } else { message("Method only relevant for transformer models.") } @@ -1267,7 +1304,8 @@ TextEmbeddingModel<-R6::R6Class( if(private$transformer_components$emb_pool_type=="average"){ #Average Pooling over all tokens for(i in tmp_selected_layer){ - tensor_embeddings[i]=list(pooling(tensor_embeddings[[as.integer(i)]])) + tensor_embeddings[i]=list(pooling(x=tensor_embeddings[[as.integer(i)]], + mask=tokens$encodings["attention_mask"])) } } diff --git a/R/transformer_bert.R b/R/transformer_bert.R index c6f3193..9cc1aae 100644 --- a/R/transformer_bert.R +++ b/R/transformer_bert.R @@ -742,13 +742,20 @@ train_tune_bert_model=function(ml_framework=aifeducation_config$get_framework(), save_weights_only= TRUE ) + callback_history=tf$keras$callbacks$CSVLogger( + filename=paste0(output_dir,"/checkpoints/history.log"), + separator=",", + append=FALSE) + + callbacks=list(callback_checkpoint,callback_history) + #Add Callback if Shiny App is running if(requireNamespace("shiny",quietly=TRUE) & requireNamespace("shinyWidgets",quietly=TRUE)){ if(shiny::isRunning()){ shiny_app_active=TRUE reticulate::py_run_file(system.file("python/keras_callbacks.py", package = "aifeducation")) - callback_checkpoint=list(callback_checkpoint,py$ReportAiforeducationShiny()) + callbacks=list(callback_checkpoint,callback_history,py$ReportAiforeducationShiny()) } } @@ -772,7 +779,7 @@ train_tune_bert_model=function(ml_framework=aifeducation_config$get_framework(), epochs=as.integer(n_epoch), workers=as.integer(n_workers), use_multiprocessing=multi_process, - callbacks=list(callback_checkpoint), + callbacks=list(callbacks), verbose=as.integer(keras_trace)) if(trace==TRUE){ @@ -830,12 +837,15 @@ train_tune_bert_model=function(ml_framework=aifeducation_config$get_framework(), tokenizer = tokenizer) trainer$remove_callback(transformers$integrations$CodeCarbonCallback) + #Load Custom Callbacks + + #Add Callback if Shiny App is running if(requireNamespace("shiny") & requireNamespace("shinyWidgets")){ if(shiny::isRunning()){ - shiny_app_active=TRUE reticulate::py_run_file(system.file("python/pytorch_transformer_callbacks.py", package = "aifeducation")) + shiny_app_active=TRUE trainer$add_callback(py$ReportAiforeducationShiny_PT()) } } @@ -857,9 +867,20 @@ train_tune_bert_model=function(ml_framework=aifeducation_config$get_framework(), } if(ml_framework=="tensorflow"){ mlm_model$save_pretrained(save_directory=output_dir) + history_log=read.csv(file = paste0(output_dir,"/checkpoints/history.log")) + write.csv2(history_log, + file=paste0(output_dir,"/history.log"), + row.names=FALSE, + quote=FALSE) } else { mlm_model$save_pretrained(save_directory=output_dir, safe_serilization=pt_safe_save) + history_log=pandas$DataFrame(trainer$state$log_history) + history_log=clean_pytorch_log_transformers(history_log) + write.csv2(history_log, + file=paste0(output_dir,"/history.log"), + row.names=FALSE, + quote=FALSE) } update_aifeducation_progress_bar(value = 8, total = pgr_max, title = "BERT Model") diff --git a/R/transformer_deberta_v2.R b/R/transformer_deberta_v2.R index 4c6e3eb..7c374a2 100644 --- a/R/transformer_deberta_v2.R +++ b/R/transformer_deberta_v2.R @@ -771,12 +771,19 @@ train_tune_deberta_v2_model=function(ml_framework=aifeducation_config$get_framew save_freq="epoch", save_weights_only= TRUE) + callback_history=tf$keras$callbacks$CSVLogger( + filename=paste0(output_dir,"/checkpoints/history.log"), + separator=",", + append=FALSE) + + callbacks=list(callback_checkpoint,callback_history) + if(requireNamespace("shiny",quietly=TRUE) & requireNamespace("shinyWidgets",quietly=TRUE)){ if(shiny::isRunning()){ shiny_app_active=TRUE reticulate::py_run_file(system.file("python/keras_callbacks.py", package = "aifeducation")) - callback_checkpoint=list(callback_checkpoint,py$ReportAiforeducationShiny()) + callbacks=list(callback_checkpoint,callback_history,py$ReportAiforeducationShiny()) } } @@ -801,7 +808,7 @@ train_tune_deberta_v2_model=function(ml_framework=aifeducation_config$get_framew epochs=as.integer(n_epoch), workers=as.integer(n_workers), use_multiprocessing=multi_process, - callbacks=list(callback_checkpoint), + callbacks=list(callbacks), verbose=as.integer(keras_trace)) if(trace==TRUE){ @@ -900,9 +907,20 @@ train_tune_deberta_v2_model=function(ml_framework=aifeducation_config$get_framew if(ml_framework=="tensorflow"){ mlm_model$save_pretrained(save_directory=output_dir) + history_log=read.csv(file = paste0(output_dir,"/checkpoints/history.log")) + write.csv2(history_log, + file=paste0(output_dir,"/history.log"), + row.names=FALSE, + quote=FALSE) } else { mlm_model$save_pretrained(save_directory=output_dir, safe_serilization=pt_safe_save) + history_log=pandas$DataFrame(trainer$state$log_history) + history_log=clean_pytorch_log_transformers(history_log) + write.csv2(history_log, + file=paste0(output_dir,"/history.log"), + row.names=FALSE, + quote=FALSE) } update_aifeducation_progress_bar(value = 8, total = pgr_max, title = "DeBERTa V2 Model") diff --git a/R/transformer_funnel.R b/R/transformer_funnel.R index eb6c599..fcc4e25 100644 --- a/R/transformer_funnel.R +++ b/R/transformer_funnel.R @@ -752,13 +752,20 @@ train_tune_funnel_model=function(ml_framework=aifeducation_config$get_framework( save_weights_only= TRUE ) + callback_history=tf$keras$callbacks$CSVLogger( + filename=paste0(output_dir,"/checkpoints/history.log"), + separator=",", + append=FALSE) + + callbacks=list(callback_checkpoint,callback_history) + #Add Callback if Shiny App is running if(requireNamespace("shiny",quietly=TRUE) & requireNamespace("shinyWidgets",quietly=TRUE)){ if(shiny::isRunning()){ shiny_app_active=TRUE reticulate::py_run_file(system.file("python/keras_callbacks.py", package = "aifeducation")) - callback_checkpoint=list(callback_checkpoint,py$ReportAiforeducationShiny()) + callbacks=list(callback_checkpoint,callback_history,py$ReportAiforeducationShiny()) } } @@ -782,7 +789,7 @@ train_tune_funnel_model=function(ml_framework=aifeducation_config$get_framework( epochs=as.integer(n_epoch), workers=as.integer(n_workers), use_multiprocessing=multi_process, - callbacks=list(callback_checkpoint), + callbacks=list(callbacks), verbose=as.integer(keras_trace)) if(trace==TRUE){ @@ -871,9 +878,20 @@ train_tune_funnel_model=function(ml_framework=aifeducation_config$get_framework( } if(ml_framework=="tensorflow"){ mlm_model$save_pretrained(save_directory=output_dir) + history_log=read.csv(file = paste0(output_dir,"/checkpoints/history.log")) + write.csv2(history_log, + file=paste0(output_dir,"/history.log"), + row.names=FALSE, + quote=FALSE) } else { mlm_model$save_pretrained(save_directory=output_dir, safe_serilization=pt_safe_save) + history_log=pandas$DataFrame(trainer$state$log_history) + history_log=clean_pytorch_log_transformers(history_log) + write.csv2(history_log, + file=paste0(output_dir,"/history.log"), + row.names=FALSE, + quote=FALSE) } update_aifeducation_progress_bar(value = 8, total = pgr_max, title = "Funnel Model") diff --git a/R/transformer_longformer.R b/R/transformer_longformer.R index c1f1ad1..b771efa 100644 --- a/R/transformer_longformer.R +++ b/R/transformer_longformer.R @@ -695,13 +695,20 @@ train_tune_longformer_model=function(ml_framework=aifeducation_config$get_framew save_freq="epoch", save_weights_only= TRUE) + callback_history=tf$keras$callbacks$CSVLogger( + filename=paste0(output_dir,"/checkpoints/history.log"), + separator=",", + append=FALSE) + + callbacks=list(callback_checkpoint,callback_history) + #Add Callback if Shiny App is running if(requireNamespace("shiny",quietly=TRUE) & requireNamespace("shinyWidgets",quietly=TRUE)){ if(shiny::isRunning()){ shiny_app_active=TRUE reticulate::py_run_file(system.file("python/keras_callbacks.py", package = "aifeducation")) - callback_checkpoint=list(callback_checkpoint,py$ReportAiforeducationShiny()) + callbacks=list(callback_checkpoint,callback_history,py$ReportAiforeducationShiny()) } } @@ -726,7 +733,7 @@ train_tune_longformer_model=function(ml_framework=aifeducation_config$get_framew epochs=as.integer(n_epoch), workers=as.integer(n_workers), use_multiprocessing=multi_process, - callbacks=list(callback_checkpoint), + callbacks=list(callbacks), verbose=as.integer(keras_trace)) if(trace==TRUE){ @@ -806,9 +813,20 @@ train_tune_longformer_model=function(ml_framework=aifeducation_config$get_framew } if(ml_framework=="tensorflow"){ mlm_model$save_pretrained(save_directory=output_dir) + history_log=read.csv(file = paste0(output_dir,"/checkpoints/history.log")) + write.csv2(history_log, + file=paste0(output_dir,"/history.log"), + row.names=FALSE, + quote=FALSE) } else { mlm_model$save_pretrained(save_directory=output_dir, safe_serilization=pt_safe_save) + history_log=pandas$DataFrame(trainer$state$log_history) + history_log=clean_pytorch_log_transformers(history_log) + write.csv2(history_log, + file=paste0(output_dir,"/history.log"), + row.names=FALSE, + quote=FALSE) } update_aifeducation_progress_bar(value = 8, total = pgr_max, title = "Longformer Model") diff --git a/R/transformer_roberta.R b/R/transformer_roberta.R index 8bc3956..f386bc0 100644 --- a/R/transformer_roberta.R +++ b/R/transformer_roberta.R @@ -705,13 +705,20 @@ train_tune_roberta_model=function(ml_framework=aifeducation_config$get_framework save_freq="epoch", save_weights_only= TRUE) + callback_history=tf$keras$callbacks$CSVLogger( + filename=paste0(output_dir,"/checkpoints/history.log"), + separator=",", + append=FALSE) + + callbacks=list(callback_checkpoint,callback_history) + #Add Callback if Shiny App is running if(requireNamespace("shiny",quietly=TRUE) & requireNamespace("shinyWidgets",quietly=TRUE)){ if(shiny::isRunning()){ shiny_app_active=TRUE reticulate::py_run_file(system.file("python/keras_callbacks.py", package = "aifeducation")) - callback_checkpoint=list(callback_checkpoint,py$ReportAiforeducationShiny()) + callbacks=list(callback_checkpoint,callback_history,py$ReportAiforeducationShiny()) } } @@ -735,7 +742,7 @@ train_tune_roberta_model=function(ml_framework=aifeducation_config$get_framework epochs=as.integer(n_epoch), workers=as.integer(n_workers), use_multiprocessing=multi_process, - callbacks=list(callback_checkpoint), + callbacks=list(callbacks), verbose=as.integer(keras_trace)) if(trace==TRUE){ @@ -815,9 +822,20 @@ train_tune_roberta_model=function(ml_framework=aifeducation_config$get_framework } if(ml_framework=="tensorflow"){ mlm_model$save_pretrained(save_directory=output_dir) + history_log=read.csv(file = paste0(output_dir,"/checkpoints/history.log")) + write.csv2(history_log, + file=paste0(output_dir,"/history.log"), + row.names=FALSE, + quote=FALSE) } else { mlm_model$save_pretrained(save_directory=output_dir, safe_serilization=pt_safe_save) + history_log=pandas$DataFrame(trainer$state$log_history) + history_log=clean_pytorch_log_transformers(history_log) + write.csv2(history_log, + file=paste0(output_dir,"/history.log"), + row.names=FALSE, + quote=FALSE) } update_aifeducation_progress_bar(value = 8, total = pgr_max, title = "RoBERTa Model") diff --git a/inst/python/pytorch_te_classifier.py b/inst/python/pytorch_te_classifier.py index 588b536..87151a3 100644 --- a/inst/python/pytorch_te_classifier.py +++ b/inst/python/pytorch_te_classifier.py @@ -152,7 +152,10 @@ class GlobalAveragePooling1D_PT(torch.nn.Module): def __init__(self): super().__init__() - def forward(self,x): + def forward(self,x,mask=None): + if not mask is None: + mask_r=mask.reshape(mask.size()[0],mask.size()[1],1) + x=torch.mul(x,mask_r) x=torch.sum(x,dim=1)*(1/self.get_length(x)) return x diff --git a/inst/python/pytorch_transformer_callbacks.py b/inst/python/pytorch_transformer_callbacks.py index 32445ef..b23f121 100644 --- a/inst/python/pytorch_transformer_callbacks.py +++ b/inst/python/pytorch_transformer_callbacks.py @@ -1,4 +1,5 @@ import transformers +import pandas class ReportAiforeducationShiny_PT(transformers.TrainerCallback): def on_train_begin(self, args, state, control, **kwargs): @@ -13,4 +14,3 @@ def on_step_end(self, args, state, control, **kwargs): - diff --git a/man/TextEmbeddingModel.Rd b/man/TextEmbeddingModel.Rd index 4548a12..0ad06b1 100644 --- a/man/TextEmbeddingModel.Rd +++ b/man/TextEmbeddingModel.Rd @@ -20,6 +20,15 @@ Other Text Embedding: \code{\link{combine_embeddings}()} } \concept{Text Embedding} +\section{Public fields}{ +\if{html}{\out{
}} +\describe{ +\item{\code{last_training}}{('list()')\cr +List for storing the history and the results of the last training. This +information will be overwritten if a new training is started.} +} +\if{html}{\out{
}} +} \section{Methods}{ \subsection{Public methods}{ \itemize{ diff --git a/man/array_to_matrix.Rd b/man/array_to_matrix.Rd index 1aefe95..9e4535c 100644 --- a/man/array_to_matrix.Rd +++ b/man/array_to_matrix.Rd @@ -21,6 +21,7 @@ Function transforming an array to a matrix. Other Auxiliary Functions: \code{\link{calc_standard_classification_measures}()}, \code{\link{check_embedding_models}()}, +\code{\link{clean_pytorch_log_transformers}()}, \code{\link{create_iota2_mean_object}()}, \code{\link{create_synthetic_units}()}, \code{\link{generate_id}()}, @@ -30,6 +31,7 @@ Other Auxiliary Functions: \code{\link{get_stratified_train_test_split}()}, \code{\link{get_synthetic_cases}()}, \code{\link{get_train_test_split}()}, +\code{\link{is.null_or_na}()}, \code{\link{matrix_to_array_c}()}, \code{\link{split_labeled_unlabeled}()}, \code{\link{summarize_tracked_sustainability}()}, diff --git a/man/calc_standard_classification_measures.Rd b/man/calc_standard_classification_measures.Rd index 07420b5..0209e8a 100644 --- a/man/calc_standard_classification_measures.Rd +++ b/man/calc_standard_classification_measures.Rd @@ -22,6 +22,7 @@ Function for calculating recall, precision, and f1. Other Auxiliary Functions: \code{\link{array_to_matrix}()}, \code{\link{check_embedding_models}()}, +\code{\link{clean_pytorch_log_transformers}()}, \code{\link{create_iota2_mean_object}()}, \code{\link{create_synthetic_units}()}, \code{\link{generate_id}()}, @@ -31,6 +32,7 @@ Other Auxiliary Functions: \code{\link{get_stratified_train_test_split}()}, \code{\link{get_synthetic_cases}()}, \code{\link{get_train_test_split}()}, +\code{\link{is.null_or_na}()}, \code{\link{matrix_to_array_c}()}, \code{\link{split_labeled_unlabeled}()}, \code{\link{summarize_tracked_sustainability}()}, diff --git a/man/check_embedding_models.Rd b/man/check_embedding_models.Rd index ad93643..13889c1 100644 --- a/man/check_embedding_models.Rd +++ b/man/check_embedding_models.Rd @@ -25,6 +25,7 @@ only with data generated through compatible embedding models. Other Auxiliary Functions: \code{\link{array_to_matrix}()}, \code{\link{calc_standard_classification_measures}()}, +\code{\link{clean_pytorch_log_transformers}()}, \code{\link{create_iota2_mean_object}()}, \code{\link{create_synthetic_units}()}, \code{\link{generate_id}()}, @@ -34,6 +35,7 @@ Other Auxiliary Functions: \code{\link{get_stratified_train_test_split}()}, \code{\link{get_synthetic_cases}()}, \code{\link{get_train_test_split}()}, +\code{\link{is.null_or_na}()}, \code{\link{matrix_to_array_c}()}, \code{\link{split_labeled_unlabeled}()}, \code{\link{summarize_tracked_sustainability}()}, diff --git a/man/clean_pytorch_log_transformers.Rd b/man/clean_pytorch_log_transformers.Rd new file mode 100644 index 0000000..dae33af --- /dev/null +++ b/man/clean_pytorch_log_transformers.Rd @@ -0,0 +1,40 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/aux_fct.R +\name{clean_pytorch_log_transformers} +\alias{clean_pytorch_log_transformers} +\title{Clean pytorch log of transformers} +\usage{ +clean_pytorch_log_transformers(log) +} +\arguments{ +\item{log}{\code{data.frame} containing the log.} +} +\value{ +Returns a \code{data.frame} containing epochs, loss, and val_loss. +} +\description{ +Function for preparing and cleaning the log created by an object of class Trainer +from the python library 'transformer's +} +\seealso{ +Other Auxiliary Functions: +\code{\link{array_to_matrix}()}, +\code{\link{calc_standard_classification_measures}()}, +\code{\link{check_embedding_models}()}, +\code{\link{create_iota2_mean_object}()}, +\code{\link{create_synthetic_units}()}, +\code{\link{generate_id}()}, +\code{\link{get_coder_metrics}()}, +\code{\link{get_folds}()}, +\code{\link{get_n_chunks}()}, +\code{\link{get_stratified_train_test_split}()}, +\code{\link{get_synthetic_cases}()}, +\code{\link{get_train_test_split}()}, +\code{\link{is.null_or_na}()}, +\code{\link{matrix_to_array_c}()}, +\code{\link{split_labeled_unlabeled}()}, +\code{\link{summarize_tracked_sustainability}()}, +\code{\link{to_categorical_c}()} +} +\concept{Auxiliary Functions} +\keyword{internal} diff --git a/man/create_iota2_mean_object.Rd b/man/create_iota2_mean_object.Rd index 352d897..1a2b561 100644 --- a/man/create_iota2_mean_object.Rd +++ b/man/create_iota2_mean_object.Rd @@ -36,6 +36,7 @@ Other Auxiliary Functions: \code{\link{array_to_matrix}()}, \code{\link{calc_standard_classification_measures}()}, \code{\link{check_embedding_models}()}, +\code{\link{clean_pytorch_log_transformers}()}, \code{\link{create_synthetic_units}()}, \code{\link{generate_id}()}, \code{\link{get_coder_metrics}()}, @@ -44,6 +45,7 @@ Other Auxiliary Functions: \code{\link{get_stratified_train_test_split}()}, \code{\link{get_synthetic_cases}()}, \code{\link{get_train_test_split}()}, +\code{\link{is.null_or_na}()}, \code{\link{matrix_to_array_c}()}, \code{\link{split_labeled_unlabeled}()}, \code{\link{summarize_tracked_sustainability}()}, diff --git a/man/create_synthetic_units.Rd b/man/create_synthetic_units.Rd index cc0e8b5..f66702c 100644 --- a/man/create_synthetic_units.Rd +++ b/man/create_synthetic_units.Rd @@ -40,6 +40,7 @@ Other Auxiliary Functions: \code{\link{array_to_matrix}()}, \code{\link{calc_standard_classification_measures}()}, \code{\link{check_embedding_models}()}, +\code{\link{clean_pytorch_log_transformers}()}, \code{\link{create_iota2_mean_object}()}, \code{\link{generate_id}()}, \code{\link{get_coder_metrics}()}, @@ -48,6 +49,7 @@ Other Auxiliary Functions: \code{\link{get_stratified_train_test_split}()}, \code{\link{get_synthetic_cases}()}, \code{\link{get_train_test_split}()}, +\code{\link{is.null_or_na}()}, \code{\link{matrix_to_array_c}()}, \code{\link{split_labeled_unlabeled}()}, \code{\link{summarize_tracked_sustainability}()}, diff --git a/man/generate_id.Rd b/man/generate_id.Rd index bb1347b..ad5da9a 100644 --- a/man/generate_id.Rd +++ b/man/generate_id.Rd @@ -21,6 +21,7 @@ Other Auxiliary Functions: \code{\link{array_to_matrix}()}, \code{\link{calc_standard_classification_measures}()}, \code{\link{check_embedding_models}()}, +\code{\link{clean_pytorch_log_transformers}()}, \code{\link{create_iota2_mean_object}()}, \code{\link{create_synthetic_units}()}, \code{\link{get_coder_metrics}()}, @@ -29,6 +30,7 @@ Other Auxiliary Functions: \code{\link{get_stratified_train_test_split}()}, \code{\link{get_synthetic_cases}()}, \code{\link{get_train_test_split}()}, +\code{\link{is.null_or_na}()}, \code{\link{matrix_to_array_c}()}, \code{\link{split_labeled_unlabeled}()}, \code{\link{summarize_tracked_sustainability}()}, diff --git a/man/get_coder_metrics.Rd b/man/get_coder_metrics.Rd index 44b6ae6..93a51cd 100644 --- a/man/get_coder_metrics.Rd +++ b/man/get_coder_metrics.Rd @@ -53,6 +53,7 @@ Other Auxiliary Functions: \code{\link{array_to_matrix}()}, \code{\link{calc_standard_classification_measures}()}, \code{\link{check_embedding_models}()}, +\code{\link{clean_pytorch_log_transformers}()}, \code{\link{create_iota2_mean_object}()}, \code{\link{create_synthetic_units}()}, \code{\link{generate_id}()}, @@ -61,6 +62,7 @@ Other Auxiliary Functions: \code{\link{get_stratified_train_test_split}()}, \code{\link{get_synthetic_cases}()}, \code{\link{get_train_test_split}()}, +\code{\link{is.null_or_na}()}, \code{\link{matrix_to_array_c}()}, \code{\link{split_labeled_unlabeled}()}, \code{\link{summarize_tracked_sustainability}()}, diff --git a/man/get_folds.Rd b/man/get_folds.Rd index 179d8ee..a69912c 100644 --- a/man/get_folds.Rd +++ b/man/get_folds.Rd @@ -46,6 +46,7 @@ Other Auxiliary Functions: \code{\link{array_to_matrix}()}, \code{\link{calc_standard_classification_measures}()}, \code{\link{check_embedding_models}()}, +\code{\link{clean_pytorch_log_transformers}()}, \code{\link{create_iota2_mean_object}()}, \code{\link{create_synthetic_units}()}, \code{\link{generate_id}()}, @@ -54,6 +55,7 @@ Other Auxiliary Functions: \code{\link{get_stratified_train_test_split}()}, \code{\link{get_synthetic_cases}()}, \code{\link{get_train_test_split}()}, +\code{\link{is.null_or_na}()}, \code{\link{matrix_to_array_c}()}, \code{\link{split_labeled_unlabeled}()}, \code{\link{summarize_tracked_sustainability}()}, diff --git a/man/get_n_chunks.Rd b/man/get_n_chunks.Rd index 65777ce..f9ec92a 100644 --- a/man/get_n_chunks.Rd +++ b/man/get_n_chunks.Rd @@ -25,6 +25,7 @@ Other Auxiliary Functions: \code{\link{array_to_matrix}()}, \code{\link{calc_standard_classification_measures}()}, \code{\link{check_embedding_models}()}, +\code{\link{clean_pytorch_log_transformers}()}, \code{\link{create_iota2_mean_object}()}, \code{\link{create_synthetic_units}()}, \code{\link{generate_id}()}, @@ -33,6 +34,7 @@ Other Auxiliary Functions: \code{\link{get_stratified_train_test_split}()}, \code{\link{get_synthetic_cases}()}, \code{\link{get_train_test_split}()}, +\code{\link{is.null_or_na}()}, \code{\link{matrix_to_array_c}()}, \code{\link{split_labeled_unlabeled}()}, \code{\link{summarize_tracked_sustainability}()}, diff --git a/man/get_stratified_train_test_split.Rd b/man/get_stratified_train_test_split.Rd index 83f0a3c..0fc0e52 100644 --- a/man/get_stratified_train_test_split.Rd +++ b/man/get_stratified_train_test_split.Rd @@ -27,6 +27,7 @@ Other Auxiliary Functions: \code{\link{array_to_matrix}()}, \code{\link{calc_standard_classification_measures}()}, \code{\link{check_embedding_models}()}, +\code{\link{clean_pytorch_log_transformers}()}, \code{\link{create_iota2_mean_object}()}, \code{\link{create_synthetic_units}()}, \code{\link{generate_id}()}, @@ -35,6 +36,7 @@ Other Auxiliary Functions: \code{\link{get_n_chunks}()}, \code{\link{get_synthetic_cases}()}, \code{\link{get_train_test_split}()}, +\code{\link{is.null_or_na}()}, \code{\link{matrix_to_array_c}()}, \code{\link{split_labeled_unlabeled}()}, \code{\link{summarize_tracked_sustainability}()}, diff --git a/man/get_synthetic_cases.Rd b/man/get_synthetic_cases.Rd index 50d6cf7..78c4d4c 100644 --- a/man/get_synthetic_cases.Rd +++ b/man/get_synthetic_cases.Rd @@ -50,6 +50,7 @@ Other Auxiliary Functions: \code{\link{array_to_matrix}()}, \code{\link{calc_standard_classification_measures}()}, \code{\link{check_embedding_models}()}, +\code{\link{clean_pytorch_log_transformers}()}, \code{\link{create_iota2_mean_object}()}, \code{\link{create_synthetic_units}()}, \code{\link{generate_id}()}, @@ -58,6 +59,7 @@ Other Auxiliary Functions: \code{\link{get_n_chunks}()}, \code{\link{get_stratified_train_test_split}()}, \code{\link{get_train_test_split}()}, +\code{\link{is.null_or_na}()}, \code{\link{matrix_to_array_c}()}, \code{\link{split_labeled_unlabeled}()}, \code{\link{summarize_tracked_sustainability}()}, diff --git a/man/get_train_test_split.Rd b/man/get_train_test_split.Rd index e9100bc..48dba2b 100644 --- a/man/get_train_test_split.Rd +++ b/man/get_train_test_split.Rd @@ -36,6 +36,7 @@ Other Auxiliary Functions: \code{\link{array_to_matrix}()}, \code{\link{calc_standard_classification_measures}()}, \code{\link{check_embedding_models}()}, +\code{\link{clean_pytorch_log_transformers}()}, \code{\link{create_iota2_mean_object}()}, \code{\link{create_synthetic_units}()}, \code{\link{generate_id}()}, @@ -44,6 +45,7 @@ Other Auxiliary Functions: \code{\link{get_n_chunks}()}, \code{\link{get_stratified_train_test_split}()}, \code{\link{get_synthetic_cases}()}, +\code{\link{is.null_or_na}()}, \code{\link{matrix_to_array_c}()}, \code{\link{split_labeled_unlabeled}()}, \code{\link{summarize_tracked_sustainability}()}, diff --git a/man/is.null_or_na.Rd b/man/is.null_or_na.Rd new file mode 100644 index 0000000..e85806e --- /dev/null +++ b/man/is.null_or_na.Rd @@ -0,0 +1,40 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/aux_fct.R +\name{is.null_or_na} +\alias{is.null_or_na} +\title{Check if NULL or NA} +\usage{ +is.null_or_na(object) +} +\arguments{ +\item{object}{An object to test.} +} +\value{ +Returns \code{FALSE} if the object is not \code{NULL} and not \code{NA}. +Returns \code{TRUE} in all other cases. +} +\description{ +Function for checking if an object is \code{NULL} or \codee{NA} +} +\seealso{ +Other Auxiliary Functions: +\code{\link{array_to_matrix}()}, +\code{\link{calc_standard_classification_measures}()}, +\code{\link{check_embedding_models}()}, +\code{\link{clean_pytorch_log_transformers}()}, +\code{\link{create_iota2_mean_object}()}, +\code{\link{create_synthetic_units}()}, +\code{\link{generate_id}()}, +\code{\link{get_coder_metrics}()}, +\code{\link{get_folds}()}, +\code{\link{get_n_chunks}()}, +\code{\link{get_stratified_train_test_split}()}, +\code{\link{get_synthetic_cases}()}, +\code{\link{get_train_test_split}()}, +\code{\link{matrix_to_array_c}()}, +\code{\link{split_labeled_unlabeled}()}, +\code{\link{summarize_tracked_sustainability}()}, +\code{\link{to_categorical_c}()} +} +\concept{Auxiliary Functions} +\keyword{internal} diff --git a/man/matrix_to_array_c.Rd b/man/matrix_to_array_c.Rd index ba94c28..a97c725 100644 --- a/man/matrix_to_array_c.Rd +++ b/man/matrix_to_array_c.Rd @@ -26,6 +26,7 @@ Other Auxiliary Functions: \code{\link{array_to_matrix}()}, \code{\link{calc_standard_classification_measures}()}, \code{\link{check_embedding_models}()}, +\code{\link{clean_pytorch_log_transformers}()}, \code{\link{create_iota2_mean_object}()}, \code{\link{create_synthetic_units}()}, \code{\link{generate_id}()}, @@ -35,6 +36,7 @@ Other Auxiliary Functions: \code{\link{get_stratified_train_test_split}()}, \code{\link{get_synthetic_cases}()}, \code{\link{get_train_test_split}()}, +\code{\link{is.null_or_na}()}, \code{\link{split_labeled_unlabeled}()}, \code{\link{summarize_tracked_sustainability}()}, \code{\link{to_categorical_c}()} diff --git a/man/split_labeled_unlabeled.Rd b/man/split_labeled_unlabeled.Rd index d0b061b..88a8f6d 100644 --- a/man/split_labeled_unlabeled.Rd +++ b/man/split_labeled_unlabeled.Rd @@ -33,6 +33,7 @@ Other Auxiliary Functions: \code{\link{array_to_matrix}()}, \code{\link{calc_standard_classification_measures}()}, \code{\link{check_embedding_models}()}, +\code{\link{clean_pytorch_log_transformers}()}, \code{\link{create_iota2_mean_object}()}, \code{\link{create_synthetic_units}()}, \code{\link{generate_id}()}, @@ -42,6 +43,7 @@ Other Auxiliary Functions: \code{\link{get_stratified_train_test_split}()}, \code{\link{get_synthetic_cases}()}, \code{\link{get_train_test_split}()}, +\code{\link{is.null_or_na}()}, \code{\link{matrix_to_array_c}()}, \code{\link{summarize_tracked_sustainability}()}, \code{\link{to_categorical_c}()} diff --git a/man/summarize_tracked_sustainability.Rd b/man/summarize_tracked_sustainability.Rd index e626761..b2015f6 100644 --- a/man/summarize_tracked_sustainability.Rd +++ b/man/summarize_tracked_sustainability.Rd @@ -22,6 +22,7 @@ Other Auxiliary Functions: \code{\link{array_to_matrix}()}, \code{\link{calc_standard_classification_measures}()}, \code{\link{check_embedding_models}()}, +\code{\link{clean_pytorch_log_transformers}()}, \code{\link{create_iota2_mean_object}()}, \code{\link{create_synthetic_units}()}, \code{\link{generate_id}()}, @@ -31,6 +32,7 @@ Other Auxiliary Functions: \code{\link{get_stratified_train_test_split}()}, \code{\link{get_synthetic_cases}()}, \code{\link{get_train_test_split}()}, +\code{\link{is.null_or_na}()}, \code{\link{matrix_to_array_c}()}, \code{\link{split_labeled_unlabeled}()}, \code{\link{to_categorical_c}()} diff --git a/man/to_categorical_c.Rd b/man/to_categorical_c.Rd index ae4e43d..3cdd76f 100644 --- a/man/to_categorical_c.Rd +++ b/man/to_categorical_c.Rd @@ -25,6 +25,7 @@ Other Auxiliary Functions: \code{\link{array_to_matrix}()}, \code{\link{calc_standard_classification_measures}()}, \code{\link{check_embedding_models}()}, +\code{\link{clean_pytorch_log_transformers}()}, \code{\link{create_iota2_mean_object}()}, \code{\link{create_synthetic_units}()}, \code{\link{generate_id}()}, @@ -34,6 +35,7 @@ Other Auxiliary Functions: \code{\link{get_stratified_train_test_split}()}, \code{\link{get_synthetic_cases}()}, \code{\link{get_train_test_split}()}, +\code{\link{is.null_or_na}()}, \code{\link{matrix_to_array_c}()}, \code{\link{split_labeled_unlabeled}()}, \code{\link{summarize_tracked_sustainability}()} diff --git a/tests/testthat/test-04_transformer_models.R b/tests/testthat/test-04_transformer_models.R index cda5df2..0be4c75 100644 --- a/tests/testthat/test-04_transformer_models.R +++ b/tests/testthat/test-04_transformer_models.R @@ -560,6 +560,15 @@ for(ai_method in ai_methods){ model_dir=testthat::test_path(paste0(path_01,"/",framework)) ) + test_that(paste0(ai_method,"training history after creation",framework),{ + history=bert_modeling$last_training$history + expect_equal(nrow(history),2) + expect_equal(ncol(history),3) + expect_true("epoch"%in%colnames(history)) + expect_true("loss"%in%colnames(history)) + expect_true("val_loss"%in%colnames(history)) + }) + test_that(paste0(ai_method,"embedding",framework,"get_transformer_components"),{ expect_equal(bert_modeling$get_transformer_components()$emb_layer_min,min_layer) expect_equal(bert_modeling$get_transformer_components()$emb_layer_max,max_layer) @@ -822,6 +831,14 @@ for(ai_method in ai_methods){ ) expect_s3_class(bert_modeling, class="TextEmbeddingModel") + + history=bert_modeling$last_training$history + expect_equal(nrow(history),2) + expect_equal(ncol(history),3) + expect_true("epoch"%in%colnames(history)) + expect_true("loss"%in%colnames(history)) + expect_true("val_loss"%in%colnames(history)) + }) } else { test_that(paste0(ai_method,"Save Total Model safetensors",framework), { @@ -845,6 +862,12 @@ for(ai_method in ai_methods){ ) expect_s3_class(bert_modeling, class="TextEmbeddingModel") + history=bert_modeling$last_training$history + expect_equal(nrow(history),2) + expect_equal(ncol(history),3) + expect_true("epoch"%in%colnames(history)) + expect_true("loss"%in%colnames(history)) + expect_true("val_loss"%in%colnames(history)) }) } @@ -874,6 +897,12 @@ for(ai_method in ai_methods){ ) expect_s3_class(bert_modeling, class="TextEmbeddingModel") + history=bert_modeling$last_training$history + expect_equal(nrow(history),2) + expect_equal(ncol(history),3) + expect_true("epoch"%in%colnames(history)) + expect_true("loss"%in%colnames(history)) + expect_true("val_loss"%in%colnames(history)) }) #------------------------------------------------------------------------- @@ -894,6 +923,12 @@ for(ai_method in ai_methods){ ) expect_s3_class(bert_modeling, class="TextEmbeddingModel") + history=bert_modeling$last_training$history + expect_equal(nrow(history),2) + expect_equal(ncol(history),3) + expect_true("epoch"%in%colnames(history)) + expect_true("loss"%in%colnames(history)) + expect_true("val_loss"%in%colnames(history)) }) #------------------------------------------------------------------------ diff --git a/vignettes/gui_aife_studio.Rmd b/vignettes/gui_aife_studio.Rmd index 649c934..b60d645 100644 --- a/vignettes/gui_aife_studio.Rmd +++ b/vignettes/gui_aife_studio.Rmd @@ -509,7 +509,11 @@ model. [![Figure 13: Text Embedding Model - Description](img_articles/gui_aife_studio_emb_interface_use_desc.jpg){width="100%"}](img_articles/gui_aife_studio_emb_interface_use_desc.jpg) Figure 13: Text Embedding Model - Description (click image to enlarge) - + +The tab *Training* shows the development of the loss and the validation loss +during the last training of the corresponding base model. If no plot is displayed no +history data is available. + The tab *Create Text Embeddings* (Figure 14) allows you to transform raw texts into a numerical representation of these texts, called text embeddings. These text embeddings can be used in downstream tasks such as classifying diff --git a/vignettes/img_articles/gui_aife_studio_emb_interface_create.JPG b/vignettes/img_articles/gui_aife_studio_emb_interface_create.JPG index 712e427..a799408 100644 Binary files a/vignettes/img_articles/gui_aife_studio_emb_interface_create.JPG and b/vignettes/img_articles/gui_aife_studio_emb_interface_create.JPG differ