In this section, we’ll explore the training part of the YAML configuration file. This configuration defines how the model training process is set up, including specifying the trainer type, loss functions, datasets, optimizers, and other training parameters.

Here’s the training section from your YAML file:
```yaml
training:
  trainer:
    type: 'AdversarialTrainer'
    params:
      lambda_triplet: 1.0
      lambda_adv: 1.0
      lambda_gp: 10.0
      use_gradient_penalty: true
      discriminator:
        embedding_dim: 1024
  num_epochs: 10
  validate_every: 2
  save_every: 2
  save_best_only: true
  device: 'cpu'
  log_dir: 'path_to_log_directory'
  loss_function:
    type: 'TripletCosineLoss'
    params:
      margin: 0.1
      normalize: false
  optimizer:
    type: 'Adam'
    params:
      lr: 0.001
  discriminator_optimizer:
    type: 'Adam'
    params:
      lr: 0.001
  train_loader:
    path: 'path_to_triplet_train_dataset.pt'
    dataset_type: 'TripletDataset'
    batch_size: 32
    num_workers: 0
    shuffle: true
  val_loader:
    path: 'path_to_triplet_val_dataset.pt'
    dataset_type: 'TripletDataset'
    batch_size: 32
    num_workers: 0
    shuffle: false
  test_loader:
    path: 'path_to_triplet_test_dataset.pt'
    dataset_type: 'TripletDataset'
    batch_size: 32
    num_workers: 0
    shuffle: false
```
#### Overview

	•	Trainer Specification: Choose the trainer class to use and set its parameters.
	•	Training Parameters: Define epochs, validation frequency, saving frequency, etc.
	•	Loss Function: Specify the loss function and its parameters.
	•	Optimizers: Define optimizers for the model and discriminator (if applicable).
	•	Data Loaders: Set up the training, validation, and test datasets.

#### 1. Specifying the Trainer

Available Trainers

You have three trainer classes available:

	1.	Trainer: Basic trainer suitable for simple tasks.
	2.	TripletTrainer: Designed for triplet-based training using triplet losses.
	3.	AdversarialTrainer: Combines triplet loss with adversarial training.

Choosing the Trainer

In your YAML configuration:

```
trainer:
  type: 'AdversarialTrainer'
  params:
    lambda_triplet: 1.0
    lambda_adv: 1.0
    lambda_gp: 10.0
    use_gradient_penalty: true
    discriminator:
      embedding_dim: 1024
```
	•	type: Specifies the trainer class to use.
	•	params: Additional parameters specific to the chosen trainer.

Trainer-Specific Parameters

AdversarialTrainer

	•	lambda_triplet: Weight for the triplet loss component.
	•	lambda_adv: Weight for the adversarial loss component.
	•	lambda_gp: Weight for the gradient penalty term.
	•	use_gradient_penalty: Boolean flag to enable or disable gradient penalty.
	•	discriminator: Configuration for the discriminator network.
	•	embedding_dim: Dimension of the embeddings used in the discriminator.

Other Trainers

	•	Trainer: May not require additional parameters.
	•	TripletTrainer: Might have its own specific parameters.

#### 2. Mapping Trainers to Datasets and Loss Functions

Datasets

	•	SimpleDataset:
	•	Contains a list of graph data objects.
	•	Used with the basic Trainer.
	•	Compatible with losses that require only inputs and targets (e.g., MSELoss, CosineEmbeddingLoss).
	•	TripletDataset:
	•	Contains triplets of data: anchor, positive, and negative examples.
	•	Used with TripletTrainer and AdversarialTrainer.
	•	Compatible with triplet-based losses (e.g., TripletMarginLoss, TripletCosineLoss).

Loss Functions

	•	Losses for SimpleDataset and Trainer:
	•	MSELoss
	•	CosineEmbeddingLoss
	•	Losses for TripletDataset and TripletTrainer/AdversarialTrainer:
	•	TripletMarginLoss
	•	TripletCosineLoss

