In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [10]:
def expand_angle_to_vector(phi: float, n: int = 10) -> np.ndarray:
    """
    Expands an angular input phi (in radians) into a feature vector of size n.
    The expansion includes sine and cosine terms for multiple frequencies,
    effectively creating a Fourier series-like representation.

    Args:
        phi (float): the angle in radians.
        n (int): the size of the output feature vector. Should be even.

    Returns:
        numpy array, the expanded feature vector for the angle.

    Raises:
        ValueError: if n is not an even number.

    Notes:
        - As i increases, the model becomes more sensitive to smaller changes in phi.
            This is because a small change in phi will result in a larger change in sin((i + 1) * phi) and cos((i + 1) * phi)
            for higher values of i. This can help the model to detect and respond to finer nuances in the data related to orientation.
        - The combination of sine and cosine functions at different frequencies provides
            a diverse set of features that can help the model to disentangle and learn complex patterns related to angular orientation.
        - if i were used directly without adding 1, the first pair of sine and cosine ((i = 0)) would always be sin(0) = 0 and cos(0) = 1,
            which does not provide meaningful information about (\phi). By starting with ((i + 1)),
            we ensure that the first pair effectively encodes the original angle, and subsequent pairs encode its harmonics.
    """
    # Ensure n is even for balanced sine and cosine features
    if n % 2 != 0:
        raise ValueError("n must be an even number.")

    # Initialize the feature vector
    feature_vector = np.zeros(n)

    # Fill the feature vector with sine and cosine terms
    for i in range(n // 2):
        feature_vector[2*i] = np.sin((i + 1) * phi)
        feature_vector[2*i + 1] = np.cos((i + 1) * phi)

    return feature_vector

# Example usage
phi = np.pi / 4  # 45 degrees in radians
phi = np.random.uniform(0, 2*np.pi)  # Random angle
n = 10  # Size of the feature vector
expanded_feature_vector = expand_angle_to_vector(phi, n)
print("Expanded Feature Vector:", expanded_feature_vector)

Expanded Feature Vector: [-0.7450707  -0.6669855   0.9939027  -0.11026069 -0.58076668  0.81407006
 -0.21917679 -0.97568516  0.87314217  0.48746565]


In [31]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiheadSelfAttention(nn.Module):
    """
    Implements a Multihead Self-Attention mechanism.

    Attributes:
        num_heads (int): Number of attention heads.
        dim_q (int): Dimensionality of query vectors.
        dim_k (int): Dimensionality of key vectors.
        dim_v (int): Dimensionality of value vectors.
        query (nn.Linear): Linear transformation for query vectors.
        key (nn.Linear): Linear transformation for key vectors.
        value (nn.Linear): Linear transformation for value vectors.
        unifyheads (nn.Linear): Linear transformation to unify outputs from all heads.

    Args:
        dim_in (int): Dimensionality of the input feature vector.
        dim_q (int): Dimensionality of query vectors.
        dim_k (int): Dimensionality of key vectors.
        dim_v (int): Dimensionality of value vectors.
        num_heads (int): Number of attention heads. Defaults to 8.
    """
    def __init__(self, dim_in: int, dim_q: int, dim_k: int, dim_v: int, num_heads: int = 8):
        super(MultiheadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.dim_q = dim_q
        self.dim_k = dim_k
        self.dim_v = dim_v

        self.query = nn.Linear(dim_in, dim_q * num_heads)
        self.key = nn.Linear(dim_in, dim_k * num_heads)
        self.value = nn.Linear(dim_in, dim_v * num_heads)

        self.unifyheads = nn.Linear(dim_v * num_heads, dim_in)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the MultiheadSelfAttention mechanism.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_length, dim_in).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, seq_length, dim_in) after applying multihead self-attention.
        """
        batch_size, seq_length, _ = x.size()

        queries = self.query(x).view(batch_size, seq_length, self.num_heads, self.dim_q)
        keys = self.key(x).view(batch_size, seq_length, self.num_heads, self.dim_k)
        values = self.value(x).view(batch_size, seq_length, self.num_heads, self.dim_v)

        keys = keys.transpose(1, 2).contiguous().view(batch_size * self.num_heads, seq_length, self.dim_k)
        queries = queries.transpose(1, 2).contiguous().view(batch_size * self.num_heads, seq_length, self.dim_q)
        values = values.transpose(1, 2).contiguous().view(batch_size * self.num_heads, seq_length, self.dim_v)

        scores = torch.bmm(queries, keys.transpose(1, 2)) / (self.dim_k ** 0.5)
        attention = F.softmax(scores, dim=-1)

        out = torch.bmm(attention, values).view(batch_size, self.num_heads, seq_length, self.dim_v)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_length, self.dim_v * self.num_heads)

        return self.unifyheads(out)

In [60]:
class DecoderBlock(nn.Module):
    """
    A decoder block that applies a fully connected layer, ReLU activation, and another fully connected layer.

    Attributes:
        fc (nn.Sequential): A sequential container of layers.

    Args:
        dim_in (int): The dimensionality of the input features.
        dim_out (int): The dimensionality of the output features.
    """
    def __init__(self, dim_in: int, dim_out: int):
        super(DecoderBlock, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(dim_in, dim_out),
            nn.ReLU(),
            nn.Linear(dim_out, dim_out)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the DecoderBlock.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor after applying the sequential container.
        """
        return self.fc(x)

