# Transfer Learning

Datasets computed at high levels of theory are expensive and thus, usually small. 
A model trained on this data might not be able to generalize well to unseen configurations.
Sometimes this can be remedied with transfer learning:
By first training a model on a lot of data from a less expensive level of theory, only small adjustments to the parameters are required to accurately reproduce the potential energy surface of a different level of theory.


Alternatively, the level of theory might not change, but the dataset is extended.
This is the case in learning on the fly scenarios.
For a demonstration of using transfer learning for learning on the fly, see the corresponding example from the [IPSuite documentation](https://ipsuite.readthedocs.io/en/latest/).


Apax comes with discriminative transfer learning capabilities out of the box.
In this tutorial we are going to fine tune a model trained on benzene data at the DFT level of theory to CCSDT.



First download the appropriate dataset from the sgdml website.


Transfer learning can be facilitated in apax by adding the path to a pre-trained model in the config.
Furthermore, we can freeze or reduce the learning rate of various components by adjusting the `optimizer` section of the config.

```yaml
optimizer:
    nn_lr: 0.004
    embedding_lr: 0.0
```

Learning rates of 0.0 will mask the respective weights during training steps.
Here, we will freeze the descriptor, reinitialize the scaling and shifting parameters and reduce the learning rate of all other components.

We can now fine tune the model by running
`apax train config.yaml`

In [1]:
from pathlib import Path

import yaml

from apax.utils.datasets import (
    download_benzene_DFT,
    download_md22_benzene_CCSDT,
    mod_md_datasets,
)
from apax.utils.helpers import mod_config

## Acquire Datasets

For this demonstration we will use the DFT and CC versions of the benzene MD17 dataset.
We start by downloading both and saving them in an appropriate format.

In [2]:
# Download DFT Data

data_path = Path("project")
dft_file_path = download_benzene_DFT(data_path)
dft_file_path = mod_md_datasets(dft_file_path)

In [3]:
# Download CCSD(T) Data

data_path = Path("project")
cc_file_path, _ = download_md22_benzene_CCSDT(data_path)
cc_file_path = mod_md_datasets(cc_file_path)

## Pretrain Model

First, we will train a model on the "large" (in relative terms) but less accurate DFT dataset.
A standard model with default optimizers will do just fine.

In [4]:
!apax template train --full

In [5]:
config_path = Path("config_full.yaml")

config_updates = {
    "n_epochs": 100,
    "data": {
        "n_train": 1000,
        "n_valid": 200,
        "batch_size": 8,
        "valid_batch_size": 100,
        "experiment": "benzene_dft",
        "directory": "project/models",
        "data_path": str(dft_file_path),
        "energy_unit": "kcal/mol",
        "pos_unit": "Ang",
    },
}
config_dict = mod_config(config_path, config_updates)

with open("config_full.yaml", "w") as conf:
    yaml.dump(config_dict, conf, default_flow_style=False)

In [6]:
!apax train config_full.yaml

E0000 00:00:1732124930.427330  459463 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732124930.430484  459463 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
INFO | 17:48:53 | Running on [CudaDevice(id=0)]
INFO | 17:48:53 | Initializing Callbacks
INFO | 17:48:53 | Initializing Loss Function
INFO | 17:48:53 | Initializing Metrics
INFO | 17:48:53 | Running Input Pipeline
INFO | 17:48:53 | Reading data file project/benzene_mod.xyz
INFO | 17:49:00 | Found n_train: 1000, n_val: 200
INFO | 17:49:00 | Computing per element energy regression.
INFO | 17:49:01 | Building Standard model
INFO | 17:49:01 | initializing 1 model(s)
INFO | 17:49:08 | Initializing Optimizer
INFO | 17:49:08 | Beginning Training
Epochs: 100%|████████████████████████████████████| 100/100 [00:52<00:00,  1.89it/s, val_loss=0.0206]
I

## Baseline CC Training

Next, we require a CC baseline to quantify the effect of pretraining.
As with the DFT dataset, we will only use a small fraction of the data to emphasize the effects in the low-data regime.

In [7]:
config_path = Path("config_full.yaml")

config_updates = {
    "n_epochs": 100,
    "data": {
        "n_train": 50,
        "n_valid": 10,
        "batch_size": 4,
        "valid_batch_size": 10,
        "experiment": "benzene_cc_baseline",
        "directory": "project/models",
        "data_path": str(cc_file_path),
        "energy_unit": "kcal/mol",
        "pos_unit": "Ang",
    },
}
config_dict = mod_config(config_path, config_updates)

with open("config_cc_baseline.yaml", "w") as conf:
    yaml.dump(config_dict, conf, default_flow_style=False)

In [8]:
!apax train config_cc_baseline.yaml

E0000 00:00:1732125003.227050  460393 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732125003.230272  460393 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
INFO | 17:50:05 | Running on [CudaDevice(id=0)]
INFO | 17:50:05 | Initializing Callbacks
INFO | 17:50:05 | Initializing Loss Function
INFO | 17:50:05 | Initializing Metrics
INFO | 17:50:05 | Running Input Pipeline
INFO | 17:50:05 | Reading data file project/benzene_ccsd_t-train_mod.xyz
INFO | 17:50:05 | Found n_train: 50, n_val: 10
INFO | 17:50:05 | Computing per element energy regression.
INFO | 17:50:06 | Building Standard model
INFO | 17:50:06 | initializing 1 model(s)
INFO | 17:50:13 | Initializing Optimizer
INFO | 17:50:13 | Beginning Training
Epochs: 100%|██████████████████████████████████████| 100/100 [00:29<00:00,  3.39it/s, val_lo

## DFT -> CC Fine Tuning

Finally, we can fine tune a model that was pretrained on DFT data.
The model architecture remains unchanged for all 3 runs.
However, for fine-tuning we need to specify the path to the base model and how to deal with its parameters.
For each parameter group we can choose to freeze, to reset it or to keep training it.
It is certainly advisable to experiment with different strategies, but a good start consists in freezing the embedding layer if the system we transfer to remains the same and resetting the scale-shift layer if the level of theory changes (DFT and CC have different energy scales).

Make sure to carefully inspect the config options below.

In [9]:
config_path = Path("config_full.yaml")

config_updates = {
    "n_epochs": 100,
    "data": {
        "n_train": 50,
        "n_valid": 10,
        "batch_size": 4,
        "valid_batch_size": 10,
        "experiment": "benzene_cc_ft",
        "directory": "project/models",
        "data_path": str(cc_file_path),
        "energy_unit": "kcal/mol",
        "pos_unit": "Ang",
    },
    "optimizer": {
        "emb_lr": 0.00,  # freeze embedding layer
        "nn_lr": 0.0005,  # lower lr
        "scale_lr": 0.001,  # lower lr
        "shift_lr": 0.005,  # lower lr
    },
    "checkpoints": {
        "base_model_checkpoint": "project/models/benzene_dft",  # pretrained model
        "reset_layers": ["scale_shift"],  # reset scale-shift layer
    },
}
config_dict = mod_config(config_path, config_updates)

with open("config_cc_ft.yaml", "w") as conf:
    yaml.dump(config_dict, conf, default_flow_style=False)

In [10]:
!apax train config_cc_ft.yaml

E0000 00:00:1732125044.359099  461322 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732125044.362317  461322 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
INFO | 17:50:46 | Running on [CudaDevice(id=0)]
INFO | 17:50:46 | Initializing Callbacks
INFO | 17:50:46 | Initializing Loss Function
INFO | 17:50:46 | Initializing Metrics
INFO | 17:50:46 | Running Input Pipeline
INFO | 17:50:46 | Reading data file project/benzene_ccsd_t-train_mod.xyz
INFO | 17:50:46 | Found n_train: 50, n_val: 10
INFO | 17:50:46 | Computing per element energy regression.
INFO | 17:50:46 | Building Standard model
INFO | 17:50:46 | initializing 1 model(s)
INFO | 17:50:53 | Initializing Optimizer
INFO | 17:50:53 | loading checkpoint from project/models/benzene_dft/best
INFO | 17:50:53 | Transferring parameters from project/

As we can see, the fine-tuned model achieves a lower validation loss than the baseline CC model.

How much further can you improve the fine-tuning (or pretraining) setup?