diff --git a/chebai/models/base.py b/chebai/models/base.py index e657963f..a6c9a4e6 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -49,6 +49,10 @@ def __init__( assert input_dim is not None, "input_dim must be specified" self.out_dim = out_dim self.input_dim = input_dim + print( + f"Input dimension for the model: {self.input_dim}", + f"Output dimension for the model: {self.out_dim}", + ) self.save_hyperparameters( ignore=[ diff --git a/chebai/models/ffn.py b/chebai/models/ffn.py index 18e9df4d..bb0dfb41 100644 --- a/chebai/models/ffn.py +++ b/chebai/models/ffn.py @@ -1,9 +1,11 @@ +from collections import OrderedDict from typing import Any, Dict, List, Optional, Tuple import torch from torch import Tensor, nn from chebai.models import ChebaiBaseNet +from chebai.models.electra import filter_dict class FFN(ChebaiBaseNet): @@ -15,6 +17,8 @@ def __init__( 1024, ], use_adam_optimizer: bool = False, + pretrained_checkpoint: Optional[str] = None, + load_prefix: Optional[str] = "model.", **kwargs, ): super().__init__(**kwargs) @@ -32,6 +36,33 @@ def __init__( layers.append(torch.nn.Linear(current_layer_input_size, self.out_dim)) self.model = nn.Sequential(*layers) + if pretrained_checkpoint is not None: + ckpt_file = torch.load( + pretrained_checkpoint, map_location=self.device, weights_only=False + ) + if load_prefix is not None: + state_dict = filter_dict(ckpt_file["state_dict"], load_prefix) + else: + state_dict = ckpt_file["state_dict"] + + model_sd = self.model.state_dict() + filtered = OrderedDict() + skipped = set() + for k, v in state_dict.items(): + if model_sd[k].shape == v.shape: + filtered[k] = v # only load params with matching shapes + else: + skipped.add(k) + filtered[k] = model_sd[k] + # else: silently skip mismatched keys like "2.weight", "2.bias" + # which is the last linear layers which maps to output dimension + + self.model.load_state_dict(filtered) + print( + f"Loaded (shape-matched) weights from {pretrained_checkpoint}", + f"Skipped the following weights: {skipped}", + ) + def _get_prediction_and_labels(self, data, labels, model_output): d = model_output["logits"] loss_kwargs = data.get("loss_kwargs", dict())