In [None]:
#|default_exp models

In [1]:
# | export

import torch
import torch.nn as nn

class PersistenceModel(nn.Module):
    def __init__(self, c_in, c_out, custom_head=None):
        """
        Initialize the PersistenceModel.

        Parameters:
        - c_in (int): Number of input channels/features (not used directly).
        - c_out (int): The output size, i.e., the horizon of the forecast.
        - seq_len (int): Length of the input sequence (not used directly).
        - pred_dim (int): Number of predictive dimensions (not used directly).
        - kwargs: Additional keyword arguments for future flexibility.
        """
        super(PersistenceModel, self).__init__()
        self.c_out = c_out
        self.c_in = c_in

    def step(self): pass
    def zero_grad(self): pass

    def forward(self, x):
        last_values = x[:,:,-1].unsqueeze(-1)  # Shape (batch_size, 1)
        output = last_values.expand(-1, -1, 6)  # Shape (batch_size, c_out)
        return output

    

In [2]:
# Test
model = PersistenceModel(1, 6)
x = torch.rand(2, 4, 6)
output = model(x)
print(output.shape)  # torch.Size([10, 1, 6])
print(output, "\n" ,x)

torch.Size([2, 4, 6])
tensor([[[0.4111, 0.4111, 0.4111, 0.4111, 0.4111, 0.4111],
         [0.9139, 0.9139, 0.9139, 0.9139, 0.9139, 0.9139],
         [0.7423, 0.7423, 0.7423, 0.7423, 0.7423, 0.7423],
         [0.6584, 0.6584, 0.6584, 0.6584, 0.6584, 0.6584]],

        [[0.8372, 0.8372, 0.8372, 0.8372, 0.8372, 0.8372],
         [0.2459, 0.2459, 0.2459, 0.2459, 0.2459, 0.2459],
         [0.7781, 0.7781, 0.7781, 0.7781, 0.7781, 0.7781],
         [0.3944, 0.3944, 0.3944, 0.3944, 0.3944, 0.3944]]]) 
 tensor([[[0.5400, 0.0749, 0.6305, 0.8589, 0.6191, 0.4111],
         [0.8351, 0.2516, 0.6318, 0.3392, 0.9685, 0.9139],
         [0.1158, 0.9753, 0.6490, 0.5867, 0.7059, 0.7423],
         [0.4138, 0.0823, 0.7121, 0.3450, 0.2967, 0.6584]],

        [[0.7524, 0.8181, 0.1844, 0.6702, 0.1589, 0.8372],
         [0.5773, 0.9710, 0.4841, 0.0901, 0.3357, 0.2459],
         [0.7829, 0.2129, 0.7987, 0.6491, 0.8732, 0.7781],
         [0.8781, 0.5614, 0.2783, 0.5092, 0.4622, 0.3944]]])
