In [None]:
import torch
def get_unique_connections(in_dim, out_dim, device='cuda'):
    assert out_dim * 2 >= in_dim, 'The number of neurons ({}) must not be smaller than half of the number of inputs ' \
                                  '({}) because otherwise not all inputs could be used or considered.'.format(
        out_dim, in_dim
    )

    x = torch.arange(in_dim).long().unsqueeze(0)

    # Take pairs (0, 1), (2, 3), (4, 5), ...
    a, b = x[..., ::2], x[..., 1::2]
    if a.shape[-1] != b.shape[-1]:
        m = min(a.shape[-1], b.shape[-1])
        a = a[..., :m]
        b = b[..., :m]

    # If this was not enough, take pairs (1, 2), (3, 4), (5, 6), ...
    if a.shape[-1] < out_dim:
        a_, b_ = x[..., 1::2], x[..., 2::2]
        a = torch.cat([a, a_], dim=-1)
        b = torch.cat([b, b_], dim=-1)
        if a.shape[-1] != b.shape[-1]:
            m = min(a.shape[-1], b.shape[-1])
            a = a[..., :m]
            b = b[..., :m]

    # If this was not enough, take pairs with offsets >= 2:
    offset = 2
    while out_dim > a.shape[-1] > offset:
        a_, b_ = x[..., :-offset], x[..., offset:]
        a = torch.cat([a, a_], dim=-1)
        b = torch.cat([b, b_], dim=-1)
        offset += 1
        assert a.shape[-1] == b.shape[-1], (a.shape[-1], b.shape[-1])

    if a.shape[-1] >= out_dim:
        a = a[..., :out_dim]
        b = b[..., :out_dim]
    else:
        assert False, (a.shape[-1], offset, out_dim)

    perm = torch.randperm(out_dim)

    a = a[:, perm].squeeze(0)
    b = b[:, perm].squeeze(0)

    a, b = a.to(torch.int64), b.to(torch.int64)
    a, b = a.to(device), b.to(device)
    a, b = a.contiguous(), b.contiguous()
    return a, b


In [12]:
get_unique_connections(4,3)

(tensor([1, 0, 2], device='cuda:0'), tensor([2, 1, 3], device='cuda:0'))

In [1]:
import torch
import torch.nn.functional as F
import time

# Create 16 random logits
torch.manual_seed(42)
logits = torch.randn(16, requires_grad=True)

print("=" * 60)
print("COMPARISON: LogSumExp vs Softmax in PyTorch")
print("=" * 60)
print(f"\nInput logits (16 values):\n{logits}")
print(f"Shape: {logits.shape}")

# ============================================================
# 1. BASIC COMPUTATIONS
# ============================================================
print("\n" + "=" * 60)
print("1. BASIC COMPUTATIONS")
print("=" * 60)

# LogSumExp: returns a scalar (log of sum of exponentials)
lse = torch.logsumexp(logits, dim=0)
print(f"\nLogSumExp (scalar): {lse.item():.6f}")

# Softmax: returns a probability distribution (same shape as input)
softmax_out = F.softmax(logits, dim=0)
print(f"\nSoftmax (distribution):\n{softmax_out}")
print(f"Sum of softmax: {softmax_out.sum().item():.6f}")

# ============================================================
# 2. MATHEMATICAL RELATIONSHIP
# ============================================================
print("\n" + "=" * 60)
print("2. MATHEMATICAL RELATIONSHIP")
print("=" * 60)

# Softmax is defined as: exp(x_i) / sum(exp(x_j))
# Which equals: exp(x_i - logsumexp(x))

# Method 1: Using softmax function
softmax_v1 = F.softmax(logits, dim=0)

# Method 2: Manual calculation using logsumexp
lse = torch.logsumexp(logits, dim=0)
softmax_v2 = torch.exp(logits - lse)

print(f"\nSoftmax using F.softmax:\n{softmax_v1}")
print(f"\nSoftmax using exp(x - logsumexp(x)):\n{softmax_v2}")
print(f"\nAre they equal? {torch.allclose(softmax_v1, softmax_v2)}")
print(f"Max difference: {(softmax_v1 - softmax_v2).abs().max().item():.2e}")

