In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from copy import deepcopy
from argparse import Namespace

from moment.utils.config import Config
from moment.utils.utils import parse_config
from moment.models.base import BaseModel
from moment.models.moment import MOMENTPipeline

In [None]:
def load_pipeline_from_checkpoint(args: Namespace) -> MOMENTPipeline:
    initial_args = deepcopy(args)
    checkpoint_args = deepcopy(args)
    model = MOMENTPipeline(initial_args)
    checkpoint = BaseModel.load_pretrained_weights(
        run_name=checkpoint_args.pretraining_run_name, 
        opt_steps=args.opt_steps
    )
    pretrained_model = MOMENTPipeline(checkpoint_args)
    pretrained_model.load_state_dict(checkpoint["model_state_dict"])
    # copy weights from pretrained model
    do_not_copy_head = True
    for ((name_p, param_p), (name_f, param_f)) in\
        zip(pretrained_model.named_parameters(), model.named_parameters()):
        if (name_p == name_f) and (param_p.shape == param_f.shape):
            if do_not_copy_head and name_p.startswith("head"):
                continue
            else:
                param_f.data = param_p.data
    return model

In [None]:
CONFIG_PATH = "../../configs/model_hub/model_hub.yaml"
config = Config(
    config_file_path=CONFIG_PATH,
    default_config_file_path=CONFIG_PATH
).parse()

args = parse_config(config)
args.device = "cpu"
print(args)

In [None]:
checkpoints = [
    ["MOMENT-large", "google/flan-t5-large", "fearless-planet-52-large", 55000],
]
for name, encoder_id, checkpoint, steps in checkpoints:
    # load checkpoint
    args.transformer_backbone = encoder_id
    args.pretraining_run_name = checkpoint
    args.opt_steps = steps
    args.model_kwargs = {} # placeholder for model kwargs
    model = load_pipeline_from_checkpoint(args)
    # clean up temp args
    delattr(args, "pretraining_run_name") 
    delattr(args, "opt_steps")
    # save model 
    config = vars(args)
    model.save_pretrained(
        f"KonradSzafer/{name}",
        config=config,
        push_to_hub=True,
        private=True,
    )
    # check loading from hub
    model = MOMENTPipeline.from_pretrained(
        f"KonradSzafer/{name}",
        model_kwargs={
            "task_name": "classification",
            "n_channels": 1,
            "num_class": 2,
        },
    )
    model.init()
    print(model)