In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def torch_randnorm(size, dim=0):
    # Generate a random tensor
    rand_tensor = torch.rand(size)
    
    # Normalize along the specified dimension
    sum_along_dim = torch.sum(rand_tensor, dim=dim, keepdim=True)
    normalized_tensor = rand_tensor / sum_along_dim
    
    return normalized_tensor

# Example usage
normalized_tensor = torch_randnorm([5,6], dim=1)
print(normalized_tensor)
print(normalized_tensor.sum(dim=1))

tensor([[0.0966, 0.0143, 0.2557, 0.0420, 0.2699, 0.3216],
        [0.2561, 0.0221, 0.1642, 0.1490, 0.2198, 0.1888],
        [0.0931, 0.1505, 0.2764, 0.2581, 0.0254, 0.1965],
        [0.1804, 0.1271, 0.1946, 0.1877, 0.2196, 0.0906],
        [0.0242, 0.0941, 0.1834, 0.2081, 0.1398, 0.3503]])
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


In [3]:
class InteractionModule(nn.Module):
    def __init__(self, n_features):
        super(InteractionModule, self).__init__()
        self.n_features = n_features
        # Initialize a set of interaction tensors, one for the state tensor and one for each column of the transition tensor
        self.interaction_tensors = nn.ParameterList([nn.Parameter(torch_randnorm([n_features, n_features, n_features], dim = 1)) for _ in range(n_features + 1)])

    def forward(self, state_tensor, previous_transition_tensor):
        # Get batch size
        batch = state_tensor.shape[0]
        # Assuming previous_transition_tensors is a list of transition tensors from the previous step
        candidates = []
        for i in range(self.n_features + 1):
            if i == 0:  # Interaction with the state tensor
                current_tensor = state_tensor
            else:  # Interaction with columns of the previous transition tensor
                current_tensor = previous_transition_tensor[:, :, i - 1]

            interaction_tensor = self.interaction_tensors[i]
            current_tensor_expanded = current_tensor.unsqueeze(1).expand(-1, self.n_features, -1)  # Shape: [batch, n_features, n_features]
            interaction_tensor_expanded = interaction_tensor.unsqueeze(0).expand(batch, -1, -1, -1)  # Shape: [batch, n_features, n_features, n_features]
            # Multiply and sum over the last dimension to get the transition tensor
            candidate = torch.einsum('bik,bijk->bij', current_tensor_expanded, interaction_tensor_expanded)
            
            candidates.append(candidate)
            
        candidates_tensor = torch.stack(candidates, dim = -1)
        return candidates_tensor

In [4]:
class SelectorModule(nn.Module):
    def __init__(self, num_slices):
        super(SelectorModule, self).__init__()
        # A simple linear layer to compute importance scores for each slice
        self.importance = nn.Linear(num_slices, num_slices)

    def forward(self, x, temperature=1):
        # Assuming x is of shape [batch_size, height, width, num_slices]
        batch_size, _, _, num_slices = x.shape

        # Compute importance scores by reducing x across spatial dimensions
        # Here, we take the mean of x across the spatial dimensions to get a vector per slice
        x_reduced = x.mean(dim=[1, 2])  # Shape: [batch_size, num_slices]

        # Compute scores for each slice
        scores = self.importance(x_reduced)  # Shape: [batch_size, num_slices]
        
        # Apply Gumbel-Softmax to approximate a discrete selection of slices
        weights = F.gumbel_softmax(scores, tau=temperature, hard=False, dim=-1)
        #weights = torch.softmax(scores,dim=-1)
        #print(weights)
        # Correct application of weights:
        # We need to ensure weights are applied across the num_slices dimension correctly.
        # Since weights are [batch_size, num_slices] and x is [batch_size, height, width, num_slices],
        # we permute x to bring num_slices to the front for broadcasting.
        x_permuted = x.permute(0, 3, 1, 2)  # Shape: [batch_size, num_slices, height, width]
        
        # Now, multiply by weights. We need to reshape weights to [batch_size, num_slices, 1, 1] for broadcasting.
        weighted_slices = x_permuted * weights.view(batch_size, num_slices, 1, 1)
        
        # Finally, sum the weighted slices across the num_slices dimension (now the first dimension after permute)
        selected = weighted_slices.sum(dim=1)  # Shape: [batch_size, height, width]

        return selected

# Example usage
batch_size = 10
tensor = torch.rand(batch_size, 5, 5, 6)  # Example tensor
model = SelectorModule(num_slices=6)

result = model(tensor)
print(result.shape)  # Should print torch.Size([10, 5, 5])

torch.Size([10, 5, 5])


In [5]:
class NeuralCoilLayer(nn.Module):
    def __init__(self, n_features):
        super(NeuralCoilLayer, self).__init__()
        self.n_features = n_features
        self.interaction_module = InteractionModule(n_features)
        self.selector_module = SelectorModule(n_features + 1)
        
    def step_coil(self, state_tensor, previous_transition_tensor):
        # Generate candidates
        candidates = self.interaction_module(state_tensor, previous_transition_tensor)
        # Select one candidate
        selected_transition_tensor = self.selector_module(candidates)
        
        new_state_tensor = torch.matmul(selected_transition_tensor, state_tensor.unsqueeze(-1)).squeeze(-1)
        
        return new_state_tensor, selected_transition_tensor

    def forward(self, x):
        batch, length, n_features = x.size()
        output = x.new_empty(batch, length, n_features)

        # Initialize previous transition tensors (for the first step)
        # Assuming it's a list of zero tensors for simplicity
        transition_tensor = torch.zeros(batch, n_features, n_features)

        for l in range(length):
            state_tensor = x[:, l, :]
            
            # Compute output for this step
            output[:, l, :], transition_tensor = self.step_coil(state_tensor, transition_tensor)

        return output, transition_tensor


# Sequence-to-Sequence Check

In [6]:
n_features = 16
batch, length, dim = 13, 64, n_features
x = torch.randn(batch, length, dim)
model = NeuralCoilLayer(
    n_features = n_features
)
y = model(x)

print(y[0].shape)
assert y[0].shape == x.shape

torch.Size([13, 64, 16])


## Perpetuation Check

In [35]:
n_features = 16
batch, length, dim = 13, 64, n_features
x = torch.randn(batch, length, dim)
model = NeuralCoilLayer(
    n_features = n_features
)

l = 1
state_tensor = x[:, l, :]
transition_tensor = torch.zeros(batch, n_features, n_features)

y = model.step_coil(state_tensor, transition_tensor)
y[0][1,:]

tensor([ 0.0045,  0.0068, -0.0005,  0.0030,  0.0029,  0.0057,  0.0088,  0.0014,
         0.0010, -0.0017,  0.0041,  0.0032, -0.0030,  0.0022,  0.0006, -0.0015],
       grad_fn=<SliceBackward0>)

This should be the same as if we just look at the batches independently

In [36]:
batch = 2
l = 1
state_tensor = x[0:2, l, :]
transition_tensor = torch.zeros(batch, n_features, n_features)

y = model.step_coil(state_tensor, transition_tensor)
y[0][1,:]

tensor([ 0.0284,  0.0425, -0.0034,  0.0188,  0.0179,  0.0359,  0.0551,  0.0085,
         0.0062, -0.0109,  0.0255,  0.0203, -0.0186,  0.0136,  0.0038, -0.0095],
       grad_fn=<SliceBackward0>)

This isn't the same, so something is wrong with this development