In [4]:
from cloudcasting.models import VariableHorizonModel
import numpy as np

class PersistenceModel(VariableHorizonModel):
    """A persistence model used solely for testing the validation procedure"""

    def forward(self, X):
        latest_frame = X[..., -1:, :, :].copy()

        # The NaN values in the input data could be filled with -1. Clip these to zero
        latest_frame = latest_frame.clip(0, 1)

        return np.repeat(latest_frame, self.rollout_steps, axis=-3)
    
    def hyperparameters_dict(self):
        return {}


example_input = np.ones((3,3,3,3,3))

model = PersistenceModel(history_steps=0, rollout_steps=5)
res = model(example_input)
assert res.shape == (3,3,5,3,3)
