In [11]:
import torch

import torch.nn as nn

import torch.nn.functional as F

import numpy as np



def random_projection(matrix, k):

    """Random projection to reduce dimensionality of matrix to k dimensions."""

    random_matrix = torch.randn(matrix.size(-1), k, device=matrix.device)

    return torch.matmul(matrix, random_matrix)



def cur_decomposition(matrix, k):
    # Ensure k does not exceed the matrix dimensions for both rows and columns
    k = min(k, matrix.size(-1), matrix.size(-2))

    # Generate indices within the valid range
    cols_indices = torch.randperm(matrix.size(-1), device=matrix.device)[:k]
    rows_indices = torch.randperm(matrix.size(-2), device=matrix.device)[:k]

    # Use torch.index_select for safer indexing
    C = torch.index_select(matrix, dim=-1, index=cols_indices)
    R = torch.index_select(matrix, dim=-2, index=rows_indices)
    # Intersection for U needs careful handling to ensure dimensions match
    U = torch.pinverse(C[:, rows_indices, :])

    return C, U, R







class PartitionedLinformerAttentionACT(nn.Module):

    def __init__(self, embed_size, heads, sequence_length, projection_dim, partition_size):

        super().__init__()

        self.embed_size = embed_size

        self.heads = heads

        self.sequence_length = sequence_length

        self.projection_dim = projection_dim

        self.partition_size = partition_size

        self.head_dim = embed_size // heads



        self.values = nn.Linear(embed_size, self.head_dim, bias=False)  
        
        self.keys = nn.Linear(embed_size, self.projection_dim, bias=False)  
        
        self.queries = nn.Linear(embed_size, self.projection_dim, bias=False)  

        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)



        # ACT components

        self.ponder = nn.Linear(self.head_dim, 1, bias=True)  # Ponder network to decide halting

        self.sigmoid = nn.Sigmoid()



    def forward(self, x):
        N, seq_length, _ = x.shape

        # Linear transformation without additional reshaping for values
        values = self.values(x)  # Output shape: [batch_size, seq_length, head_dim]
        keys = self.keys(x)  # Output shape after projection: [batch_size, seq_length, projection_dim]
        queries = self.queries(x)  # Same as keys

        # Apply random projection to reduce dimensionality
        keys_proj = random_projection(keys, self.projection_dim // 2)
        queries_proj = random_projection(queries, self.projection_dim // 2)

        # Prepare for attention computation by reshaping
        # Ensure the projection dimension is correctly used
        values = values.view(N, seq_length, self.heads, -1)
        keys_proj = keys_proj.view(N, seq_length, self.heads, -1)
        queries_proj = queries_proj.view(N, seq_length, self.heads, -1)

        attention_scores = torch.zeros(N, self.heads, seq_length, self.projection_dim // 2, device=x.device)

        for i in range(0, seq_length, self.partition_size):
            partition_start = i * self.partition_size
            partition_end = min(i + self.partition_size, seq_length)
            keys_partition = keys_proj[:, partition_start:partition_end, :, :]
            queries_partition = queries_proj[:, partition_start:partition_end, :, :]

            # Apply CUR decomposition
            k = min(keys_partition.size(-1) // 2, keys_partition.size(-2), keys_partition.size(-1))
            C_keys, U, R_queries = cur_decomposition(keys_partition, k=min(keys_partition.size(-1) // 2, keys_partition.size(-2), keys_partition.size(-1)))

            # Compute ponder scores
            ponder_scores = self.sigmoid(self.ponder(queries_partition).squeeze(-1))

            # Use ponder scores to weigh the importance of each partition's computation
            energy = torch.einsum('nshd,nshp->nhdp', R_queries, C_keys.transpose(-2, -1))
            attention = F.softmax(energy / (self.embed_size ** (1 / 2)), dim=-1) * ponder_scores.unsqueeze(-1)
            attention_scores[:, :, partition_start:partition_end, :] = attention

        # Attention aggregation and final output shape adjustment
        out = torch.einsum('nhdp,nshd->nshp', attention_scores, values.reshape(N, seq_length, -1))
        out = self.fc_out(out.reshape(N, seq_length, -1))

        return out





# Example usage

model = PartitionedLinformerAttentionACT(embed_size=512, heads=8, sequence_length=1024, projection_dim=256, partition_size=128)

input_tensor = torch.rand(2, 1024, 512)

output = model(input_tensor)

print(output.shape)  # Expected output shape: [2, sequence_length, embed_size]

RuntimeError: mat1 and mat2 shapes cannot be multiplied (2048x16 and 64x1)