In [11]:
import torch

import torch.nn as nn

import torch.nn.functional as F

import numpy as np



def random_projection(matrix, k):

    """Random projection to reduce dimensionality of matrix to k dimensions."""

    random_matrix = torch.randn(matrix.size(-1), k, device=matrix.device)

    return torch.matmul(matrix, random_matrix)



def cur_decomposition(matrix, k):
    # Ensure k does not exceed the matrix dimensions for both rows and columns
    k = min(k, matrix.size(-1), matrix.size(-2))

    # Generate indices within the valid range
    cols_indices = torch.randperm(matrix.size(-1), device=matrix.device)[:k]
    rows_indices = torch.randperm(matrix.size(-2), device=matrix.device)[:k]

    # Use torch.index_select for safer indexing
    C = torch.index_select(matrix, dim=-1, index=cols_indices)
    R = torch.index_select(matrix, dim=-2, index=rows_indices)
    # Intersection for U needs careful handling to ensure dimensions match
    U = torch.pinverse(C[:, rows_indices, :])

    return C, U, R







class PartitionedLinformerAttentionACT(nn.Module):

    def __init__(self, embed_size, heads, sequence_length, projection_dim, partition_size):

        super().__init__()

        self.embed_size = embed_size

        self.heads = heads

        self.sequence_length = sequence_length

        self.projection_dim = projection_dim

        self.partition_size = partition_size

        self.head_dim = embed_size // heads



        self.values = nn.Linear(embed_size, heads * self.head_dim, bias=False)
        
        self.keys = nn.Linear(embed_size, self.projection_dim, bias=False)  
        
        self.queries = nn.Linear(embed_size, self.projection_dim, bias=False)  

        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)



        # ACT components

        self.ponder = nn.Linear(self.head_dim, 1, bias=True)  # Ponder network to decide halting

        self.sigmoid = nn.Sigmoid()


    '''
    def forward(self, x):
        N, seq_length, _ = x.shape

        values = self.values(x)
        keys = self.keys(x)
        queries = self.queries(x)

        keys_proj = random_projection(keys, self.projection_dim // 2)
        queries_proj = random_projection(queries, self.projection_dim // 2)


        values = values.view(N, seq_length, self.heads, self.head_dim)
        keys_proj = keys_proj.view(N, seq_length, self.heads, -1)
        queries_proj = queries_proj.view(N, seq_length, self.heads, -1)

        attention_scores = torch.zeros(N, self.heads, seq_length, self.projection_dim // 2, device=x.device)

        for i in range(0, seq_length, self.partition_size):
            partition_end = min((i+1) * self.partition_size, seq_length)
            keys_partition = keys_proj[:, i:partition_end, :, :]
            queries_partition = queries_proj[:, i:partition_end, :, :]

            C_keys, U, R_queries = cur_decomposition(keys_partition, k=self.projection_dim // 2)

            queries_flattened = queries_partition.view(N, -1, self.head_dim)
            ponder_scores = self.sigmoid(self.ponder(queries_flattened).squeeze(-1))

            energy = torch.einsum('nshd,nshp->nhdp', R_queries, C_keys.transpose(-2, -1))
            attention = F.softmax(energy / (self.embed_size ** (1 / 2)), dim=-1)

            # Now that `attention` is computed, we can correctly shape `ponder_scores_expanded`
            ponder_scores_expanded = ponder_scores.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, attention.size(-1))
            print(f"attention shape: {attention.shape}")
            print(f"ponder_scores shape: {ponder_scores.shape}")
            print(f"ponder_scores_expanded shape: {ponder_scores_expanded.shape}")
            print(f"attention_scores shape: {attention_scores.shape}")
            # Apply ponder_scores correctly
            attention_scores[:, :, i:partition_end, :] = attention * ponder_scores_expanded

        # Reshape and apply final linear transformation to compute output
        out = torch.einsum('nhdp,nshd->nshp', attention_scores, values).reshape(N, seq_length, -1)
        out = self.fc_out(out)

        return out'''
    

    def forward(self, x):
        N, seq_length, _ = x.shape

        values = self.values(x)
        keys = self.keys(x)
        queries = self.queries(x)

        keys_proj = random_projection(keys, self.projection_dim // 2)
        queries_proj = random_projection(queries, self.projection_dim // 2)

        values = values.view(N, seq_length, self.heads, self.head_dim)
        keys_proj = keys_proj.view(N, seq_length, self.heads, -1)
        queries_proj = queries_proj.view(N, seq_length, self.heads, -1)

        attention_scores = torch.zeros(N, self.heads, seq_length, self.projection_dim // 2, device=x.device)

        for i in range(0, seq_length, self.partition_size):
            partition_end = min((i+1) * self.partition_size, seq_length)
            keys_partition = keys_proj[:, i:partition_end, :, :]
            queries_partition = queries_proj[:, i:partition_end, :, :]

            C_keys, U, R_queries = cur_decomposition(keys_partition, k=self.projection_dim // 2)

            # Adjusted calculation of ponder_scores
            # Calculate ponder_scores per head instead of per flattened queries
            # This involves reshaping queries_partition to align with heads and sequence segments
            # Note: This step may require further adjustment based on the specific model behavior you're aiming for
            queries_reshaped = queries_partition.view(N, -1, self.heads, self.head_dim)
            ponder_scores = self.sigmoid(self.ponder(queries_reshaped).squeeze(-1))

            # Ensure ponder_scores are compatible with attention's dimensions
            # This involves adjusting ponder_scores' shape to [N, heads, partition_size, projection_dim // 2]
            # Note: This step assumes partition_size is a valid segmentation of the sequence for attention calculations
            ponder_scores_expanded = ponder_scores.unsqueeze(-1).expand(-1, -1, partition_end - i, -1)

            #ponder_scores_expanded = ponder_scores.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, attention.size(-1))

            energy = torch.einsum('nshd,nshp->nhdp', R_queries, C_keys.transpose(-2, -1))
            attention = F.softmax(energy / (self.embed_size ** (1 / 2)), dim=-1)
            print(f"attention shape: {attention.shape}")
            print(f"ponder_scores shape: {ponder_scores.shape}")
            print(f"ponder_scores_expanded shape: {ponder_scores_expanded.shape}")
            print(f"attention_scores shape: {attention_scores.shape}")
            # Correctly apply ponder_scores to attention
            attention_scores[:, :, i:partition_end, :] = attention * ponder_scores_expanded

        # Reshape and apply final linear transformation to compute output
        out = torch.einsum('nhdp,nshd->nshp', attention_scores, values).reshape(N, seq_length, -1)
        out = self.fc_out(out)

        return out










# Example usage

model = PartitionedLinformerAttentionACT(embed_size=512, heads=8, sequence_length=1024, projection_dim=256, partition_size=128)

input_tensor = torch.rand(2, 1024, 512)

output = model(input_tensor)

print(output.shape)  # Expected output shape: [2, sequence_length, embed_size]

RuntimeError: The expanded size of the tensor (128) must match the existing size (8) at non-singleton dimension 2.  Target sizes: [-1, -1, 128, -1].  Tensor sizes: [2, 32, 8, 1]

In [14]:
import torch

import torch.nn as nn

import torch.nn.functional as F

import numpy as np



def random_projection(matrix, k):

    """Random projection to reduce dimensionality of matrix to k dimensions."""

    random_matrix = torch.randn(matrix.size(-1), k, device=matrix.device)

    return torch.matmul(matrix, random_matrix)



def cur_decomposition(matrix, k):
    # Ensure k does not exceed the matrix dimensions for both rows and columns
    k = min(k, matrix.size(-1), matrix.size(-2))

    # Generate indices within the valid range
    cols_indices = torch.randperm(matrix.size(-1), device=matrix.device)[:k]
    rows_indices = torch.randperm(matrix.size(-2), device=matrix.device)[:k]

    # Use torch.index_select for safer indexing
    C = torch.index_select(matrix, dim=-1, index=cols_indices)
    R = torch.index_select(matrix, dim=-2, index=rows_indices)
    # Intersection for U needs careful handling to ensure dimensions match
    U = torch.pinverse(C[:, rows_indices, :])

    return C, U, R







class PartitionedLinformerAttentionACT(nn.Module):

    def __init__(self, embed_size, heads, sequence_length, projection_dim, partition_size):

        super().__init__()

        self.embed_size = embed_size

        self.heads = heads

        self.sequence_length = sequence_length

        self.projection_dim = projection_dim

        self.partition_size = partition_size

        self.head_dim = embed_size // heads



        self.values = nn.Linear(embed_size, heads * self.head_dim, bias=False)
        
        self.keys = nn.Linear(embed_size, self.projection_dim, bias=False)  
        
        self.queries = nn.Linear(embed_size, self.projection_dim, bias=False)  

        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)



        # ACT components

        self.ponder = nn.Linear(self.head_dim, 1, bias=True)  # Ponder network to decide halting

        self.sigmoid = nn.Sigmoid()


    def forward(self, x):
        N, seq_length, _ = x.shape

        values = self.values(x)
        keys = self.keys(x)
        queries = self.queries(x)

        keys_proj = random_projection(keys, self.projection_dim // 2)
        queries_proj = random_projection(queries, self.projection_dim // 2)


        values = values.view(N, seq_length, self.heads, self.head_dim)
        keys_proj = keys_proj.view(N, seq_length, self.heads, -1)
        queries_proj = queries_proj.view(N, seq_length, self.heads, -1)

        attention_scores = torch.zeros(N, self.heads, seq_length, self.projection_dim // 2, device=x.device)

        for i in range(0, seq_length, self.partition_size):
            partition_end = min((i+1) * self.partition_size, seq_length)

            keys_partition = keys_proj[:, i:partition_end, :, :]
            queries_partition = queries_proj[:, i:partition_end, :, :]

            C_keys, U, R_queries = cur_decomposition(keys_partition, k=self.projection_dim // 2)

            queries_partition_flattened = queries_partition.reshape(N * partition_end * self.heads, -1)
            ponder_scores_raw = self.ponder(queries_partition_flattened)
            ponder_scores = self.sigmoid(ponder_scores_raw)

            # Reshape ponder_scores back to match the original partitioning and heads
            # Then, calculate the mean across the dimension representing head_dim
            ponder_scores_reshaped = ponder_scores.view(N, partition_end, self.heads, -1).mean(dim=-1)

            # Correctly reshape ponder_scores for broadcasting with attention
            # Assuming attention mechanism is applied per head and per partitioned segment of the sequence
            ponder_scores_expanded = ponder_scores_reshaped.unsqueeze(-1).unsqueeze(-1)

            energy = torch.einsum('nshd,nshp->nhdp', R_queries, C_keys.transpose(-2, -1))
            attention = F.softmax(energy / (self.embed_size ** (1 / 2)), dim=-1)
            print(f"attention shape: {attention.shape}")
            print(f"ponder_scores shape: {ponder_scores.shape}")
            print(f"ponder_scores_expanded shape: {ponder_scores_expanded.shape}")
            print(f"attention_scores shape: {attention_scores.shape}")
            # Apply ponder_scores correctly
            attention_scores[:, :, i:partition_end, :] = attention * ponder_scores_expanded


        # Reshape and apply final linear transformation to compute output
        out = torch.einsum('nhdp,nshd->nshp', attention_scores, values).reshape(N, seq_length, -1)
        out = self.fc_out(out)

        return out
    

# Example usage

model = PartitionedLinformerAttentionACT(embed_size=512, heads=8, sequence_length=1024, projection_dim=256, partition_size=128)

input_tensor = torch.rand(2, 1024, 512)

output = model(input_tensor)

print(output.shape)  # Expected output shape: [2, sequence_length, embed_size]

RuntimeError: mat1 and mat2 shapes cannot be multiplied (2048x16 and 64x1)

# v3

In [None]:
import torch

import torch.nn as nn

import torch.nn.functional as F

import numpy as np



def random_projection(matrix, k):

    """Random projection to reduce dimensionality of matrix to k dimensions."""

    random_matrix = torch.randn(matrix.size(-1), k, device=matrix.device)

    return torch.matmul(matrix, random_matrix)



def cur_decomposition(matrix, k):

    """Performs approximate CUR decomposition on a matrix to identify k columns and rows."""

    # For simplicity, using a randomized approach to select columns and rows

    cols_indices = np.random.choice(matrix.shape[-1], k, replace=False)

    rows_indices = np.random.choice(matrix.shape[-2], k, replace=False)

    

    C = matrix[:, :, cols_indices]

    R = matrix[:, rows_indices, :]

    U = torch.pinverse(matrix[:, rows_indices][:, :, cols_indices])

    

    return C, U, R



class PartitionedLinformerAttentionACT(nn.Module):

    def __init__(self, embed_size, heads, sequence_length, projection_dim, partition_size):

        super().__init__()

        self.embed_size = embed_size

        self.heads = heads

        self.sequence_length = sequence_length

        self.projection_dim = projection_dim

        self.partition_size = partition_size

        self.head_dim = embed_size // heads



        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)

        self.keys = nn.Linear(self.head_dim, self.projection_dim, bias=False)

        self.queries = nn.Linear(self.head_dim, self.projection_dim, bias=False)

        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)



        # For ACT

        self.ponder = nn.Linear(self.head_dim, 1)  # Ponder network to decide halting

        

        # Special initialization for head specialization

        for head in range(self.heads):

            nn.init.normal_(self.keys.weight[head * self.head_dim:(head + 1) * self.head_dim], mean=0, std=0.02 * (head + 1))

            nn.init.normal_(self.queries.weight[head * self.head_dim:(head + 1) * self.head_dim], mean=0, std=0.02 * (head + 1))



    def forward(self, x):

        N, seq_length, _ = x.shape

        values = self.values(x).view(N, seq_length, self.heads, self.head_dim)

        keys = self.keys(x).view(N, seq_length, self.heads, self.projection_dim)

        queries = self.queries(x).view(N, seq_length, self.heads, self.projection_dim)



        keys = random_projection(keys, self.projection_dim // 2)

        queries = random_projection(queries, self.projection_dim // 2)



        attention_scores = torch.zeros(N, self.heads, seq_length, self.projection_dim // 2, device=x.device)



        for i in range(0, seq_length, self.partition_size):

            partition_start = i * self.partition_size

            partition_end = min(i + self.partition_size, seq_length)

            keys_partition = keys[:, partition_start:partition_end, :, :]

            queries_partition = queries[:, partition_start:partition_end, :, :]



            # Apply CUR decomposition within each partition for keys and queries

            C_keys, U, R_queries = cur_decomposition(keys_partition, k=keys_partition.size(-1) // 2)



            # Compute attention scores using CUR components

            energy = torch.einsum('nshd,nshp->nhdp', R_queries, C_keys.transpose(-2, -1))

            attention = F.softmax(energy / (self.embed_size ** (1 / 2)), dim=-1)

            attention_scores[:, :, partition_start:partition_end, :] = attention



        out = torch.einsum('nhdp,nshd->nshp', attention_scores, values)

        out = out.reshape(N, seq_length, self.heads * self.head_dim)

        return self.fc_out(out)



# Example Usage

model = PartitionedLinformerAttentionACT(embed_size=512, heads=8, sequence_length=1024, projection_dim=256, partition_size=128)

input_tensor = torch.rand(2, sequence_length, embed_size)  # Batch size of 2

output = model(input_tensor)



####################




# v4

In [56]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import GradScaler, autocast
from torch.optim import Adam



def random_projection(matrix, k):
    """Random projection to reduce dimensionality of matrix to k dimensions."""
    random_matrix = torch.randn(matrix.size(-1), k, device=matrix.device)
    return torch.matmul(matrix, random_matrix)

def cur_decomposition(matrix, k):
    """Applies CUR decomposition on each head for each batch."""
    batch_size, seq_length, heads, dim = matrix.shape
    # Ensuring k does not exceed dimensions
    k = min(k, dim)
    C = torch.zeros(batch_size, seq_length, heads, k, device=matrix.device)
    R = torch.zeros(batch_size, k, heads, dim, device=matrix.device)
    
    for b in range(batch_size):
        for h in range(heads):
            col_indices = np.random.choice(dim, k, replace=False)
            row_indices = np.random.choice(seq_length, k, replace=False)
            C[b, :, h] = matrix[b, :, h, col_indices]
            R[b, :, h] = matrix[b, row_indices, h]
    return C, R

class PartitionedLinformerAttentionACT(nn.Module):

    def __init__(self, embed_size, heads, sequence_length, projection_dim, partition_size):

        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.sequence_length = sequence_length
        self.projection_dim = projection_dim
        self.partition_size = partition_size
        self.head_dim = embed_size // heads

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.projection_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.projection_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
        # ACT components
        #self.ponder = nn.Linear(self.partition_size, 1, bias=True)  # Ponder network to decide halting
        self.ponder = nn.Linear(self.head_dim, 1, bias=True)  # Ponder network to decide halting

        self.sigmoid = nn.Sigmoid()



    def forward(self, x):

        N, seq_length, _ = x.shape
        print(f"x shape as input to forward: {x.shape}")
        print(f"embed size as input to forward: {self.embed_size}")
        print(f"values before reshaping: {self.values}")
        print(f"keys before reshaping: {self.keys}")
        print(f"queries before reshaping: {self.queries}")
        x_reshaped = x.view(N, seq_length, self.heads, self.head_dim)
        print(f"x_reshaped : {x_reshaped.shape}")
        values = self.values(x_reshaped).view(N, seq_length, self.heads, self.head_dim)
        keys = self.keys(x_reshaped).view(N, seq_length, self.heads, self.projection_dim)
        queries = self.queries(x_reshaped).view(N, seq_length, self.heads, self.projection_dim)
        print(f"values after reshaping: {values.shape}")
        print(f"keys after reshaping: {keys.shape}")
        print(f"queries after reshaping: {queries.shape}")
        # Apply random projection
        keys = random_projection(keys, self.projection_dim // 2)
        queries = random_projection(queries, self.projection_dim // 2)
        print(f"keys after random_projection: {keys.shape}")
        print(f"queries after random_projection: {queries.shape}")
        attention_scores = torch.zeros(N, self.heads, seq_length, self.projection_dim // 2, device=x.device)
        print(f"attention_scores : {attention_scores.shape}")

        for i in range(0, seq_length, self.partition_size):
            print(f"START OF LOOP")
            partition_start = i * self.partition_size
            partition_end = min(i + self.partition_size, seq_length)
            keys_partition = keys[:, partition_start:partition_end, :, :]
            queries_partition = queries[:, partition_start:partition_end, :, :]
            print(f"partition_start : {partition_start}, partition_end : {partition_end} keys_partition = {keys_partition.shape}, queries_partition = {queries_partition.shape}")
            # Apply CUR decomposition
            C_keys,  R_queries = cur_decomposition(keys_partition, k=keys_partition.size(-1) // 2)
            # Apply ponder to each head's queries. Note: you might need to adjust this part based on your specific needs
            queries_reshaped = queries_partition.reshape(N * seq_length, self.heads, -1)
            ponder_scores = []
            for h in range(self.heads):
                head_queries = queries_reshaped[:, h, :]  # Select queries for head h
                scores = self.sigmoid(self.ponder(head_queries)).squeeze(-1)  # Apply ponder and squeeze
                ponder_scores.append(scores)
            # Combine ponder scores from all heads
            ponder_scores = torch.stack(ponder_scores, dim=1)
            # Reshape ponder_scores back to match the original structure
            ponder_scores = ponder_scores.view(N, seq_length, self.heads, -1)
            # Use ponder scores to weigh the importance of each partition's computation
            energy = torch.einsum('nshd,nshp->nhdp', R_queries, C_keys.transpose(-2, -1))
            attention = F.softmax(energy / (self.embed_size ** (1 / 2)), dim=-1) * ponder_scores.unsqueeze(-1)
            attention_scores[:, :, partition_start:partition_end, :] = attention
        out = torch.einsum('nhdp,nshd->nshp', attention_scores, values)
        out = out.reshape(N, seq_length, self.heads * self.head_dim)

        return self.fc_out(out)



# Example usage

model = PartitionedLinformerAttentionACT(embed_size=512, heads=8, sequence_length=1024, projection_dim=256, partition_size=128)

input_tensor = torch.rand(2, model.sequence_length, model.embed_size)

output = model(input_tensor)

print(output.shape)  # Expected output shape: [2, sequence_length, embed_size]


class MyDataset(Dataset):

    def __init__(self):

        # Initialize your data here

        pass

    

    def __len__(self):

        # Return the size of your dataset

        return 100  # Example size



    def __getitem__(self, idx):

        # Implement logic to get a single item at a given index

        # For simplicity, let's return random tensors

        return torch.randn(10, 64), torch.randint(0, 2, (10,))  # Example data


# Create a DataLoader with pinned memory

dataset = MyDataset()
loader = DataLoader(dataset, batch_size=32, shuffle=True, pin_memory=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()

for epoch in range(10):  # Example: 10 epochs

    for inputs, targets in loader:

        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)



        optimizer.zero_grad()

        

        with autocast():

            predictions = model(inputs)

            loss = nn.functional.cross_entropy(predictions, targets)



        scaler.scale(loss).backward()

        scaler.step(optimizer)

        scaler.update()

x shape as input to forward: torch.Size([2, 1024, 512])
embed size as input to forward: 512
values before reshaping: Linear(in_features=64, out_features=64, bias=False)
keys before reshaping: Linear(in_features=64, out_features=256, bias=False)
queries before reshaping: Linear(in_features=64, out_features=256, bias=False)
x_reshaped : torch.Size([2, 1024, 8, 64])
values after reshaping: torch.Size([2, 1024, 8, 64])
keys after reshaping: torch.Size([2, 1024, 8, 256])
queries after reshaping: torch.Size([2, 1024, 8, 256])
keys after random_projection: torch.Size([2, 1024, 8, 128])
queries after random_projection: torch.Size([2, 1024, 8, 128])
attention_scores : torch.Size([2, 8, 1024, 128])
START OF LOOP
partition_start : 0, partition_end : 128 keys_partition = torch.Size([2, 128, 8, 128]), queries_partition = torch.Size([2, 128, 8, 128])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (2048x16 and 64x1)

# v5

In [47]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

def random_projection(matrix, k):
    """Applies a random projection to reduce dimensionality of the last dimension of the matrix."""
    _, _, _, dim = matrix.shape
    random_matrix = torch.randn(dim, k, device=matrix.device)
    return torch.matmul(matrix, random_matrix)

def apply_cur_decomposition(matrix, k):
    """Applies CUR decomposition on each head for each batch."""
    batch_size, seq_length, heads, dim = matrix.shape
    # Ensuring k does not exceed dimensions
    k = min(k, dim)
    C = torch.zeros(batch_size, seq_length, heads, k, device=matrix.device)
    R = torch.zeros(batch_size, k, heads, dim, device=matrix.device)
    
    for b in range(batch_size):
        for h in range(heads):
            col_indices = np.random.choice(dim, k, replace=False)
            row_indices = np.random.choice(seq_length, k, replace=False)
            C[b, :, h] = matrix[b, :, h, col_indices]
            R[b, :, h] = matrix[b, row_indices, h]
    return C, R

class CustomLinformer(nn.Module):
    def __init__(self, embed_size, heads, sequence_length, projection_dim, partition_size):
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.sequence_length = sequence_length
        self.projection_dim = projection_dim
        self.partition_size = partition_size
        self.head_dim = embed_size // heads

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.projection_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.projection_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, x):
        N, seq_length, _ = x.shape
        x = x.view(N, seq_length, self.heads, self.head_dim)
        
        # Applying linear transformations
        values = self.values(x)
        keys = self.keys(x)
        queries = self.queries(x)
        
        # Reducing dimensionality
        keys_reduced = random_projection(keys, self.projection_dim // 2)
        queries_reduced = random_projection(queries, self.projection_dim // 2)
        
        # Applying CUR decomposition
        C, R = apply_cur_decomposition(keys_reduced, k=self.projection_dim // 4)
        
        # Skipping the computation of U for simplification and directly computing the attention scores
        attention_scores = torch.einsum('bnhd,bnhd->bnhd', queries_reduced, C)

        
        # Normalize attention scores
        attention_scores = F.softmax(attention_scores, dim=-1)
        
        # Aggregate values
        out = torch.einsum('bnhdm,bndh->bnhm', attention_scores, values)
        out = out.reshape(N, seq_length, -1)
        
        return self.fc_out(out)

# Example usage
embed_size = 512
heads = 8
sequence_length = 1024
projection_dim = 256
partition_size = 128

model = CustomLinformer(embed_size, heads, sequence_length, projection_dim, partition_size)
input_tensor = torch.rand(2, sequence_length, embed_size)
output = model(input_tensor)
print(output.shape)  # Expected output shape: [2, sequence_length, embed_size]


RuntimeError: einsum(): subscript d has size 64 for operand 1 which does not broadcast with previously seen size 128