In [156]:
import torch
import torch.nn as nn

class CustomLayer(nn.Module):
    def __init__(self, n_features):
        super(CustomLayer, self).__init__()
        self.n_features = n_features
        # Initialize the interaction tensor as a learnable parameter
        self.interaction_tensor = nn.Parameter(torch.randn(n_features, n_features, n_features))

    def forward(self, x):
        # x is expected to be of size [batch, length, n_features]
        batch, length, n_features = x.size()

        # Process each [n_features] vector across batch and length
        output = x.new_empty(batch, length, n_features)
        for b in range(batch):
            for l in range(length):
                state_tensor = x[b, l, :]  # Shape: [n_features]
                # Step 1: Multiply state tensor by interaction tensor to get transition tensor
                # We manually implement the multiplication to match your operation
                transition_tensor = torch.einsum('i,ijk->jk', state_tensor, self.interaction_tensor)
                # Step 2: Multiply the transition tensor by the state tensor
                # Resulting shape: [n_features]
                output[b, l, :] = torch.matmul(transition_tensor, state_tensor)

        return output


In [157]:
import torch
import torch.nn as nn

class CustomLayerVectorized(nn.Module):
    def __init__(self, n_features):
        super(CustomLayerVectorized, self).__init__()
        self.n_features = n_features
        # Initialize the interaction tensor as a learnable parameter
        self.interaction_tensor = nn.Parameter(torch.randn(n_features, n_features, n_features))

    def forward(self, x):
        # x is expected to be of size [batch, length, n_features]
        batch, length, n_features = x.size()

        # Pre-allocate output tensor
        output = x.new_empty(batch, length, n_features)

        # Loop over length, but vectorize over the batch
        for l in range(length):
            # Extract all vectors at position l across all batches
            state_tensor = x[:, l, :]  # Shape: [batch, n_features]
            
            # Vectorized operation for all batches
            # Step 1: Calculate the transition tensor
            # Since we cannot directly use einsum for batched operation in this specific scenario,
            # we manually broadcast and multiply to achieve the intended result.
            # This involves expanding dimensions to enable broadcasting.
            state_tensor_expanded = state_tensor.unsqueeze(1).expand(-1, n_features, -1)  # Shape: [batch, n_features, n_features]
            interaction_tensor_expanded = self.interaction_tensor.unsqueeze(0).expand(batch, -1, -1, -1)  # Shape: [batch, n_features, n_features, n_features]
            # Multiply and sum over the last dimension to get the transition tensor
            transition_tensor = torch.einsum('bik,bijk->bij', state_tensor_expanded, interaction_tensor_expanded)
            
            # Step 2: Multiply the transition tensor by the state tensor to get the output
            output[:, l, :] = torch.einsum('bij,bj->bi', transition_tensor, state_tensor)

        return output



In [158]:
# n_features = 16
# batch, length, dim = 2, 64, n_features
# x = torch.randn(batch, length, dim)
# model = CustomLayerVectorized(
#     n_features = n_features
# )
# y = model(x)

# print(x.shape)
# assert y.shape == x.shape

In [159]:
import torch

def torch_randnorm(size, dim=0):
    # Generate a random tensor
    rand_tensor = torch.rand(size)
    
    # Normalize along the specified dimension
    sum_along_dim = torch.sum(rand_tensor, dim=dim, keepdim=True)
    normalized_tensor = rand_tensor / sum_along_dim
    
    return normalized_tensor

# Example usage
normalized_tensor = torch_randnorm([5,6], dim=1)
print(normalized_tensor)
print(normalized_tensor.sum(dim=1))

tensor([[0.3178, 0.0869, 0.1055, 0.0041, 0.1740, 0.3116],
        [0.1301, 0.0681, 0.2564, 0.1927, 0.1367, 0.2159],
        [0.0445, 0.2007, 0.1324, 0.2821, 0.1587, 0.1816],
        [0.0806, 0.1735, 0.2114, 0.1402, 0.2615, 0.1327],
        [0.3120, 0.0053, 0.2000, 0.2640, 0.0095, 0.2092]])
tensor([1., 1., 1., 1., 1.])


In [160]:
class InteractionModule(nn.Module):
    def __init__(self, n_features):
        super(InteractionModule, self).__init__()
        self.n_features = n_features
        # Initialize a set of interaction tensors, one for the state tensor and one for each column of the transition tensor
        self.interaction_tensors = nn.ParameterList([nn.Parameter(torch_randnorm([n_features, n_features, n_features], dim = 1)) for _ in range(n_features + 1)])

    def forward(self, state_tensor, previous_transition_tensor):
        # Get batch size
        batch = state_tensor.shape[0]
        # Assuming previous_transition_tensors is a list of transition tensors from the previous step
        candidates = []
        for i in range(self.n_features + 1):
            if i == 0:  # Interaction with the state tensor
                current_tensor = state_tensor
            else:  # Interaction with columns of the previous transition tensor
                current_tensor = previous_transition_tensor[:, :, i - 1]

            interaction_tensor = self.interaction_tensors[i]
            current_tensor_expanded = current_tensor.unsqueeze(1).expand(-1, self.n_features, -1)  # Shape: [batch, n_features, n_features]
            interaction_tensor_expanded = interaction_tensor.unsqueeze(0).expand(batch, -1, -1, -1)  # Shape: [batch, n_features, n_features, n_features]
            # Multiply and sum over the last dimension to get the transition tensor
            candidate = torch.einsum('bik,bijk->bij', current_tensor_expanded, interaction_tensor_expanded)
            
            candidates.append(candidate)
            
        candidates_tensor = torch.stack(candidates, dim = -1)
        return candidates_tensor