# ============================================================
# 3. NUMERICAL STABILITY TEST
# ============================================================
print("\n" + "=" * 60)
print("3. NUMERICAL STABILITY TEST")
print("=" * 60)

# Test with extreme values
extreme_logits = torch.tensor([1000.0, 999.0, 998.0, 997.0, -1000.0, -999.0, 
                               -998.0, -997.0, 500.0, 0.0, -500.0, 100.0, 
                               -100.0, 50.0, -50.0, 1.0], requires_grad=True)

print(f"\nExtreme logits:\n{extreme_logits}")

# Safe: Using logsumexp (numerically stable)
lse_stable = torch.logsumexp(extreme_logits, dim=0)
print(f"\nLogsumexp (stable): {lse_stable.item():.6f}")

# Safe: Using softmax (uses logsumexp internally)
softmax_stable = F.softmax(extreme_logits, dim=0)
print(f"\nSoftmax (stable):\n{softmax_stable}")
print(f"No NaN/Inf: {not torch.isnan(softmax_stable).any() and not torch.isinf(softmax_stable).any()}")

# Unsafe: Manual exp then sum (can overflow/underflow)
try:
    manual_exp = torch.exp(extreme_logits)
    manual_sum = manual_exp.sum()
    manual_softmax = manual_exp / manual_sum
    print(f"\nManual exp/sum (unstable):\n{manual_softmax}")
    print(f"Contains Inf? {torch.isinf(manual_exp).any()}")
except Exception as e:
    print(f"\nManual calculation failed: {e}")

# ============================================================
# 4. LOG-SOFTMAX COMPARISON
# ============================================================
print("\n" + "=" * 60)
print("4. LOG-SOFTMAX (RELATED OPERATION)")
print("=" * 60)

# LogSoftmax = log(softmax(x)) = x - logsumexp(x)
log_softmax_v1 = F.log_softmax(logits, dim=0)
log_softmax_v2 = logits - torch.logsumexp(logits, dim=0)

print(f"\nLog-Softmax using F.log_softmax:\n{log_softmax_v1}")
print(f"\nLog-Softmax using x - logsumexp(x):\n{log_softmax_v2}")
print(f"\nAre they equal? {torch.allclose(log_softmax_v1, log_softmax_v2)}")

# ============================================================
# 5. PERFORMANCE COMPARISON
# ============================================================
print("\n" + "=" * 60)
print("5. PERFORMANCE COMPARISON")
print("=" * 60)

# Create larger tensor for timing
large_logits = torch.randn(10000, requires_grad=True)
n_iterations = 10000

# Time logsumexp
start = time.time()
for _ in range(n_iterations):
    _ = torch.logsumexp(large_logits, dim=0)
lse_time = time.time() - start

# Time softmax
start = time.time()
for _ in range(n_iterations):
    _ = F.softmax(large_logits, dim=0)
softmax_time = time.time() - start

# Time log_softmax
start = time.time()
for _ in range(n_iterations):
    _ = F.log_softmax(large_logits, dim=0)
log_softmax_time = time.time() - start

print(f"\nTiming (10000 iterations, 10000 elements):")
print(f"LogSumExp:   {lse_time:.4f}s ({lse_time/n_iterations*1e6:.2f} µs/iter)")
print(f"Softmax:     {softmax_time:.4f}s ({softmax_time/n_iterations*1e6:.2f} µs/iter)")
print(f"Log-Softmax: {log_softmax_time:.4f}s ({log_softmax_time/n_iterations*1e6:.2f} µs/iter)")

# ============================================================
# 6. GRADIENT COMPUTATION
# ============================================================
print("\n" + "=" * 60)
print("6. GRADIENT COMPUTATION")
print("=" * 60)

logits_grad = torch.randn(16, requires_grad=True)

# Gradient of logsumexp
lse_out = torch.logsumexp(logits_grad, dim=0)
lse_out.backward()
print(f"\nGradient of LogSumExp:")
print(f"{logits_grad.grad}")
print(f"Sum of gradients: {logits_grad.grad.sum().item():.6f}")
print(f"(Note: LSE gradient equals softmax)")

