In [17]:
import torch
import torch.nn as nn

d = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_im_embedding(matrix_im, latent):
    im_embedding = matrix_im @ latent
    return im_embedding

def get_text_embeddings(matrix_text_1, matrix_text_2, latent):
    text_embedding_1 = (matrix_text_1 @ latent).T
    text_embedding_2 = (matrix_text_2 @ latent).T
    text_embeddings = torch.cat([text_embedding_1, text_embedding_2], dim=0)
    transform_matrix = torch.tensor([[1.0, 0.01], [0.0, 0.5]], device=device)
    text_embeddings_transformed = transform_matrix @ text_embeddings
    return text_embeddings_transformed, text_embedding_1, text_embedding_2

def forward(text_embeddings, im_embedding, query_matrix, key_matrix, value_matrix, readout_weights_1, readout_weights_2):
    query = im_embedding @ query_matrix
    key = text_embeddings @ key_matrix
    value = text_embeddings @ value_matrix

    attention_scores = query @ key.T
    attention_weights = torch.softmax(attention_scores, dim=1)

    attention_output = attention_weights @ value
    
    readout_1 = attention_output @ readout_weights_1
    readout_2 = attention_output @ readout_weights_2

    return readout_1, readout_2


# Generate three random 32x32 matrices
matrix_im = torch.randn(d, d, device=device)
matrix_text_1 = torch.randn(d, d, device=device)
matrix_text_2 = torch.randn(d, d, device=device)

# Generate query, key, value transformation matrices for attention (trainable)
query_matrix = nn.Parameter(torch.randn(d, d, device=device))
key_matrix = nn.Parameter(torch.randn(d, d, device=device))
value_matrix = nn.Parameter(torch.randn(d, d, device=device))

# Define readout weight matrices (trainable)
readout_weights_1 = nn.Parameter(torch.randn(d, d, device=device))
readout_weights_2 = nn.Parameter(torch.randn(d, d, device=device))


trainable_params = [readout_weights_1, readout_weights_2, query_matrix, key_matrix, value_matrix]

optimizer = torch.optim.Adam(trainable_params, lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20000, gamma=0.5)
for epoch in range(20000):
    optimizer.zero_grad()
    latent = torch.randn(d,1, device=device)
    im_embedding = get_im_embedding(matrix_im, latent).T
    text_embeddings, text_embedding_1, text_embedding_2 = get_text_embeddings(matrix_text_1, matrix_text_2, latent) 

    readout_1, readout_2 = forward(text_embeddings, im_embedding, query_matrix, key_matrix, value_matrix, readout_weights_1, readout_weights_2)

    loss = torch.nn.functional.mse_loss(readout_1, text_embedding_1) + torch.nn.functional.mse_loss(readout_2, text_embedding_2)
    loss.backward()
    optimizer.step()
    scheduler.step()
    if (epoch + 1) % 50 == 0:
        print(f"epoch {epoch+1:03d} | loss: {loss.item():.6f} | lr: {scheduler.get_last_lr()[0]:.6f}")

epoch 050 | loss: 122196.765625 | lr: 0.001000
epoch 100 | loss: 13640.766602 | lr: 0.001000
epoch 150 | loss: 45010.648438 | lr: 0.001000
epoch 200 | loss: 73837.250000 | lr: 0.001000
epoch 250 | loss: 13244.179688 | lr: 0.001000
epoch 300 | loss: 94504.351562 | lr: 0.001000
epoch 350 | loss: 20004.126953 | lr: 0.001000
epoch 400 | loss: 15447.396484 | lr: 0.001000
epoch 450 | loss: 11261.937500 | lr: 0.001000
epoch 500 | loss: 28022.597656 | lr: 0.001000
epoch 550 | loss: 12588.578125 | lr: 0.001000
epoch 600 | loss: 9768.969727 | lr: 0.001000
epoch 650 | loss: 11284.344727 | lr: 0.001000
epoch 700 | loss: 42582.859375 | lr: 0.001000
epoch 750 | loss: 8473.489258 | lr: 0.001000
epoch 800 | loss: 6918.554688 | lr: 0.001000
epoch 850 | loss: 17165.023438 | lr: 0.001000
epoch 900 | loss: 11497.678711 | lr: 0.001000
epoch 950 | loss: 28800.306641 | lr: 0.001000
epoch 1000 | loss: 4237.684082 | lr: 0.001000
epoch 1050 | loss: 22338.453125 | lr: 0.001000
epoch 1100 | loss: 2881.075195 | lr