In [1]:
import torch.nn as nn
import torch.nn.functional as F


class NHITSBlock(nn.Module):
    def __init__(self, input_steps, input_size, output_size, hidden_size):
        super(NHITSBlock, self).__init__()
        self.input_steps = input_steps
        self.input_size = input_size
        self.output_size = output_size

        # Fully connected layers
        self.fc1 = nn.Linear(input_steps * input_size, hidden_size)  # Flattened input
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc_out = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # Flatten the input: (batch_size, input_steps, input_size) -> (batch_size, input_steps * input_size)
        x = x.view(x.size(0), -1)  # Flatten the input
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        output = self.fc_out(x)
        return output

In [2]:
#Simple task model
"""class NHITSModel(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, input_steps, num_blocks):
        super(NHITSModel, self).__init__()
        self.num_blocks = num_blocks
        self.blocks = nn.ModuleList([
            NHITSBlock(input_size, output_size, hidden_size, input_steps)
            for _ in range(num_blocks)
        ])

    def forward(self, x):
        for block in self.blocks:
            x = block(x)  # Forward pass through each NHITS block
        return x"""

'class NHITSModel(nn.Module):\n    def __init__(self, input_size, output_size, hidden_size, input_steps, num_blocks):\n        super(NHITSModel, self).__init__()\n        self.num_blocks = num_blocks\n        self.blocks = nn.ModuleList([\n            NHITSBlock(input_size, output_size, hidden_size, input_steps)\n            for _ in range(num_blocks)\n        ])\n\n    def forward(self, x):\n        for block in self.blocks:\n            x = block(x)  # Forward pass through each NHITS block\n        return x'

In [2]:
#RESIDUAL MODEL
class NHITS(nn.Module):
    def __init__(self, input_steps, input_size, output_size, hidden_size, num_stacks):
        super(NHITS, self).__init__()
        self.num_stacks = num_stacks
        self.input_steps = input_steps
        self.input_size = input_size
        self.output_size = output_size

        # Create multiple NHITS blocks
        self.blocks = nn.ModuleList([
            NHITSBlock(input_steps, input_size, output_size, hidden_size) for _ in range(num_stacks)
        ])

        # Projection layer to map residuals back to input_size
        self.projection = nn.Linear(output_size, input_size)

    def forward(self, x):
        # Initialize residual as input
        residual = x
        forecasts = []

        for block in self.blocks:
            # Predict using the block
            forecast = block(residual)
            forecasts.append(forecast)

            # Update residual and project it back to input size
            residual_update = self.projection(forecast)

            # Expand the residual_update to match the shape of residual (broadcasting)
            residual_update_expanded = residual_update.unsqueeze(1).expand_as(residual)
            residual = residual - residual_update_expanded

        # Combine all forecasts
        final_forecast = sum(forecasts)
        return final_forecast