# Verify: gradient of logsumexp should equal softmax
softmax_check = F.softmax(logits_grad.detach(), dim=0)
print(f"\nSoftmax of same input:\n{softmax_check}")
print(f"Gradients match softmax? {torch.allclose(logits_grad.grad, softmax_check)}")

print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)



COMPARISON: LogSumExp vs Softmax in PyTorch

Input logits (16 values):
tensor([ 1.9269,  1.4873,  0.9007, -2.1055,  0.6784, -1.2345, -0.0431, -1.6047,
        -0.7521,  1.6487, -0.3925, -1.4036, -0.7279, -0.5594, -0.7688,  0.7624],
       requires_grad=True)
Shape: torch.Size([16])

1. BASIC COMPUTATIONS

LogSumExp (scalar): 3.316051

Softmax (distribution):
tensor([0.2493, 0.1606, 0.0893, 0.0044, 0.0715, 0.0106, 0.0348, 0.0073, 0.0171,
        0.1888, 0.0245, 0.0089, 0.0175, 0.0207, 0.0168, 0.0778],
       grad_fn=<SoftmaxBackward0>)
Sum of softmax: 1.000000

2. MATHEMATICAL RELATIONSHIP

Softmax using F.softmax:
tensor([0.2493, 0.1606, 0.0893, 0.0044, 0.0715, 0.0106, 0.0348, 0.0073, 0.0171,
        0.1888, 0.0245, 0.0089, 0.0175, 0.0207, 0.0168, 0.0778],
       grad_fn=<SoftmaxBackward0>)

Softmax using exp(x - logsumexp(x)):
tensor([0.2493, 0.1606, 0.0893, 0.0044, 0.0715, 0.0106, 0.0348, 0.0073, 0.0171,
        0.1888, 0.0245, 0.0089, 0.0175, 0.0207, 0.0168, 0.0778],
       grad_fn=

In [4]:
weights = torch.nn.parameter.Parameter(torch.randn(3, 16))

In [5]:
weights

Parameter containing:
tensor([[-0.1453,  0.8568, -1.0604, -0.4739,  0.8355,  0.5523,  0.9878,  0.2472,
         -1.3898,  0.7879,  0.2619,  0.3239, -0.0066, -0.6608,  1.0100,  1.5781],
        [ 0.3673,  0.9278,  0.3392, -0.2632, -0.8080,  0.0591, -0.4664, -0.8863,
         -0.3549, -0.2807,  0.1187, -1.6083, -0.0530, -0.7546,  1.1250,  0.2149],
        [ 1.1615,  1.4790, -0.2460,  0.7622, -0.1894,  0.0621,  0.2685,  0.6053,
          0.0126, -0.1968,  1.6451,  1.3394,  1.2698, -0.4576,  0.6375,  0.3438]],
       requires_grad=True)

In [27]:
from entmax import normmax_bisect

normmax_bisect(weights,alpha=1.5,dim=1)

tensor([[0.0000, 0.0411, 0.0000, 0.0000, 0.0339, 0.0000, 0.1004, 0.0000, 0.0000,
         0.0204, 0.0000, 0.0000, 0.0000, 0.0000, 0.1130, 0.6912],
        [0.0117, 0.3687, 0.0069, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6127, 0.0000],
        [0.0641, 0.2564, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.4082, 0.1560, 0.1152, 0.0000, 0.0000, 0.0000]],
       grad_fn=<NormmaxBisectFunctionBackward>)

In [15]:
torch.nn.functional.softmax(weights)

  torch.nn.functional.softmax(weights)


tensor([[0.0327, 0.0891, 0.0131, 0.0236, 0.0872, 0.0657, 0.1016, 0.0484, 0.0094,
         0.0832, 0.0492, 0.0523, 0.0376, 0.0195, 0.1039, 0.1834],
        [0.0839, 0.1469, 0.0816, 0.0447, 0.0259, 0.0616, 0.0364, 0.0239, 0.0407,
         0.0439, 0.0654, 0.0116, 0.0551, 0.0273, 0.1790, 0.0720],
        [0.0943, 0.1296, 0.0231, 0.0633, 0.0244, 0.0314, 0.0386, 0.0541, 0.0299,
         0.0243, 0.1530, 0.1127, 0.1051, 0.0187, 0.0559, 0.0416]],
       grad_fn=<SoftmaxBackward0>)