class PoolingBlock(nn.Module):
    """
    A pooling block that applies max pooling over the sequence dimension and a linear transformation.

    Attributes:
        pooling (nn.Linear): A linear layer for transforming the pooled output.

    Args:
        dim_in (int): The dimensionality of the input features.
        num_outputs (int): The number of output features after pooling.
    """
    def __init__(self, dim_in: int, num_outputs: int):
        super(PoolingBlock, self).__init__()
        self.pooling = nn.Linear(dim_in, num_outputs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the PoolingBlock.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor after max pooling and linear transformation.
        """
        return self.pooling(x.max(dim=1).values)

class SetTransformer(nn.Module):
    """
    A simplified Set Transformer model comprising an encoder, pooling, and decoder block.

    Args:
        encoder_block (nn.Module): The encoder block, typically a MultiheadSelfAttention module.
        pooling_block (nn.Module): The pooling block.
        decoder_block (nn.Module): The decoder block.
    """
    def __init__(
        self,
        encoder_block: nn.Module,
        pooling_block: nn.Module,
        decoder_block: nn.Module
    ):
        super(SetTransformer, self).__init__()
        self.encoder = encoder_block
        self.pooling = pooling_block
        self.decoder = decoder_block

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the SetTransformer.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor after processing through encoder, pooling, and decoder blocks.
        """
        x = self.encoder(x)
        x = self.pooling(x)
        x = self.decoder(x)
        return x

# Example usage and parameters are provided in the original script.

# Parameters
batch_size = 5
seq_length = 10  # Number of coordinates
input_dim = 2  # Dimension of each coordinate (Xfov, Yfov)
dim_q = dim_k = dim_v = 64  # Dimension for query, key, value
num_heads = 4  # Number of attention heads
num_outputs = 1  # Adjust based on your specific needs
output_dim = 128  # Example output dimension

# Generate a batch of random coordinates
random_coordinates = torch.rand(batch_size, seq_length, input_dim)

set_transformer = SetTransformer(
    encoder_block=MultiheadSelfAttention(dim_in=input_dim, dim_q=dim_q, dim_k=dim_k, dim_v=dim_v, num_heads=num_heads),
    pooling_block=PoolingBlock(dim_in=input_dim, num_outputs=num_outputs),
    decoder_block=DecoderBlock(dim_in=num_outputs, dim_out=output_dim)
)

# Pass the random coordinates through the SetTransformer module
output = set_transformer(random_coordinates)

print("Output shape:", output.shape)

Output shape: torch.Size([5, 128])


In [54]:
class Decoder(nn.Module):
    def __init__(self, attention, projection, mask_projection, max_sequence_len, output_dim):
        """
        Initializes the decoder module.

        Args:
            attention (nn.Module): The multi-head self-attention module.
            projection (nn.Module): The linear projection module to the output space.
            mask_projection (nn.Module): The module to mask out the future tokens.
            max_sequence_len (int): The maximum length of the output sequence.
            output_dim (int): The dimensionality of the output space (e.g., 2 for 2D coordinates).
        """
        super(Decoder, self).__init__()
        self.max_sequence_len = max_sequence_len
        self.output_dim = output_dim

        # Assuming dim_q = dim_k = dim_v for simplicity
        self.attention = attention
        # Linear layer to project from the attention output back to the original space
        self.projection = projection
        self.mask_projection = mask_projection

    def forward(self, embeddings):
        """
        Forward pass of the decoder.

        Args:
            embeddings (torch.Tensor): The embeddings tensor of shape (batch_size, embedding_dim).

        Returns:
            torch.Tensor: The reconstructed set of coordinates of shape (batch_size, max_sequence_len, output_dim).
        """
        # Apply multihead self-attention
        attention_output = self.attention(embeddings)

        # Project the output of the attention mechanism back to the original input space
        projected_output = self.projection(attention_output)

        # Reshape to match the expected output shape
        output = projected_output.view(-1, self.max_sequence_len, self.output_dim)

        mask = self.mask_projection(attention_output)

        return output, mask

In [84]:
max_sequence_len = 20

encoder = SetTransformer(
    encoder_block=MultiheadSelfAttention(dim_in=input_dim, dim_q=dim_q, dim_k=dim_k, dim_v=dim_v, num_heads=num_heads),
    pooling_block=PoolingBlock(dim_in=input_dim, num_outputs=num_outputs),
    decoder_block=DecoderBlock(dim_in=num_outputs, dim_out=output_dim)
)

decoder = Decoder(
    attention=MultiheadSelfAttention(dim_in=output_dim, dim_q=dim_q, dim_k=dim_k, dim_v=dim_v, num_heads=num_heads),
    projection=nn.Linear(output_dim, input_dim*max_sequence_len),
    mask_projection=nn.Sequential(
        nn.Linear(output_dim, max_sequence_len),
        nn.Softmax(dim=-1),
        nn.Flatten(1),
        nn.Unflatten(-1, (max_sequence_len, 1))
    ),
    max_sequence_len=max_sequence_len,
    output_dim=input_dim
)

In [85]:
output_, mask = decoder(output.unsqueeze(1))

In [98]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class VectorToMatrixDeconv(nn.Module):
    def __init__(self, n_channels):

        super(VectorToMatrixDeconv, self).__init__()
        self.n_channels = n_channels
        
        # Assuming the input vector is treated as a 1D "image" of shape (1, 1, input_len)
        # The goal is to expand it to the target shape (m, n)
        # This example uses a kernel size and stride that would need to be adjusted
        # based on the specific relationship between input_len and output_shape
        self.deconv = nn.ConvTranspose2d(in_channels=1, out_channels=1,
                                         kernel_size=(n_channels, 1), stride=(n_channels, 1))

    def forward(self, x):
        """
        Forward pass to unwrap the input vector into a matrix.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, input_len).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, 1, m, n).
        """
        # Reshape input to match the expected shape for ConvTranspose2d: (N, C, H, W)
        x = x.view(-1, 1, 1, x.shape[-1])  # Add dummy dimensions for H and W
        
        # Apply the deconvolution operation
        x = self.deconv(x)
        
        # Optionally, remove the channel dimension if it's not needed in the output
        x = x.squeeze(1)  # Resulting shape: (batch_size, m, n)
        
        return x

# Example usage
input_len = 100  # Example input length
n_channels = 10  # Desired output shape (m, n)

model = VectorToMatrixDeconv(n_channels=n_channels)

# Create a dummy input vector
x = torch.randn(10, input_len)  # Batch size of 1

# Unwrap the vector into a matrix
output = model(x)
print("Output shape:", output.shape)

Output shape: torch.Size([10, 10, 100])


In [102]:
class MatrixToMatrixDeconv(nn.Module):
    def __init__(self, input_shape, output_shape):
        """
        Initializes the deconvolution layer to transform a matrix of size (m, n) into a matrix of size (p, q).

        Args:
            input_shape (tuple): The shape of the input matrix (m, n).
            output_shape (tuple): The desired shape of the output matrix (p, q).
        """
        super(MatrixToMatrixDeconv, self).__init__()
        self.input_shape = input_shape
        self.output_shape = output_shape
        m, n = input_shape
        p, q = output_shape

        # Calculate the kernel size, stride, and padding required to achieve the desired output shape.
        # These values are placeholders and should be adjusted based on the specific requirements of your task.
        kernel_size = (p // m, q // n)  # Example calculation
        stride = (p // m, q // n)  # Example calculation
        padding = (0, 0)  # Adjust as needed

        self.deconv = nn.ConvTranspose2d(in_channels=1, out_channels=1,
                                         kernel_size=kernel_size, stride=stride,
                                         padding=padding)

    def forward(self, x):
        """
        Forward pass to transform the input matrix into the desired output matrix shape.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, 1, m, n).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, 1, p, q).
        """
        # Apply the deconvolution operation
        x = self.deconv(x)
        
        return x

# Example usage
input_shape = (10, 10)  # Shape of the input matrix (m, n)
output_shape = (20, 20)  # Desired shape of the output matrix (p, q)

model = MatrixToMatrixDeconv(input_shape=input_shape, output_shape=output_shape)

# Create a dummy input matrix
x = torch.randn(1, *input_shape)  # Batch size of 1, 1 channel

# Transform the matrix
output = model(x)
print("Output shape:", output.shape)

Output shape: torch.Size([1, 20, 20])


In [None]:
class NeuralPopulationDecoder(nn.Module):
    def __init__(self, input_dim, m, n, p, q, hidden_dims = None):
        super(NeuralPopulationDecoder, self).__init__()
        if hidden_dims is None:
            self.hidden_dims = (512, 256)
        else:
            self.hidden_dims = hidden_dims
        self.m, self.n, self.p, self.q = m, n, p, q
        self.fc1 = nn.Linear(input_dim, self.hidden_dims[0])  # Example dimension
        self.fc2 = nn.Linear(self.hidden_dims[0], self.hidden_dims[1])  # Further processing

        # Output layers for each population
        self.fc_pop1_4 = nn.Linear(self.hidden_dims[1], self.hidden_dims[1])
        self.pop1_4 = VectorToMatrixDeconv(4)
        self.pop5 = nn.Linear(self.hidden_dims[1], n)  # Population 5
        # self.pop6 = nn.Linear(self.hidden_dims[1], p*q)  # Population 6
        self.pop6 = nn.Sequential(
            nn.Linear(self.hidden_dims[1], (p*q)//8), # ((p*q)//8,)
            VectorToMatrixDeconv((p*q)//8), # ((p*q)//8, (p*q)//8)
            nn.BatchNorm2d((p*q)//8),
            nn.ReLU(),
            MatrixToMatrixDeconv(((p*q)//8, (p*q)//8), (p//4, q//4)),
            nn.BatchNorm2d(p//4),
            nn.ReLU(),
            MatrixToMatrixDeconv((p//4, q//4), (p//2, q//2)),
            nn.BatchNorm2d(p//2),
            nn.ReLU(),
            MatrixToMatrixDeconv((p//2, q//2), (p, q)),
            nn.BatchNorm2d(p),
            nn.ReLU()
        )

    def forward(self, embeddings):
        x = F.relu(self.fc1(embeddings))
        x = F.relu(self.fc2(x))

        # Generate outputs for each population
        outputs_1_4 = self.pop1_4(self.fc_pop1_4(x))
        output_5 = F.softmax(self.pop5(x), dim=-1)  # Assuming n represents categories
        output_6 = self.pop6(x).view(-1, self.p, self.q)  # Reshape for place cells

        return outputs_1_4, output_5, output_6

In [None]:
class PopulationEmbeddingsDecoder(nn.Module):
    def __init__(
        self,
        n_filters,
        filter_length,
        pop5_dim,
        pop6_dims,
        embedding_dim
    ):
        """
        Initializes the decoder module.

        Args:
            n_filters (int): Number of filters for the convolutional layers.
            filter_length (int): Length of the 1D convolutional filters for populations 1-4.
            pop5_dim (int): Dimensionality of population 5.
            pop6_dims (tuple): Dimensions of population 6 matrix, expected to be a tuple (p, q).
            embedding_dim (int): Desired dimensionality of the output embeddings.
        """
        super(PopulationEmbeddingsDecoder, self).__init__()
        p, q = pop6_dims  # Unpack dimensions for population 6

        # Convolution for populations 1-4
        self.conv1_4 = nn.Conv1d(in_channels=4, out_channels=n_filters, kernel_size=filter_length)
        self.pool1_4 = nn.MaxPool1d(kernel_size=2)

        # Linear for population 5
        self.linear5 = nn.Linear(pop5_dim, 128)  # Example dimension

        # CNN for population 6
        self.conv6 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1)  # Assuming population 6 has a single channel
        self.pool6 = nn.MaxPool2d(kernel_size=2)
        self.fc6 = nn.Linear(64 * (p//2) * (q//2), 128)  # Adjust based on the size of population 6 after pooling

        # Calculate the flattened size after pooling for populations 1-4
        self.flattened_size_1_4 = n_filters * ((filter_length - 1) // 2)  # Adjust based on convolution and pooling

        # Final projection to embeddings
        self.final_projection = nn.Linear(self.flattened_size_1_4 + 128 + 128, embedding_dim)

    def forward(self, pop1_4, pop5, pop6):
        # Process populations 1-4
        pop1_4 = pop1_4.permute(0, 2, 1)  # Assuming shape (batch_size, 4, sequence_length)
        x1_4 = self.pool1_4(F.relu(self.conv1_4(pop1_4)))
        x1_4 = x1_4.view(x1_4.size(0), -1)  # Flatten

        # Process population 5
        x5 = F.relu(self.linear5(pop5))

        # Process population 6
        x6 = self.pool6(F.relu(self.conv6(pop6.unsqueeze(1))))  # Add channel dimension
        x6 = x6.view(x6.size(0), -1)  # Flatten
        x6 = F.relu(self.fc6(x6))

        # Concatenate and project to embeddings
        x = torch.cat([x1_4, x5, x6], dim=1)
        embeddings = self.final_projection(x)

        return embeddings