In [None]:
#|default_exp models

In [1]:
# | export

import torch
import torch.nn as nn

class PersistenceModel(nn.Module):
    def __init__(self, horizon):
        """
        Initialize the PersistenceModel.

        Parameters:
        - c_in (int): Number of input channels/features (not used directly).
        - c_out (int): The output size, i.e., the horizon of the forecast.
        - seq_len (int): Length of the input sequence (not used directly).
        - pred_dim (int): Number of predictive dimensions (not used directly).
        - kwargs: Additional keyword arguments for future flexibility.
        """
        super(PersistenceModel, self).__init__()
        self.horizon = horizon

    def step(self): pass
    def zero_grad(self): pass

    def forward(self, x):
        last_values = x[:,:,-1].unsqueeze(-1)
        output = last_values.expand(-1, -1, self.horizon) 
        return output
    

In [3]:
# Test
model = PersistenceModel(6)
x = torch.rand(2, 4, 6)
output = model(x)
print(output.shape)  # torch.Size([10, 1, 6])
print(output, "\n" ,x)

torch.Size([2, 4, 6])
tensor([[[0.7197, 0.7197, 0.7197, 0.7197, 0.7197, 0.7197],
         [0.6504, 0.6504, 0.6504, 0.6504, 0.6504, 0.6504],
         [0.4552, 0.4552, 0.4552, 0.4552, 0.4552, 0.4552],
         [0.0602, 0.0602, 0.0602, 0.0602, 0.0602, 0.0602]],

        [[0.7562, 0.7562, 0.7562, 0.7562, 0.7562, 0.7562],
         [0.8653, 0.8653, 0.8653, 0.8653, 0.8653, 0.8653],
         [0.6855, 0.6855, 0.6855, 0.6855, 0.6855, 0.6855],
         [0.3235, 0.3235, 0.3235, 0.3235, 0.3235, 0.3235]]]) 
 tensor([[[0.3362, 0.5141, 0.4492, 0.6180, 0.3601, 0.7197],
         [0.0098, 0.4903, 0.8631, 0.8351, 0.2197, 0.6504],
         [0.0901, 0.8371, 0.8963, 0.9179, 0.7690, 0.4552],
         [0.1844, 0.3814, 0.9163, 0.4824, 0.1221, 0.0602]],

        [[0.1830, 0.2040, 0.0249, 0.2651, 0.9778, 0.7562],
         [0.5694, 0.1777, 0.0791, 0.9147, 0.0812, 0.8653],
         [0.6298, 0.8184, 0.4935, 0.6241, 0.0561, 0.6855],
         [0.7608, 0.9856, 0.5084, 0.3116, 0.2454, 0.3235]]])
