In [80]:
import torch
import torch.nn as nn

class CustomLayer(nn.Module):
    def __init__(self, n_features):
        super(CustomLayer, self).__init__()
        self.n_features = n_features
        # Initialize the interaction tensor as a learnable parameter
        self.interaction_tensor = nn.Parameter(torch.randn(n_features, n_features, n_features))

    def forward(self, x):
        # x is expected to be of size [batch, length, n_features]
        batch, length, n_features = x.size()

        # Process each [n_features] vector across batch and length
        output = x.new_empty(batch, length, n_features)
        for b in range(batch):
            for l in range(length):
                state_tensor = x[b, l, :]  # Shape: [n_features]
                # Step 1: Multiply state tensor by interaction tensor to get transition tensor
                # We manually implement the multiplication to match your operation
                transition_tensor = torch.einsum('i,ijk->jk', state_tensor, self.interaction_tensor)
                # Step 2: Multiply the transition tensor by the state tensor
                # Resulting shape: [n_features]
                output[b, l, :] = torch.matmul(transition_tensor, state_tensor)

        return output


In [81]:
import torch
import torch.nn as nn

class CustomLayerVectorized(nn.Module):
    def __init__(self, n_features):
        super(CustomLayerVectorized, self).__init__()
        self.n_features = n_features
        # Initialize the interaction tensor as a learnable parameter
        self.interaction_tensor = nn.Parameter(torch.randn(n_features, n_features, n_features))

    def forward(self, x):
        # x is expected to be of size [batch, length, n_features]
        batch, length, n_features = x.size()

        # Pre-allocate output tensor
        output = x.new_empty(batch, length, n_features)

        # Loop over length, but vectorize over the batch
        for l in range(length):
            # Extract all vectors at position l across all batches
            state_tensor = x[:, l, :]  # Shape: [batch, n_features]
            
            # Vectorized operation for all batches
            # Step 1: Calculate the transition tensor
            # Since we cannot directly use einsum for batched operation in this specific scenario,
            # we manually broadcast and multiply to achieve the intended result.
            # This involves expanding dimensions to enable broadcasting.
            state_tensor_expanded = state_tensor.unsqueeze(1).expand(-1, n_features, -1)  # Shape: [batch, n_features, n_features]
            interaction_tensor_expanded = self.interaction_tensor.unsqueeze(0).expand(batch, -1, -1, -1)  # Shape: [batch, n_features, n_features, n_features]
            # Multiply and sum over the last dimension to get the transition tensor
            transition_tensor = torch.einsum('bik,bijk->bij', state_tensor_expanded, interaction_tensor_expanded)
            
            # Step 2: Multiply the transition tensor by the state tensor to get the output
            output[:, l, :] = torch.einsum('bij,bj->bi', transition_tensor, state_tensor)

        return output



In [82]:
n_features = 16
batch, length, dim = 2, 64, n_features
x = torch.randn(batch, length, dim)
model = CustomLayerVectorized(
    n_features = n_features
)
y = model(x)

print(x.shape)
assert y.shape == x.shape

torch.Size([2, 64, 16])


In [83]:
class InteractionModule(nn.Module):
    def __init__(self, n_features):
        super(InteractionModule, self).__init__()
        self.n_features = n_features
        # Initialize a set of interaction tensors, one for the state tensor and one for each column of the transition tensor
        self.interaction_tensors = nn.ParameterList([nn.Parameter(torch.randn(n_features, n_features, n_features)) for _ in range(n_features + 1)])

    def forward(self, state_tensor, previous_transition_tensor):
        # Assuming previous_transition_tensors is a list of transition tensors from the previous step
        candidates = []
        for i in range(n_features + 1):
            if i == 0:  # Interaction with the state tensor
                current_tensor = state_tensor
            else:  # Interaction with columns of the previous transition tensor
                current_tensor = previous_transition_tensor[:, :, i - 1]

            interaction_tensor = self.interaction_tensors[i]
            current_tensor_expanded = current_tensor.unsqueeze(1).expand(-1, n_features, -1)  # Shape: [batch, n_features, n_features]
            interaction_tensor_expanded = interaction_tensor.unsqueeze(0).expand(batch, -1, -1, -1)  # Shape: [batch, n_features, n_features, n_features]
            # Multiply and sum over the last dimension to get the transition tensor
            candidate = torch.einsum('bik,bijk->bij', current_tensor_expanded, interaction_tensor_expanded)
            
            candidates.append(candidate)
        return candidates


In [84]:
class SelectorModule(nn.Module):
    def __init__(self, n_features):
        super(SelectorModule, self).__init__()
        # Placeholder for any parameters or initialization
        pass

    def forward(self, candidates):
        # Select one of the candidates based on a criterion
        # This is a placeholder for the selection logic
        # For demonstration, we just return the first candidate
        return candidates[0]


In [85]:
class CustomLayerExtended(nn.Module):
    def __init__(self, n_features):
        super(CustomLayerExtended, self).__init__()
        self.n_features = n_features
        self.interaction_module = InteractionModule(n_features)
        self.selector_module = SelectorModule(n_features)

    def forward(self, x):
        batch, length, n_features = x.size()
        output = x.new_empty(batch, length, n_features)

        # Initialize previous transition tensors (for the first step)
        # Assuming it's a list of zero tensors for simplicity
        previous_transition_tensor = torch.zeros(batch, n_features, n_features)

        for l in range(length):
            state_tensor = x[:, l, :]
            # Generate candidates
            candidates = self.interaction_module(state_tensor, previous_transition_tensor)
            # Select one candidate
            selected_transition_tensor = self.selector_module(candidates)
            # Update the previous_transition_tensors for the next iteration
            previous_transition_tensor = selected_transition_tensor
            # Compute output for this step
            output[:, l, :] = torch.matmul(selected_transition_tensor, state_tensor.unsqueeze(-1)).squeeze(-1)

        return output, selected_transition_tensor


In [91]:
n_features = 16
batch, length, dim = 2, 64, n_features
x = torch.randn(batch, length, dim)
model = CustomLayerExtended(
    n_features = n_features
)
y = model(x)

print(y[0].shape)
assert y[0].shape == x.shape

torch.Size([2, 64, 16])
