In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch.nn import MultiheadAttention
from sparse_generalization.layers.thresh_mha import MultiHeadAttentionThresh
from sparse_generalization.losses.sparse_loss import L1SparsityWeights, L1SparsityAdjacency


torch.manual_seed(0)

logits = torch.randn((4, 3, 2), requires_grad=True)

mha = MultiHeadAttentionThresh(embed_size=2, num_heads=1, batch_first=True)
mha2 = MultiheadAttention(embed_dim=2, num_heads=1, batch_first=True)


optimizer = torch.optim.Adam([logits], lr=0.01)
weight = 1
loss_func = L1SparsityWeights()

print("Initial logits:")
print(logits.data)

def copy_weights(torch_mha, my_mha):
    Wq, Wk, Wv = torch.chunk(torch_mha.in_proj_weight.data, 3, dim=0)
    bq, bk, bv = torch.chunk(torch_mha.in_proj_bias.data, 3, dim=0)
    my_mha.queries.weight.data = Wq.clone()
    my_mha.queries.bias.data   = bq.clone()
    my_mha.keys.weight.data    = Wk.clone()
    my_mha.keys.bias.data      = bk.clone()
    my_mha.values.weight.data  = Wv.clone()
    my_mha.values.bias.data    = bv.clone()
    my_mha.projection.weight.data = torch_mha.out_proj.weight.data.clone()
    my_mha.projection.bias.data   = torch_mha.out_proj.bias.data.clone()

A_sums = []

copy_weights(mha2, mha)

for step in range(1):
    optimizer.zero_grad()

    out, attn = mha(logits, logits, logits)

    A = attn  
    loss = loss_func(A)

    loss.backward()

    A_sum_value = (A > 0.1).detach().sum(dim=(1, 2)).float().mean().item()
    A_sums.append(A_sum_value)

    # print(f"\n--- Step {step} ---")
    # print("Loss:", loss.item())
    # print("Grad magnitude:", logits.grad)
    # print("Logits before step:\n", logits.data)
    # print("Adjacency A:\n", A.detach())
    # print("Adjacency A (sum):", A_sum_value)

    optimizer.step()

    # print("Logits after step:\n", logits.data)

plt.figure(figsize=(6, 4))
plt.plot(A_sums)
plt.title("Sparsity Curve: Num of weights above 0.1")
plt.xlabel("Step")
plt.ylabel("Mean over batch")
plt.grid(True)
plt.show()


In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from sparse_generalization.layers.bern_mha import MultiHeadAttentionBern
from sparse_generalization.losses.sparse_loss import L1SparsityAdjacency

torch.manual_seed(0)

# Learnable logits
logits = torch.randn((4, 3, 2), requires_grad=True)

mha = MultiHeadAttentionBern(embed_size=2, num_heads=1, hard=True, temp=0.5)

optimizer = torch.optim.Adam([logits], lr=0.1)
weight = 1
loss_func = L1SparsityAdjacency()

print("Initial logits:")
print(logits.data)

A_sums = [] 

for step in range(50):
    optimizer.zero_grad()

    out, attn = mha(logits, logits, logits)

    A = attn 

    loss = weight * loss_func(A)

    loss.backward()

    A_sum_value = A.detach().sum(dim=(1, 2)).mean().item()
    A_sums.append(A_sum_value)

    print(f"\n--- Step {step} ---")
    print("Loss:", loss.item())
    print("Grad magnitude:", logits.grad)
    print("Logits before step:\n", logits.data)
    print("Adjacency A:\n", A.detach())
    print("Adjacency A (sum):", A_sum_value)

    optimizer.step()


plt.figure(figsize=(6, 4))
plt.plot(A_sums)
plt.title("Sparsity Curve: A.sum over Optimization Steps")
plt.xlabel("Step")
plt.ylabel("Mean A.sum()")
plt.grid(True)
plt.show()
