In [1]:
from typing import Dict
import torch
import torch.nn as nn
from tft import MyGatedResidualNetwork, MyResampleNorm

class MyVariableSelectionNetwork(nn.Module):
    def __init__(
        self,
        input_sizes: Dict[str, int],
        hidden_size: int,
        input_embedding_flags: Dict[str, bool] = {},
        dropout: float = 0.1,
        context_size: int = None,
        single_variable_grns: Dict[str, MyGatedResidualNetwork] = {},
        prescalers: Dict[str, nn.Linear] = {},
    ):
        """
        Calculate weights for ``num_inputs`` variables  which are each of size ``input_size``
        """
        super().__init__()

        self.hidden_size = hidden_size
        self.input_sizes = input_sizes
        self.input_embedding_flags = input_embedding_flags
        self.dropout = dropout
        self.context_size = context_size

        if self.num_inputs > 1:
            if self.context_size is not None:
                self.flattened_grn = MyGatedResidualNetwork(
                    self.input_size_total,
                    min(self.hidden_size, self.num_inputs),
                    self.num_inputs,
                    self.dropout,
                    self.context_size,
                    residual=False,
                )
            else:
                self.flattened_grn = MyGatedResidualNetwork(
                    self.input_size_total,
                    min(self.hidden_size, self.num_inputs),
                    self.num_inputs,
                    self.dropout,
                    residual=False,
                )

        self.single_variable_grns = nn.ModuleDict()
        self.prescalers = nn.ModuleDict()
        for name, input_size in self.input_sizes.items():
            if name in single_variable_grns:
                self.single_variable_grns[name] = single_variable_grns[name]
            elif self.input_embedding_flags.get(name, False):
                self.single_variable_grns[name] = MyResampleNorm(
                    input_size, self.hidden_size
                )
            else:
                self.single_variable_grns[name] = MyGatedResidualNetwork(
                    input_size,
                    min(input_size, self.hidden_size),
                    output_size=self.hidden_size,
                    dropout=self.dropout,
                )
            if name in prescalers:  # reals need to be first scaled up
                self.prescalers[name] = prescalers[name]
            elif not self.input_embedding_flags.get(name, False):
                self.prescalers[name] = nn.Linear(1, input_size)

        self.softmax = nn.Softmax(dim=-1)

    @property
    def input_size_total(self):
        return sum(
            size if name in self.input_embedding_flags else size
            for name, size in self.input_sizes.items()
        )

    @property
    def num_inputs(self):
        return len(self.input_sizes)

    def forward(self, x: Dict[str, torch.Tensor], context: torch.Tensor = None):
        if self.num_inputs > 1:
            # transform single variables
            var_outputs = []
            weight_inputs = []
            for name in self.input_sizes.keys():
                # select embedding belonging to a single input
                variable_embedding = x[name]
                if name in self.prescalers:
                    variable_embedding = self.prescalers[name](variable_embedding)
                weight_inputs.append(variable_embedding)
                var_outputs.append(self.single_variable_grns[name](variable_embedding))
            var_outputs = torch.stack(var_outputs, dim=-1)

            # calculate variable weights
            ## get all of the embeddings from all of the variables and just combine them, very simple
            flat_embedding = torch.cat(weight_inputs, dim=-1)
            #
            sparse_weights = self.flattened_grn(flat_embedding, context)
            sparse_weights = self.softmax(sparse_weights).unsqueeze(-2)

            outputs = var_outputs * sparse_weights
            outputs = outputs.sum(dim=-1)

        else:  # for one input, do not perform variable selection, just encoding
            pass
            

        # the outputs are a weighted sum of the importance of each variable for the current time step?
        return outputs, sparse_weights

In [9]:
x = {
    "fueltype": torch.tensor(
        [
            [1.3122, 2.0916, 0.4749, 2.5620, -2.2733],
            [0.3703, -1.0351, -0.2936, 1.7159, 0.6043],
            [0.6815, -0.5216, -0.5855, -1.4212, 0.9495],
            [-0.4051, -1.1760, 0.8423, -0.3982, 0.0264],
            [1.6051, 1.6294, -1.8574, -0.9640, -0.3509],
            [0.5635, -1.2075, -0.0809, -1.2652, -0.8209],
            [-2.4415, -0.5139, -0.5364, -0.1024, 0.3291],
            [0.5885, 0.1993, -1.9647, -0.0054, -1.3004],
        ],
        device="cpu",
    ),
    "encoder_length": torch.tensor(
        [[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0]], device="cpu"
    ),
    "value_center": torch.tensor(
        [
            [0.5803],
            [1.6990],
            [1.4191],
            [-0.8215],
            [-0.7555],
            [-0.8013],
            [-0.7047],
            [-0.6153],
        ],
        device="cpu",
    ),
    "value_scale": torch.tensor(
        [
            [1.3127],
            [2.0438],
            [-0.2789],
            [-0.7635],
            [-0.7440],
            [-0.6782],
            [-0.5635],
            [-0.3283],
        ],
        device="cpu",
    ),
}


device = torch.device("cpu")  # or torch.device('cuda') if using GPU and it's available

x = {name: tensor.to(device) for name, tensor in x.items()}

# Create an instance of MyVariableSelectionNetwork
vsn = MyVariableSelectionNetwork(
    input_sizes={'fueltype': 5, 'encoder_length': 16, 'value_center': 16, 'value_scale': 16},
    hidden_size=64,
    input_embedding_flags={'fueltype':True},
    dropout=0.1,
)

vsn.to(device)


outputs, weights = vsn(x)

# Print the shapes of the outputs and weights
print("Outputs shape:", outputs.shape)
print("Weights shape:", weights.shape)

Outputs shape: torch.Size([8, 64])
Weights shape: torch.Size([8, 1, 4])


In [10]:
outputs[0]

tensor([-0.1079, -0.1849, -0.0550,  0.1918,  0.1365,  0.3658,  0.4432,  0.3992,
         0.7487,  0.7473,  0.7420,  0.6947,  0.6873,  0.4913,  0.4800,  0.5437,
         0.8107,  0.4517,  0.7231,  0.4370,  0.5048,  0.1436,  0.2879,  0.1371,
         0.0751,  0.1137, -0.1464, -0.2281, -0.3206, -0.3082, -0.3979, -0.3520,
        -0.1593, -0.0053,  0.2898, -0.0181, -0.0355, -0.2747, -0.2707,  0.0326,
         0.1103,  0.0876,  0.1589,  0.1618,  0.1073,  0.0519,  0.2245,  0.4107,
         0.3577,  0.2566, -0.0041, -0.0374, -0.0417, -0.1023, -0.1377, -0.1823,
        -0.3507, -0.5448, -0.7976, -0.8944, -1.1224, -1.5733, -1.9310, -2.0214],
       grad_fn=<SelectBackward0>)

In [11]:
weights[0]

tensor([[0.6045, 0.0711, 0.2730, 0.0514]], grad_fn=<SelectBackward0>)