
# Coming back after model training


After training a model, you might want to test its performance, make
predictions or do whatever you want with it, such as continue training.

<div class="alert alert-info"><h4>Note</h4><p>This example assummes:
        * [PyTorch Lightning checkpoints](https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#lightningmodule-from-checkpoint)
          are enabled during training.
        * Training was performed with AIdsorb :doc:`../cli` or `AIdsorb +
          PyTorch Lightning <aidsorb_with_pytorch_and_lightning>`.</p></div>



In [None]:
import yaml
import torch
import lightning as L
from lightning.pytorch.cli import LightningCLI, LightningArgumentParser
from aidsorb.datamodules import PCDDataModule
from aidsorb.litmodels import PointNetLit
from aidsorb.visualize import draw_pcd

The following function let us recreate:

* Trainer
* LightningModule (litmodel)
* Datamodule

with the same settings as in the ``.yaml`` configuration file. For more
information 👉 [here](https://github.com/Lightning-AI/pytorch-lightning/discussions/10363#discussioncomment-2326235).



In [None]:
def load_from_config(filename):
    r"""
    Load configuration, trainer, model and datamodule from a ``.yaml`` file.

    .. note::
        You are responsible for restoring the model's state (the weights of the model).

    Parameters
    ----------
    filename: str
        Absolute or relative path to the ``.yaml`` configuration file.
    """
    with open(filename, 'r') as f:
        config_dict = yaml.safe_load(f)

    config_dict['trainer']['logger'] = False
    del config_dict['seed_everything'], config_dict['ckpt_path']

    parser = LightningArgumentParser()
    parser.add_class_arguments(PointNetLit, 'model', fail_untyped=False)
    parser.add_class_arguments(PCDDataModule, 'data', fail_untyped=False)
    parser.add_class_arguments(L.Trainer, 'trainer', fail_untyped=False)
    config = parser.parse_object(config_dict)
    objects = parser.instantiate_classes(config)

    return config, objects.trainer, objects.model, objects.data

In [None]:
config, trainer, litmodel, dm = load_from_config('path/to/logs/config.yaml')

## Restoring model's state 



In [None]:
ckpt = torch.load('path/to/checkpoints/checkpoint.ckpt')
model_weights = {k: v for k, v in ckpt['state_dict'].items() if k.startswith('model.')}

In [None]:
# Due to lazy initialization we need to pass a dummy input with correct shape.
in_channels = 5  # For xyz + Z + 1 additional feature.
x = torch.randn(32, in_channels, 100)
litmodel(x);

In [None]:
# Load back the weights.
litmodel.load_state_dict(model_weights)

In [None]:
# Set the model in inference mode.
litmodel.eval()
litmodel.training

## Measure performance and make predictions



In [None]:
# Measure performance on test set.
trainer.test(litmodel, dm)

In [None]:
# Predict on the test set.
y_pred = torch.cat(trainer.predict(litmodel, dm.test_dataloader()))

# Predict on the train set.
y_pred = torch.cat(trainer.predict(litmodel, dm.train_dataloader()))