In [2]:
import torch

class Sampler:
    def __init__(self, temperature, k, sparse=False):
        self.temperature = temperature
        self.k = k
        self.sparse = sparse

    def sample_without_replacement(self, logits):
        print("Initial logits:", logits)

        b, n, _ = logits.shape
        print("b (batch size):", b)
        print("n (number of items):", n)

        logits = logits * torch.exp(torch.clamp(self.temperature, -5, 5))
        print("Logits after applying temperature:", logits)

        q = torch.rand_like(logits) + 1e-8
        print("Random values q:", q)

        lq = (logits - torch.log(-torch.log(q)))
        print("lq (logits minus Gumbel noise):", lq)

        logprobs, indices = torch.topk(-lq, self.k)
        print("logprobs (top k log probabilities):", logprobs)
        print("indices (top k indices):", indices)

        rows = torch.arange(n).view(1, n, 1).to(logits.device).repeat(b, 1, self.k)
        print("rows (arange n, repeated):", rows)

        edges = torch.stack((indices.view(b, -1), rows.view(b, -1)), -2)
        print("edges (stacked indices and rows):", edges)

        if self.sparse:
            result = (edges + (torch.arange(b).to(logits.device) * n)[:, None, None]).transpose(0, 1).reshape(2, -1)
            print("Final result (sparse):", result)
            return result, logprobs
        print("Final result (dense):", edges)
        return edges, logprobs

# Sample test to verify the function
torch.manual_seed(42)  # Set a seed for reproducibility

# Example inputs
x = torch.randn(2, 5, 3)  # batch size of 2, 5 items, 3 features
logits = torch.cdist(x,x)**2
temperature = torch.tensor(0.5)
k = 2
sparse = False

# Create sampler instance
sampler = Sampler(temperature=temperature, k=k, sparse=sparse)

# Call the function and see the printed output
edges, logprobs = sampler.sample_without_replacement(logits)


Initial logits: tensor([[[ 0.0000, 21.4742, 16.1729,  8.9208, 21.5902],
         [21.4742,  0.0000,  9.6989, 15.2697,  4.6017],
         [16.1729,  9.6989,  0.0000,  4.7560,  4.0100],
         [ 8.9208, 15.2697,  4.7560,  0.0000,  6.5103],
         [21.5902,  4.6017,  4.0100,  6.5103,  0.0000]],

        [[ 0.0000, 11.0390,  1.7975,  7.4636,  9.1241],
         [11.0390,  0.0000,  6.5223,  3.7378, 12.3459],
         [ 1.7975,  6.5223,  0.0000,  6.0772,  3.4406],
         [ 7.4636,  3.7378,  6.0772,  0.0000, 15.0821],
         [ 9.1241, 12.3459,  3.4406, 15.0821,  0.0000]]])
b (batch size): 2
n (number of items): 5
Logits after applying temperature: tensor([[[ 0.0000, 35.4049, 26.6646, 14.7079, 35.5962],
         [35.4049,  0.0000, 15.9908, 25.1756,  7.5869],
         [26.6646, 15.9908,  0.0000,  7.8413,  6.6115],
         [14.7079, 25.1756,  7.8413,  0.0000, 10.7336],
         [35.5962,  7.5869,  6.6115, 10.7336,  0.0000]],

        [[ 0.0000, 18.2003,  2.9636, 12.3054, 15.0431],
      