In [219]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [275]:
class NeuralCoilLayer(nn.Module):
    def __init__(self, n_features, n_batch, device = "cpu"):
        super(NeuralCoilLayer, self).__init__()
        self.n_features = n_features
        self.attention_weights = nn.Linear(n_features, 1, bias=False)
        self.act = nn.SiLU()
        self.interaction_tensors = nn.Parameter(torch.rand(n_features, n_features, n_features, n_features + 1))
        self.topk_num = 1
        
        starting_tensor = torch.softmax(torch.ones(n_batch, n_features, n_features), dim = 1)
        if device == "cuda":
            self.starting_transition_tensor = starting_tensor.to("cuda")
        else:
            self.starting_transition_tensor = starting_tensor
        
    def step_coil(self, state_tensor, previous_transition_tensor):
        # Establish normalized subgroups
        norm_subgroups = torch.cat((state_tensor.unsqueeze(-1), previous_transition_tensor), dim=2)
        #print(norm_subgroups[5,:,3].sum())
        batch_size, num_groups, n_features = norm_subgroups.shape
        
        # Compute scores for each normalized subgroup
        scores = self.act(self.attention_weights(norm_subgroups.permute(0,2,1))).sum(-1) # [batch_size, num_groups]
        
        weights = torch.softmax(scores, dim = -1) # [batch_size, num_groups]
        #print("Norm Subgroups Sum: ", norm_subgroups.sum(-2))
        selected_norm_tensor = (torch.mul(norm_subgroups, weights.unsqueeze(1))).sum(-1)
        #print("Selected Norm Tensor Shape: ", selected_norm_tensor.shape)
        #print("Selected Norm Tensor Sum: ", selected_norm_tensor.sum(1))
        
        # If you don't want to need to force conservation of states and transition
        selected_interaction_tensors = torch.softmax(self.interaction_tensors * n_features * n_features, dim = 0)
        #selected_interaction_tensors = self.interaction_tensors
        selected_norm_subgroups = norm_subgroups

        selected_transition_tensors = (torch.mul(selected_interaction_tensors, selected_norm_subgroups.unsqueeze(1).unsqueeze(1))).sum(-2)
        
        # We need a single transition tensor so we will average this as well

        selected_transition_tensor = (torch.mul(selected_transition_tensors, weights.unsqueeze(-2).unsqueeze(-2))).sum(-1)
        #selected_transition_tensor = selected_transition_tensors[:,:,:,5]
        
        # Generate state tensor from the transition tensor
        # Unsqueezing state_tensor to make it [batch_size, n_features, 1] for matrix multiplication
        state_tensor_unsqueezed = state_tensor.unsqueeze(2)

        # Performing batch matrix multiplication
        new_state_tensor_bmm = torch.bmm(selected_transition_tensor, state_tensor_unsqueezed)

        # Squeezing the result to get rid of the extra dimension, resulting in [batch_size, n_features]
        new_state_tensor = new_state_tensor_bmm.squeeze(2)
        
        new_state_tensor = torch.softmax(new_state_tensor, dim = 1)
        #selected_transition_tensor = torch.softmax(selected_transition_tensor, dim = 1)
        
        print("Selected Transition Tensor Sum: ", sum(selected_transition_tensor[0,:,1]))
        return new_state_tensor, selected_transition_tensor


    def forward(self, x):
        batch, length, n_features = x.size()
        output = x.new_empty(batch, length, n_features)

        # Initialize previous transition tensors (for the first step)
        # Assuming it's a list of zero tensors for simplicity
        transition_tensor = self.starting_transition_tensor

        for l in range(length):
            state_tensor = x[:, l, :]
            
            # Compute output for this step
            output[:, l, :], transition_tensor = self.step_coil(state_tensor, transition_tensor)

        return output, transition_tensor


# Sequence-to-Sequence Check

In [276]:
n_features = 16
batch, length, dim = 13, 64, n_features
x = torch.softmax(torch.randn(batch, length, dim), dim = 2)
model = NeuralCoilLayer(
    n_features = n_features,
    n_batch=batch
)
y = model(x)


print(y[0].shape)
assert y[0].shape == x.shape

Selected Transition Tensor Sum:  tensor(7.9945, grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(60.6788, grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(462.4091, grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(3658.4668, grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(29068.2168, grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(230490.1406, grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(1856841.6250, grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(14604150., grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(1.3051e+08, grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(9.9343e+08, grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(8.2972e+09, grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(6.2527e+10, grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(4.9473e+11, grad_fn=<AddBackward0>)
Selected Transition 

In [277]:
y[0][0,5,:].sum()

tensor(1., grad_fn=<SumBackward0>)

In [278]:
X= x.to("cuda")
model = model.to("cuda")

In [279]:
import plotly.graph_objects as go

states = []
# Select the batch we want to make predictions for
batch = 10


# Grab starting state tensor
state_tensor = X[:,0,:]

# How many steps do we want to run the coil overall?
max_steps = 50

batch_size = X.shape[0]
transition_tensor = torch.softmax(torch.zeros(batch_size, n_features, n_features), dim = 1).to("cuda")
for step_state in range(max_steps):
    state_tensor, transition_tensor = model.step_coil(state_tensor, transition_tensor)
    states.append(state_tensor[batch,:])

# Move state dynamics to CPU
data = [row.to('cpu').detach().numpy() for row in states]
# Transpose the data to get 5 traces
traces = list(zip(*data))


# Create the figure and add traces
fig = go.Figure()

# Plotting
for i, trace in enumerate(traces):
    model_trace = go.Scatter(y=trace, mode='lines', name=f'State {i}')
    fig.add_trace(model_trace)

# Add title and labels
fig.update_layout(title=f'Self-Perpetuating Coil Dynamics',
                    xaxis_title='Timestep',
                    yaxis_title='Value')

# Show the figure
fig.show()

Selected Transition Tensor Sum:  tensor(7.9945, device='cuda:0', grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(60.7646, device='cuda:0', grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(462.4134, device='cuda:0', grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(3658.1311, device='cuda:0', grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(29066.8066, device='cuda:0', grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(230362.3750, device='cuda:0', grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(1854418.8750, device='cuda:0', grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(14585371., device='cuda:0', grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(1.3035e+08, device='cuda:0', grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(9.9215e+08, device='cuda:0', grad_fn=<AddBackward0>)
Selected Transition Tensor Sum:  tensor(8.2866e+09, device='cuda:0', grad_fn=<