In [161]:
import torch.nn.functional as F

class SelectorModule(nn.Module):
    def __init__(self, num_slices):
        super(SelectorModule, self).__init__()
        # A simple linear layer to compute importance scores for each slice
        self.importance = nn.Linear(num_slices, num_slices)

    def forward(self, x):
        # Assuming x is of shape [batch_size, height, width, num_slices]
        batch_size, height, width, num_slices = x.shape

        # Flatten the spatial dimensions and compute importance scores
        x_flat = x.view(batch_size, -1, num_slices)  # New shape: [batch_size, height*width, num_slices]
        scores = self.importance(x_flat)  # Computes a score for each slice
        scores = scores.view(batch_size, height, width, num_slices)  # Reshape scores back

        # Apply softmax to get a distribution over slices
        weights = F.softmax(scores, dim=-1)

        # Use the weights to get a weighted sum of the slices, effectively selecting slices based on importance
        # This step aggregates the slices into a single output per position
        selected = torch.einsum('bhwn,bhwi->bhw', weights, x)
        
        # OVERWRITE LATER ####
        
        selected = x[:,:,:,0]

        return selected

# Example usage
batch_size = 10
tensor = torch.rand(batch_size, 5, 5, 6)  # Example tensor
model = SelectorModule(num_slices=6)

result = model(tensor)
print(result.shape)  # Should print torch.Size([10, 5, 5]), indicating the selection/aggregation step was performed


torch.Size([10, 5, 5])


In [162]:
class CustomLayerExtended(nn.Module):
    def __init__(self, n_features):
        super(CustomLayerExtended, self).__init__()
        self.n_features = n_features
        self.interaction_module = InteractionModule(n_features)
        self.selector_module = SelectorModule(n_features + 1)

    def forward(self, x):
        batch, length, n_features = x.size()
        output = x.new_empty(batch, length, n_features)

        # Initialize previous transition tensors (for the first step)
        # Assuming it's a list of zero tensors for simplicity
        previous_transition_tensor = torch.zeros(batch, n_features, n_features)

        for l in range(length):
            state_tensor = x[:, l, :]
            # Generate candidates
            candidates = self.interaction_module(state_tensor, previous_transition_tensor)
            # Select one candidate
            selected_transition_tensor = self.selector_module(candidates)
            # Update the previous_transition_tensors for the next iteration
            previous_transition_tensor = selected_transition_tensor
            # Compute output for this step
            output[:, l, :] = torch.matmul(selected_transition_tensor, state_tensor.unsqueeze(-1)).squeeze(-1)

        return output, selected_transition_tensor


In [163]:
n_features = 16
batch, length, dim = 13, 64, n_features
x = torch.randn(batch, length, dim)
model = CustomLayerExtended(
    n_features = n_features
)
y = model(x)

print(y[0].shape)
assert y[0].shape == x.shape

torch.Size([13, 64, 16])


In [164]:
print(y)