Combining Trainers, Datasets, and Losses

	•	Basic Trainer with SimpleDataset:
	•	Suitable for regression tasks or embedding matching.
	•	Use with MSELoss or CosineEmbeddingLoss.
	•	TripletTrainer with TripletDataset:
	•	For training models using triplet losses to enforce relative similarity.
	•	Use with TripletMarginLoss or TripletCosineLoss.
	•	AdversarialTrainer with TripletDataset:
	•	Combines triplet loss with adversarial training.
	•	Suitable for tasks requiring both embedding alignment and adversarial robustness.

#### 3. Loss Function Configuration

In your YAML:
```
loss_function:
  type: 'TripletCosineLoss'
  params:
    margin: 0.1
    normalize: false
```
	•	type: The loss function class to use.
	•	params: Parameters specific to the loss function.

Example: TripletCosineLoss

	•	margin: The margin enforced between positive and negative pairs.
	•	normalize: Whether to normalize embeddings before computing cosine similarity.

#### 4. Optimizer Configuration

Model Optimizer
```
optimizer:
  type: 'Adam'
  params:
    lr: 0.001
```
	•	type: The optimizer class to use for the model.
	•	params: Parameters specific to the optimizer (e.g., learning rate).

Discriminator Optimizer (AdversarialTrainer)
```
discriminator_optimizer:
  type: 'Adam'
  params:
    lr: 0.001
```
	•	Only required when using AdversarialTrainer.
	•	Used to optimize the discriminator network separately.

#### 5. Data Loader Configuration

Dataset Preparation

	•	SimpleDataset:
	•	Prepared in massspec_dreams_embedding.ipynb.
	•	Contains data like:
```
Data(x=[64, 84], edge_index=[2, 128], edge_attr=[128, 7], y=[1, 1024],
     IDENTIFIER=[1], COLLISION_ENERGY=[1, 1], adduct=[1], precursor_mz=[1, 1])
```

	•	TripletDataset:
	•	Prepared in TripletMarginLoss_dataset_construction.ipynb.
	•	Contains triplets of the above data structures.

Data Loader Configuration in YAML

Training Data Loader
```
train_loader:
  path: 'path_to_triplet_dataset.pt'
  dataset_type: 'TripletDataset'
  batch_size: 32
  num_workers: 0
  shuffle: true
```
	•	path: Path to the saved dataset file.
	•	dataset_type: Type of the dataset class (SimpleDataset or TripletDataset).
	•	batch_size: Number of samples per batch.
	•	num_workers: Number of subprocesses to use for data loading.
	•	shuffle: Whether to shuffle the data at every epoch.

Validation and Test Data Loaders

	•	Similar to the training data loader but typically with shuffle: false.

#### 6. Training Parameters
```
num_epochs: 10
validate_every: 2
save_every: 2
save_best_only: true
device: 'cpu'
log_dir: 'path_to_log_directory'
```
	•	num_epochs: Total number of training epochs.
	•	validate_every: Perform validation every N epochs.
	•	save_every: Save model checkpoints every N epochs.
	•	save_best_only: If true, only save the model when it achieves a new best validation loss.
	•	device: Device to use for training ('cpu' or 'cuda').
	•	log_dir: Directory to save logs and model checkpoints.

#### 7. Understanding the Data Structures

SimpleDataset Data Structure

	•	Each element is a Data object containing:
	•	x: Node feature matrix.
	•	edge_index: Graph connectivity in COO format.
	•	edge_attr: Edge feature matrix.
	•	y: Target embedding (e.g., DreaMS spectral embedding).
	•	Additional Attributes: Such as IDENTIFIER, COLLISION_ENERGY, adduct, precursor_mz.

TripletDataset Data Structure

	•	Each element is a tuple of three Data objects: (anchor, positive, negative).
	•	Used for training with triplet losses.

#### 8. Dataset Preparation

	•	SimpleDataset:
	•	Prepared using the massspec_dreams_embedding.ipynb notebook.
	•	Involves featurizing molecules and associating them with embeddings.
	•	TripletDataset:
	•	Prepared using the TripletMarginLoss_dataset_construction.ipynb notebook.
	•	Based on approaches used in DreaMS for contrastive fine-tuning for atlas prediction.
	•	Triplets are formed to ensure meaningful anchor-positive-negative relationships.

