check the smart idea here. https://math.stackexchange.com/questions/1945329/can-you-transpose-a-matrix-using-matrix-multiplication
The only problem is that this method takes way too much memory and perhaps we need to figure out a better way to make transpose operation trainable

In [15]:
import torch
import torch.nn as nn

class TrainableTensorTransposeLayer(nn.Module):
    def __init__(self, n):
        super(TrainableTensorTransposeLayer, self).__init__()

        # Create the initial tensor T close to the transpose operation for n x n matrices
        I = torch.eye(n)
        initial_T = torch.zeros(n * n, n * n)
        for i in range(n):
            for j in range(n):
                row_idx = i * n + j
                col_idx = j * n + i
                initial_T[row_idx, col_idx] = 1

        self.weights = nn.Parameter(initial_T, requires_grad=True)
        self.bias = nn.Parameter(torch.zeros(n * n, 1), requires_grad=True)

    def forward(self, x):
        x_reshaped = x.reshape(-1, 1)  # Reshape to column vector
        out_reshaped = torch.mm(self.weights, x_reshaped) + self.bias
        return out_reshaped.reshape(x.size())  # Reshape back to n x n

# Function to test the layer
def test_trainable_tensor_transpose_layer():
    n = 1024
    layer = TrainableTensorTransposeLayer(n)

    # Create a random tensor of size n x n
    A = torch.rand(n, n)

    # Check initially if it's performing approximately a transpose operation
    output = layer(A)
    expected = A.t()

    diff = torch.abs(output - expected)
    max_diff = torch.max(diff).item()
    mean_diff = torch.mean(diff).item()

    if not torch.allclose(output, expected, atol=1e-2):
        print(f"Initialization discrepancy detected! Max difference: {max_diff}, Mean difference: {mean_diff}")
    else:
        print("Initialization passed!")

    # Additional tests or training can be added here

test_trainable_tensor_transpose_layer()


RuntimeError: [enforce fail at C:\cb\pytorch_1000000000000\work\c10\core\impl\alloc_cpu.cpp:81] data. DefaultCPUAllocator: not enough memory: you tried to allocate 4398046511104 bytes.

A workaround method should be maintaining a K_t layer inside transformer other than K layer, so we want the layer to directly learn K_t other than K because we will need it later anyway. But there might be problems, we need to figure them out.

In [None]:
import torch
import torch.nn as nn

class TransformerAttention(nn.Module):
    def __init__(self, embed_dim):
        super(TransformerAttention, self).__init__()

        # Linear layers for Q and K-transpose
        self.query_layer = nn.Linear(embed_dim, embed_dim, bias=True)
        self.key_transpose_layer = nn.Linear(embed_dim, embed_dim, bias=True)

    def forward(self, x):
        Q = self.query_layer(x)

        # Get K-transpose directly
        K_T = self.key_transpose_layer(x)

        # Attention calculation
        attention_scores = torch.matmul(Q, K_T)

        return attention_scores

# Test
embed_dim = 512
model = TransformerAttention(embed_dim)
input_embeddings = torch.rand(10, embed_dim)  # batch of 10 sequences
output = model(input_embeddings)