check the smart idea here. https://math.stackexchange.com/questions/1945329/can-you-transpose-a-matrix-using-matrix-multiplication
The only problem is that this method takes way too much memory (1048576x1048576 I think), and perhaps we need to figure out a better way to make transpose operation trainable

In [15]:
import torch
import torch.nn as nn

class TrainableTensorTransposeLayer(nn.Module):
    def __init__(self, n):
        super(TrainableTensorTransposeLayer, self).__init__()

        # Create the initial tensor T close to the transpose operation for n x n matrices
        I = torch.eye(n)
        initial_T = torch.zeros(n * n, n * n)
        for i in range(n):
            for j in range(n):
                row_idx = i * n + j
                col_idx = j * n + i
                initial_T[row_idx, col_idx] = 1

        self.weights = nn.Parameter(initial_T, requires_grad=True)
        self.bias = nn.Parameter(torch.zeros(n * n, 1), requires_grad=True)

    def forward(self, x):
        x_reshaped = x.reshape(-1, 1)  # Reshape to column vector
        out_reshaped = torch.mm(self.weights, x_reshaped) + self.bias
        return out_reshaped.reshape(x.size())  # Reshape back to n x n

# Function to test the layer
def test_trainable_tensor_transpose_layer():
    n = 1024
    layer = TrainableTensorTransposeLayer(n)

    # Create a random tensor of size n x n
    A = torch.rand(n, n)

    # Check initially if it's performing approximately a transpose operation
    output = layer(A)
    expected = A.t()

    diff = torch.abs(output - expected)
    max_diff = torch.max(diff).item()
    mean_diff = torch.mean(diff).item()

    if not torch.allclose(output, expected, atol=1e-2):
        print(f"Initialization discrepancy detected! Max difference: {max_diff}, Mean difference: {mean_diff}")
    else:
        print("Initialization passed!")

    # Additional tests or training can be added here

test_trainable_tensor_transpose_layer()


RuntimeError: [enforce fail at C:\cb\pytorch_1000000000000\work\c10\core\impl\alloc_cpu.cpp:81] data. DefaultCPUAllocator: not enough memory: you tried to allocate 4398046511104 bytes.

A workaround method should be maintaining a K_t layer inside transformer other than K layer, so we want the layer to directly learn K_t other than K because we will need it later anyway. But there might be problems, we need to figure them out.

In [None]:
import torch
import torch.nn as nn

class TransformerAttention(nn.Module):
    def __init__(self, embed_dim):
        super(TransformerAttention, self).__init__()

        # Linear layers for Q and K-transpose
        self.query_layer = nn.Linear(embed_dim, embed_dim, bias=True)
        self.key_transpose_layer = nn.Linear(embed_dim, embed_dim, bias=True)

    def forward(self, x):
        Q = self.query_layer(x)

        # Get K-transpose directly
        K_T = self.key_transpose_layer(x)

        # Attention calculation
        attention_scores = torch.matmul(Q, K_T)

        return attention_scores

# Test
embed_dim = 512
model = TransformerAttention(embed_dim)
input_embeddings = torch.rand(10, embed_dim)  # batch of 10 sequences
output = model(input_embeddings)

In [5]:
import torch

# Check for GPU availability
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

dim = 1024
token_dim = 1024
model = torch.nn.Linear(token_dim, token_dim, bias=False).to(device)
criterion = torch.nn.MSELoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 100
batch_size = 8

for outer_iter in range(100):  # Outer loop

    # Creating dataset directly on the GPU
    num_samples = 100
    inputs = torch.rand(num_samples, dim, token_dim, device=device)
    K_layer = torch.nn.Linear(token_dim, token_dim, bias=False).to(device)

    # Passing through K
    K_outputs = K_layer(inputs)

    # Transposing last two dimensions, operation happens on the GPU
    K_transposed_targets = K_outputs.transpose(1, 2)

    # Free up some VRAM by deleting K_outputs which is no longer needed
    del K_outputs

    # Split dataset into train and validation
    train_size = int(0.8 * num_samples)
    inputs_train, inputs_val = inputs[:train_size], inputs[train_size:]
    targets_train, targets_val = K_transposed_targets[:train_size], K_transposed_targets[train_size:]

    for epoch in range(num_epochs):
        for i in range(0, train_size, batch_size):
            optimizer.zero_grad()

            outputs = model(inputs_train[i:i+batch_size])
            loss = criterion(outputs, targets_train[i:i+batch_size])

            loss.backward(retain_graph=True)

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()

            # Memory cleanup
            del outputs, loss

        # Validation loss with batch processing
        val_loss_sum = 0.0
        num_batches = 0
        with torch.no_grad():
            for i in range(0, len(inputs_val), batch_size):
                val_outputs = model(inputs_val[i:i+batch_size])
                val_loss = criterion(val_outputs, targets_val[i:i+batch_size])
                val_loss_sum += val_loss.item()
                num_batches += 1

        avg_val_loss = val_loss_sum / num_batches
        print(f"Outer Iteration {outer_iter+1}, Epoch {epoch+1}, Avg Val Loss: {avg_val_loss}")


Outer Iteration 1, Epoch 1, Avg Val Loss: 0.14446668326854706
Outer Iteration 1, Epoch 2, Avg Val Loss: 0.14215830465157828
Outer Iteration 1, Epoch 3, Avg Val Loss: 0.139599879582723
Outer Iteration 1, Epoch 4, Avg Val Loss: 0.13855704168478647
Outer Iteration 1, Epoch 5, Avg Val Loss: 0.13806120057900748
Outer Iteration 1, Epoch 6, Avg Val Loss: 0.1376983126004537
Outer Iteration 1, Epoch 7, Avg Val Loss: 0.1373924563328425
Outer Iteration 1, Epoch 8, Avg Val Loss: 0.13709504902362823
Outer Iteration 1, Epoch 9, Avg Val Loss: 0.13679800430933634
Outer Iteration 1, Epoch 10, Avg Val Loss: 0.13649570445219675
Outer Iteration 1, Epoch 11, Avg Val Loss: 0.1361877123514811
Outer Iteration 1, Epoch 12, Avg Val Loss: 0.13587446510791779
Outer Iteration 1, Epoch 13, Avg Val Loss: 0.13555688659350076
Outer Iteration 1, Epoch 14, Avg Val Loss: 0.13523600002129874
Outer Iteration 1, Epoch 15, Avg Val Loss: 0.13491245607535043
Outer Iteration 1, Epoch 16, Avg Val Loss: 0.13458707928657532
Outer 


KeyboardInterrupt