#### 9. Setting Up the YAML for Training

When setting up the training section in your YAML configuration, follow these steps:

	1.	Choose the Trainer:
		Decide on the trainer type based on your task and dataset.
		If using triplet losses and triplet datasets, choose TripletTrainer or AdversarialTrainer.
		For simple regression or matching tasks, use the basic Trainer.
	2.	Configure the Loss Function:
		Select the appropriate loss function for your trainer and dataset.
		Set any necessary parameters specific to the loss.
	3.	Set Up the Optimizers:
		Define the optimizer for the model.
		If using AdversarialTrainer, also define the optimizer for the discriminator.
	4.	Prepare and Specify the Datasets:
		Ensure your datasets are prepared and saved at the specified paths.
		Use the correct dataset_type in the data loader configurations.
	5.	Define Training Parameters:
		Set the number of epochs, validation frequency, saving frequency, etc.
		Choose the device for training ('cpu' or 'cuda').
	6.	Additional Trainer Parameters:
		If using AdversarialTrainer, specify parameters like lambda_triplet, lambda_adv, etc.
		These control the weighting of different loss components and training behaviors.

#### 10. Example Configurations

Using Basic Trainer with SimpleDataset
```
trainer:
  type: 'Trainer'
  params: {}
loss_function:
  type: 'MSELoss'
  params: {}
train_loader:
  path: 'path_to_simple_dataset.pt'
  dataset_type: 'SimpleDataset'
  batch_size: 32
  num_workers: 0
  shuffle: true
```
	•	Suitable for regression tasks where you want to minimize the mean squared error between model outputs and targets.

Using TripletTrainer with TripletDataset
```
trainer:
  type: 'TripletTrainer'
  params: {}
loss_function:
  type: 'TripletMarginLoss'
  params:
    margin: 1.0
train_loader:
  path: 'path_to_triplet_train_dataset.pt'
  dataset_type: 'TripletDataset'
  batch_size: 32
  num_workers: 0
  shuffle: true
```
	•	Suitable for training models using triplet loss to enforce relative similarity.

Using AdversarialTrainer with TripletDataset
```
trainer:
  type: 'AdversarialTrainer'
  params:
    lambda_triplet: 1.0
    lambda_adv: 1.0
    lambda_gp: 10.0
    use_gradient_penalty: true
    discriminator:
      embedding_dim: 1024
loss_function:
  type: 'TripletCosineLoss'
  params:
    margin: 0.1
    normalize: false
train_loader:
  path: 'path_to_triplet_train_dataset.pt'
  dataset_type: 'TripletDataset'
  batch_size: 32
  num_workers: 0
  shuffle: true
```
	•	Combines triplet loss with adversarial training for more robust embedding alignment.

#### 11. Additional Notes

	•	Gradient Penalty in AdversarialTrainer:
	•	The gradient penalty term helps in stabilizing the adversarial training by penalizing the model when the gradients have high norms.
	•	Controlled by lambda_gp and can be enabled or disabled using attribute use_gradient_penalty.
	•	Discriminator Configuration:
	•	When using AdversarialTrainer, you need to define the discriminator network’s parameters.
	•	The embedding_dim should match the dimension of the embeddings produced by your model.
	•	Logging and Checkpointing:
	•	The log_dir specifies where logs and model checkpoints are saved.
	•	Model checkpoints are saved at intervals defined by save_every and can be configured to save only the best model using save_best_only.
	•	Device Selection:
	•	Specify 'cpu' or 'cuda' depending on whether you want to train on CPU or GPU.

#### Proceeding with Training

Once your YAML configuration is set up:

	1.	Load the Configuration:

import yaml

with open('config.yaml', 'r') as file:
    config = yaml.safe_load(file)


	2.	Build the Model and Trainer:

from mol2dreams.utils.parser import build_trainer_from_config

trainer = build_trainer_from_config(config)


	3.	Start Training:

trainer.train()


	4.	Evaluate the Model:

trainer.test()

