In [37]:
import torch
import torch.nn as nn

class CustomLayer(nn.Module):
    def __init__(self, n_features):
        super(CustomLayer, self).__init__()
        self.n_features = n_features
        # Initialize the interaction tensor as a learnable parameter
        self.interaction_tensor = nn.Parameter(torch.randn(n_features, n_features, n_features))

    def forward(self, x):
        # x is expected to be of size [batch, length, n_features]
        batch, length, n_features = x.size()

        # Process each [n_features] vector across batch and length
        output = x.new_empty(batch, length, n_features)
        for b in range(batch):
            for l in range(length):
                state_tensor = x[b, l, :]  # Shape: [n_features]
                # Step 1: Multiply state tensor by interaction tensor to get transition tensor
                # We manually implement the multiplication to match your operation
                transition_tensor = torch.einsum('i,ijk->jk', state_tensor, self.interaction_tensor)
                # Step 2: Multiply the transition tensor by the state tensor
                # Resulting shape: [n_features]
                output[b, l, :] = torch.matmul(transition_tensor, state_tensor)

        return output


In [38]:
import torch
import torch.nn as nn

class CustomLayerVectorized(nn.Module):
    def __init__(self, n_features):
        super(CustomLayerVectorized, self).__init__()
        self.n_features = n_features
        # Initialize the interaction tensor as a learnable parameter
        self.interaction_tensor = nn.Parameter(torch.randn(n_features, n_features, n_features))

    def forward(self, x):
        # x is expected to be of size [batch, length, n_features]
        batch, length, n_features = x.size()

        # Pre-allocate output tensor
        output = x.new_empty(batch, length, n_features)

        # Loop over length, but vectorize over the batch
        for l in range(length):
            # Extract all vectors at position l across all batches
            state_tensor = x[:, l, :]  # Shape: [batch, n_features]
            
            # Vectorized operation for all batches
            # Step 1: Calculate the transition tensor
            # Since we cannot directly use einsum for batched operation in this specific scenario,
            # we manually broadcast and multiply to achieve the intended result.
            # This involves expanding dimensions to enable broadcasting.
            state_tensor_expanded = state_tensor.unsqueeze(1).expand(-1, n_features, -1)  # Shape: [batch, n_features, n_features]
            interaction_tensor_expanded = self.interaction_tensor.unsqueeze(0).expand(batch, -1, -1, -1)  # Shape: [batch, n_features, n_features, n_features]
            # Multiply and sum over the last dimension to get the transition tensor
            transition_tensor = torch.einsum('bik,bijk->bij', state_tensor_expanded, interaction_tensor_expanded)
            
            # Step 2: Multiply the transition tensor by the state tensor to get the output
            output[:, l, :] = torch.einsum('bij,bj->bi', transition_tensor, state_tensor)

        return output



In [39]:
# n_features = 16
# batch, length, dim = 2, 64, n_features
# x = torch.randn(batch, length, dim)
# model = CustomLayerVectorized(
#     n_features = n_features
# )
# y = model(x)

# print(x.shape)
# assert y.shape == x.shape

In [40]:
class InteractionModule(nn.Module):
    def __init__(self, n_features):
        super(InteractionModule, self).__init__()
        self.n_features = n_features
        # Initialize a set of interaction tensors, one for the state tensor and one for each column of the transition tensor
        self.interaction_tensors = nn.ParameterList([nn.Parameter(torch.randn(n_features, n_features, n_features)) for _ in range(n_features + 1)])

    def forward(self, state_tensor, previous_transition_tensor):
        # Get batch size
        batch = state_tensor.shape[0]
        # Assuming previous_transition_tensors is a list of transition tensors from the previous step
        candidates = []
        for i in range(self.n_features + 1):
            if i == 0:  # Interaction with the state tensor
                current_tensor = state_tensor
            else:  # Interaction with columns of the previous transition tensor
                current_tensor = previous_transition_tensor[:, :, i - 1]

            interaction_tensor = self.interaction_tensors[i]
            current_tensor_expanded = current_tensor.unsqueeze(1).expand(-1, self.n_features, -1)  # Shape: [batch, n_features, n_features]
            interaction_tensor_expanded = interaction_tensor.unsqueeze(0).expand(batch, -1, -1, -1)  # Shape: [batch, n_features, n_features, n_features]
            # Multiply and sum over the last dimension to get the transition tensor
            candidate = torch.einsum('bik,bijk->bij', current_tensor_expanded, interaction_tensor_expanded)
            
            candidates.append(candidate)
            
        candidates_tensor = torch.stack(candidates, dim = -1)
        return candidates_tensor


In [41]:
import torch.nn.functional as F