(tensor([[[-1.1534e-01,  4.4807e-01,  1.8964e-02,  ...,  9.8030e-02,
           3.2010e-01,  1.8228e-01],
         [ 6.8518e-01,  1.9095e+00,  1.1681e+00,  ...,  9.6666e-01,
          -1.8480e-01,  8.9525e-01],
         [-3.5934e-01,  6.4750e-01, -1.9537e-01,  ...,  8.2446e-02,
           4.7636e-01, -3.6858e-01],
         ...,
         [ 4.0059e+00,  3.4427e+00,  3.9373e+00,  ...,  3.5967e+00,
           4.5405e+00,  3.9445e+00],
         [ 5.8643e-01, -3.0941e-02,  1.1523e+00,  ...,  3.0185e-01,
           3.9314e-01,  1.9763e+00],
         [-4.9121e-01, -1.4573e+00,  1.5179e-03,  ...,  1.0944e+00,
           8.2960e-01, -5.3260e-01]],

        [[ 9.1830e-02, -4.3638e-01,  1.4714e-01,  ..., -2.3673e-01,
          -1.3980e-03,  1.0617e-01],
         [ 9.9361e-01, -4.0394e-01, -6.1018e-01,  ..., -4.3683e-01,
           3.1980e-01, -7.2388e-01],
         [ 1.1947e+00,  3.5739e-01,  1.0453e+00,  ...,  1.2544e+00,
           6.1229e-01,  1.5554e+00],
         ...,
         [-5.5806e-02,  

# Example Training

## Simple Data Generation

In [165]:
import numpy as np
import torch

def generate_normalized_multivariate_time_series(n_features, total_length, amplitude=1.0):
    t = np.linspace(0, 100 * np.pi, total_length)
    series = np.zeros((total_length, n_features))
    for i in range(n_features):
        series[:, i] = amplitude * np.cos(t * (i + 1) / n_features) + 10
    
    # Normalize such that each timestep's values sum to 1
    series_sum = np.sum(series, axis=1, keepdims=True)
    series_normalized = series / series_sum
    
    return series_normalized

def segment_time_series(series, length):
    # Assuming series is a numpy array of shape [total_length, n_features]
    total_length, n_features = series.shape
    segments = []
    for start in range(0, total_length - length, length):
        segment = series[start:start + length]
        segments.append(segment)
    return np.stack(segments)


In [166]:
n_features = 4
length = 64  # Segment length
total_length = 1024  # Arbitrary total length for the generated series

# Generate and segment the time series
series = generate_normalized_multivariate_time_series(n_features, total_length)
series_x = series[:-1,]
series_y = series[1:,]

segments_x = segment_time_series(series_x, length)
segments_y = segment_time_series(series_y, length)

# Convert to tensors
segments_tensor_x = torch.tensor(segments_x, dtype=torch.float)
segments_tensor_y = torch.tensor(segments_y, dtype=torch.float)

# Prepare inputs and targets
X = segments_tensor_x
# Shift segments to the right by one timestep to create the targets
Y =  segments_tensor_y


In [167]:
from torch import nn, optim

# Model
model = CustomLayerExtended(
    n_features = n_features
)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 100
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    # Forward pass
    outputs, _ = model(X)
    loss = criterion(outputs, Y)

    # Backward and optimize
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item()}')


Epoch [10/100], Loss: 0.00024314230540767312
Epoch [20/100], Loss: 0.00023051956668496132
Epoch [30/100], Loss: 0.0002183038741350174
Epoch [40/100], Loss: 0.0002065545559162274
Epoch [50/100], Loss: 0.00019528964185155928
Epoch [60/100], Loss: 0.0001845124934334308
Epoch [70/100], Loss: 0.00017421766824554652
Epoch [80/100], Loss: 0.00016439496539533138
Epoch [90/100], Loss: 0.00015503217582590878
Epoch [100/100], Loss: 0.00014611614460591227


In [168]:
torch.stack([X[1,:,1],Y[1,:,1]], dim = -1)

tensor([[0.2304, 0.2339],
        [0.2339, 0.2378],
        [0.2378, 0.2419],
        [0.2419, 0.2460],
        [0.2460, 0.2497],
        [0.2497, 0.2529],
        [0.2529, 0.2553],
        [0.2553, 0.2569],
        [0.2569, 0.2577],
        [0.2577, 0.2576],
        [0.2576, 0.2570],
        [0.2570, 0.2558],
        [0.2558, 0.2545],
        [0.2545, 0.2530],
        [0.2530, 0.2518],
        [0.2518, 0.2508],
        [0.2508, 0.2502],
        [0.2502, 0.2500],
        [0.2500, 0.2503],
        [0.2503, 0.2511],
        [0.2511, 0.2522],
        [0.2522, 0.2535],
        [0.2535, 0.2549],
        [0.2549, 0.2562],
        [0.2562, 0.2572],
        [0.2572, 0.2577],
        [0.2577, 0.2575],
        [0.2575, 0.2565],
        [0.2565, 0.2546],
        [0.2546, 0.2520],
        [0.2520, 0.2486],
        [0.2486, 0.2447],
        [0.2447, 0.2406],
        [0.2406, 0.2365],
        [0.2365, 0.2327],
        [0.2327, 0.2295],
        [0.2295, 0.2269],
        [0.2269, 0.2253],
        [0.2

In [174]:
import plotly.graph_objects as go

def plot_model_output_vs_target(model_outputs, targets, batch_index=0, feature_index=0):
    # Extract the specified feature for the given batch from both the model outputs and targets
    model_output_series = model_outputs[batch_index, :, feature_index].detach().numpy()
    target_series = targets[batch_index, :, feature_index].numpy()
    
    # Create a range for the x-axis (timesteps)
    timesteps = list(range(model_output_series.shape[0]))
    
    # Create traces
    model_trace = go.Scatter(x=timesteps, y=model_output_series, mode='lines', name='Model Output')
    target_trace = go.Scatter(x=timesteps, y=target_series, mode='lines', name='Target')
    
    # Create the figure and add traces
    fig = go.Figure()
    fig.add_trace(model_trace)
    fig.add_trace(target_trace)
    
    # Add title and labels
    fig.update_layout(title=f'Model Output vs Target for Feature {feature_index}, Batch {batch_index}',
                      xaxis_title='Timestep',
                      yaxis_title='Value')
    
    # Show the figure
    fig.show()

# Assuming `y` and `Y` are your model outputs and targets, respectively
# Adjust batch_index and feature_index as needed
plot_model_output_vs_target(outputs, Y, batch_index=2, feature_index=2)
