# Coming back after model training

In [1]:
import yaml
import torch
from lightning.pytorch.cli import LightningCLI, LightningArgumentParser
import lightning as L
from aidsorb.datamodules import PCDDataModule
from aidsorb.litmodels import PointNetLit
from aidsorb.visualize import draw_pcd

The following function let us recreate:

* Trainer
* LightningModule (litmodel)
* Datamodule

with the same settings as in the ``.yaml`` configuration file. For more information, see [here](https://github.com/Lightning-AI/pytorch-lightning/discussions/10363#discussioncomment-2326235).

The remaining step, is to load back the (trained) weights.

In [2]:
def load_from_config(filename):
    r"""
    Load configuration, trainer, model and datamodule from a ``.yaml`` file.

    .. note::
        1. This function assumes that all we need is to perform inference.
        2. You are responsible for restoring the model's state (the weights of the model).

    Parameters
    ----------
    filename: str
        Absolute or relative path to the ``.yaml`` configuration file.
    """
    with open(filename, 'r') as f:
        config_dict = yaml.safe_load(f)

    config_dict['trainer']['logger'] = False
    del config_dict['seed_everything'], config_dict['ckpt_path']

    parser = LightningArgumentParser()
    parser.add_class_arguments(PointNetLit, 'model', fail_untyped=False)
    parser.add_class_arguments(PCDDataModule, 'data', fail_untyped=False)
    parser.add_class_arguments(L.Trainer, 'trainer', fail_untyped=False)
    config = parser.parse_object(config_dict)
    objects = parser.instantiate_classes(config)

    return config, objects.trainer, objects.model, objects.data

In [3]:
config, trainer, litmodel, dm = load_from_config('lightning_logs/version_0/config.yaml')

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


## Restoring model's state

In [4]:
ckpt = torch.load('lightning_logs/version_0/checkpoints/best.ckpt')
model_weights = {k: v for k, v in ckpt['state_dict'].items() if k.startswith('model.')}

In [5]:
# Due to lazy initialization we need to pass a dummy input with correct number of channels.
in_channels = 7  # xyz + Z + 3 feats.
x = torch.randn(32, in_channels, 100)
litmodel(x);

# Load back the weights.
litmodel.load_state_dict(model_weights)

<All keys matched successfully>

In [6]:
# Set the model in inference mode.
litmodel.eval()
litmodel.training

False

## Measure performance

In [7]:
trainer.test(litmodel, dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: |                                                              | 0/? [00:00<?, ?it/s]

[{'r2_score': 0.8966765999794006,
  'mae': 0.0951441079378128,
  'mse': 0.028326408937573433}]

## Make predictions

In [8]:
y_pred = torch.cat(trainer.predict(litmodel, dm.test_dataloader()))

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                           | 0/? [00:00<?, ?it/s]

In [9]:
y_pred.shape

torch.Size([24331, 1])