class SelectorModule(nn.Module):
    def __init__(self, num_slices):
        super(SelectorModule, self).__init__()
        # A simple linear layer to compute importance scores for each slice
        self.importance = nn.Linear(num_slices, num_slices)

    def forward(self, x):
        # Assuming x is of shape [batch_size, height, width, num_slices]
        batch_size, height, width, num_slices = x.shape

        # Flatten the spatial dimensions and compute importance scores
        x_flat = x.view(batch_size, -1, num_slices)  # New shape: [batch_size, height*width, num_slices]
        scores = self.importance(x_flat)  # Computes a score for each slice
        scores = scores.view(batch_size, height, width, num_slices)  # Reshape scores back

        # Apply softmax to get a distribution over slices
        weights = F.softmax(scores, dim=-1)

        # Use the weights to get a weighted sum of the slices, effectively selecting slices based on importance
        # This step aggregates the slices into a single output per position
        selected = torch.einsum('bhwn,bhwi->bhw', weights, x)

        return selected

# Example usage
batch_size = 10
tensor = torch.rand(batch_size, 5, 5, 6)  # Example tensor
model = SelectorModule(num_slices=6)

result = model(tensor)
print(result.shape)  # Should print torch.Size([10, 5, 5]), indicating the selection/aggregation step was performed


torch.Size([10, 5, 5])


In [42]:
class CustomLayerExtended(nn.Module):
    def __init__(self, n_features):
        super(CustomLayerExtended, self).__init__()
        self.n_features = n_features
        self.interaction_module = InteractionModule(n_features)
        self.selector_module = SelectorModule(n_features + 1)

    def forward(self, x):
        batch, length, n_features = x.size()
        output = x.new_empty(batch, length, n_features)

        # Initialize previous transition tensors (for the first step)
        # Assuming it's a list of zero tensors for simplicity
        previous_transition_tensor = torch.zeros(batch, n_features, n_features)

        for l in range(length):
            state_tensor = x[:, l, :]
            # Generate candidates
            candidates = self.interaction_module(state_tensor, previous_transition_tensor)
            # Select one candidate
            selected_transition_tensor = self.selector_module(candidates)
            # Update the previous_transition_tensors for the next iteration
            previous_transition_tensor = selected_transition_tensor
            # Compute output for this step
            output[:, l, :] = torch.matmul(selected_transition_tensor, state_tensor.unsqueeze(-1)).squeeze(-1)

        return output, selected_transition_tensor


In [43]:
n_features = 16
batch, length, dim = 13, 64, n_features
x = torch.randn(batch, length, dim)
model = CustomLayerExtended(
    n_features = n_features
)
y = model(x)

print(y[0].shape)
assert y[0].shape == x.shape

torch.Size([13, 64, 16])


In [44]:
print(y)

(tensor([[[ 1.3828e+01,  7.3758e+00, -2.3138e+01,  ..., -2.0552e+01,
           1.2001e+01, -3.8707e+01],
         [ 1.1377e+01, -1.4783e+01, -1.3818e+02,  ..., -1.4889e+02,
          -2.9613e+01,  1.0292e+02],
         [-4.6556e+03, -1.4859e+04,  5.3426e+03,  ...,  4.6355e+03,
          -6.7901e+03, -4.6900e+03],
         ...,
         [        nan,         nan,         nan,  ...,         nan,
                  nan,         nan],
         [        nan,         nan,         nan,  ...,         nan,
                  nan,         nan],
         [        nan,         nan,         nan,  ...,         nan,
                  nan,         nan]],

        [[ 3.2364e+01, -2.3503e+01, -1.8020e+01,  ..., -2.9927e+00,
          -1.3689e+01,  2.3359e+01],
         [ 6.8940e+01,  2.3408e+02,  2.2757e+02,  ...,  4.1637e+01,
           6.7449e+01, -2.5223e+02],
         [-2.2979e+03, -2.1993e+03,  3.0195e+03,  ...,  9.1770e+03,
          -3.3622e+03, -8.3801e+02],
         ...,
         [        nan,  

# Example Training

## Simple Data Generation

In [45]:
import numpy as np
import torch

def generate_multivariate_time_series(n_features, total_length, amplitude=1.0):
    t = np.linspace(0, 4 * np.pi, total_length)
    series = np.zeros((total_length, n_features))
    for i in range(n_features):
        series[:, i] = amplitude * np.cos(t * (i + 1) / n_features)
    return series

def segment_time_series(series, length):
    # Assuming series is a numpy array of shape [total_length, n_features]
    total_length, n_features = series.shape
    segments = []
    for start in range(0, total_length - length, length):
        segment = series[start:start + length]
        segments.append(segment)
    return np.stack(segments)


In [46]:
n_features = 12
length = 64  # Segment length
total_length = 1024  # Arbitrary total length for the generated series

# Generate and segment the time series
series = generate_multivariate_time_series(n_features, total_length)
segments = segment_time_series(series, length)

# Convert to tensors
segments_tensor = torch.tensor(segments, dtype=torch.float)

# Prepare inputs and targets
X = segments_tensor
# Shift segments to the right by one timestep to create the targets
Y = torch.cat((X[:, 1:], X[:, :1]), dim=1)


In [47]:
from torch import nn, optim

# Model
model = CustomLayerExtended(
    n_features = n_features
)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 100
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    # Forward pass
    outputs, _ = model(X)
    loss = criterion(outputs, Y)

    # Backward and optimize
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item()}')


Epoch [10/100], Loss: nan
Epoch [20/100], Loss: nan


KeyboardInterrupt: 