In [1]:
import torch

In [46]:
class NegativeSampler:
    def __init__(self, cfg):
        self.cfg = cfg

    def sample(self, data):
        pos_edge_index = data['user', 'interacts', 'item'].edge_label_index
        users = data['user', 'interacts', 'item'].edge_label_index[0]
        items = data['user', 'interacts', 'item'].edge_label_index[1]
        num_users = data['user'].num_nodes
        num_items = data['item'].num_nodes
        num_neg_samples, negative_sampling_ratio = self.obtain_num_neg_samples(num_users, num_items, pos_edge_index.size(1))

        sampling_func = self.get_sampling_func()

        mask = torch.ones(num_neg_samples, dtype=torch.bool, device=device)
        neg_edge_index = sampling_func(data, mask)
        mask = self.collision_check(pos_edge_index, neg_edge_index)
        i = 0
        while mask.any() and i < 3:  # Limit to 3 attempts to avoid infinite loop
            neg_edge_index = sampling_func(data, mask, neg_edge_index[0], neg_edge_index[1])
            mask = self.collision_check(pos_edge_index, neg_edge_index)
            i += 1
        # logging.debug(f"Positive edges: {pos_edge_index[:,:10]} ...")  # Log first 10 positive edges
        # logging.debug(f"Negative edges: {neg_edge_index[:,:10]} ...")  # Log first 10 negative edges
        # uniq_cols, counts = torch.unique(neg_edge_index, dim=1, return_counts=True)
        # logging.debug(uniq_cols)
        # logging.debug(f"N. unique edges: {uniq_cols.size(1)/neg_edge_index.size(1):%}")
        # logging.debug(f"Negative sampling completed in {i+1} attempts. Number of false negatives: {mask.sum().item()}") if i > 0 else None
        neg_edge_label = torch.zeros(neg_edge_index.size(1), dtype=torch.float32, device=device)
        return neg_edge_index, neg_edge_label


    def get_sampling_func(self):
        sampling_strategy = self.cfg.get('negative_sampling_method', 'batch_random')
        if sampling_strategy == 'batch_random':
            return self.batch_random_sample
        elif sampling_strategy == 'pairwise_random':
            return self.pairwise_random_sample
        else:
            raise ValueError(f"Unknown sampling strategy: {sampling_strategy}")


    def eval_sample(self, data):
        """Return directly val/test data with negative samples included."""
        neg_edge_index, neg_edge_label = self.sample(data)
        pos_edge_index, pos_edge_label = data['user', 'interacts', 'item'].edge_label_index, data['user', 'interacts', 'item'].edge_label
        edge_label_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
        edge_label = torch.cat([pos_edge_label, neg_edge_label], dim=0)
        data['user', 'interacts', 'item'].edge_label_index = edge_label_index
        data['user', 'interacts', 'item'].edge_label = edge_label
        return data


    def batch_random_sample(self, data, mask, negative_users = None, negative_items = None):
        "Custom negative sampling: generate user-item pairs not in positive set."
        num_users = data['user'].num_nodes
        num_items = data['item'].num_nodes
        num_neg_samples = mask.sum().item()
        if negative_users is None or negative_items is None:
            negative_users = torch.randint(0, num_users, (num_neg_samples,), device=device)
            negative_items = torch.randint(0, num_items, (num_neg_samples,), device=device)
        else:
            negative_users[mask] = torch.randint(0, num_users, (num_neg_samples,), device=device)
            negative_items[mask] = torch.randint(0, num_items, (num_neg_samples,), device=device)
        neg_edge_index = torch.stack([negative_users, negative_items], dim=0)
        return neg_edge_index

    def pairwise_random_sample(self, data, mask, negative_users = None, negative_items = None):
        "Pairwise negative sampling: for each user, sample negative items."
        pos_edge_index = data['user', 'interacts', 'item'].edge_label_index
        users = pos_edge_index[0]
        num_items = data['item'].num_nodes
        
        num_neg_samples = mask.sum().item()
        if negative_items is None:
            negative_users = users.repeat_interleave(num_neg_samples // users.size(0))
            negative_items = torch.randint(0, num_items, (num_neg_samples,), device=device)
        else:
            negative_items[mask] = torch.randint(0, num_items, (num_neg_samples,), device=device)
        neg_edge_index = torch.stack([negative_users, negative_items], dim=0)
        return neg_edge_index


    def obtain_num_neg_samples(self, num_users, num_items, num_pos_edges):
        negative_sampling_ratio = self.cfg['negative_sampling_ratio']
        max_edges = num_users * num_items
        num_neg_samples = int(negative_sampling_ratio * num_pos_edges)
        if num_neg_samples > max_edges:
            negative_sampling_ratio = max_edges // num_pos_edges
            num_neg_samples = negative_sampling_ratio * num_pos_edges  # Otherwise sampling error
        return num_neg_samples, negative_sampling_ratio


    def collision_check(self, pos_edge_index, neg_edge_index):
        """
        For the collision, we will appy quotient-remainder theorem to hash the edges.
        Theorem: For any integer a and any positive integer b, there exist unique integers q and r such that
        a = bq + r and 0 <= r < b.

        """
        edge_max = max(pos_edge_index.max().item(), neg_edge_index.max().item()) + 1
        hashd_pos = (pos_edge_index[0, :]  + pos_edge_index[1:, :] * edge_max)[0]
        hashd_neg = (neg_edge_index[0, :]  + neg_edge_index[1:, :] * edge_max)[0]
        mask = torch.isin(hashd_neg, hashd_pos)
        return mask


In [6]:
from torch_geometric.data import HeteroData
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data = HeteroData()
data['user'].num_nodes = 3
data['item'].num_nodes = 4
# Example: user 0 interacts with item 1, user 1 with item 2, user 2 with item 3
edge_index = torch.tensor([[0, 1, 2, 0], [1, 2, 3, 0]], dtype=torch.long)
edge_label = torch.ones(edge_index.size(1), dtype=torch.float32, device=device)
data['user', 'interacts', 'item'].edge_label_index = edge_index
data['user', 'interacts', 'item'].edge_label = edge_label
data

HeteroData(
  user={ num_nodes=3 },
  item={ num_nodes=4 },
  (user, interacts, item)={
    edge_label_index=[2, 4],
    edge_label=[4],
  }
)

In [47]:
cfg = {'negative_sampling_ratio': 2, 'negative_sampling_method': 'pairwise_random'}
sampler = NegativeSampler(cfg)
neg_edge_index, neg_edge_label = sampler.sample(data)
neg_edge_index

tensor([[0, 0, 1, 1, 2, 2, 0, 0],
        [2, 2, 3, 0, 0, 0, 2, 2]])

In [48]:
pos_edge_index, pos_edge_label = data['user', 'interacts', 'item'].edge_label_index, data['user', 'interacts', 'item'].edge_label
pos_edge_index

tensor([[0, 1, 2, 0],
        [1, 2, 3, 0]])

In [14]:
edge_label = torch.cat([pos_edge_label, neg_edge_label], dim=0)
preds = torch.rand(edge_label.size(0), device=device)
preds

tensor([0.2625, 0.8114, 0.6996, 0.9176, 0.2163, 0.2267, 0.6081, 0.9973, 0.4342,
        0.8890, 0.1594, 0.3710])

In [49]:
pos_mask = edge_label == 1
pos_preds = preds[pos_mask]
neg_preds = preds[~pos_mask]
print(f"Positive predictions: {pos_preds}")
print(f"Negative predictions: {neg_preds}")

Positive predictions: tensor([0.2625, 0.8114, 0.6996, 0.9176])
Negative predictions: tensor([0.2163, 0.2267, 0.6081, 0.9973, 0.4342, 0.8890, 0.1594, 0.3710])


In [50]:
diff = pos_preds.unsqueeze(1) - neg_preds.unsqueeze(0)
diff

tensor([[ 0.0462,  0.0358, -0.3457, -0.7349, -0.1718, -0.6266,  0.1030, -0.1086],
        [ 0.5952,  0.5847,  0.2033, -0.1859,  0.3772, -0.0776,  0.6520,  0.4404],
        [ 0.4834,  0.4729,  0.0915, -0.2977,  0.2654, -0.1894,  0.5402,  0.3286],
        [ 0.7014,  0.6910,  0.3095, -0.0797,  0.4834,  0.0286,  0.7582,  0.5466]])

In [51]:
pos_preds.unsqueeze(1)

tensor([[0.2625],
        [0.8114],
        [0.6996],
        [0.9176]])

In [56]:
neg_preds_re = neg_preds.view(pos_preds.size(0), -1)
neg_preds_re

tensor([[0.2163, 0.2267],
        [0.6081, 0.9973],
        [0.4342, 0.8890],
        [0.1594, 0.3710]])

In [39]:
diff = pos_preds.unsqueeze(1) - neg_preds_re
diff

tensor([[ 0.0462,  0.0358],
        [ 0.2033, -0.1859],
        [ 0.2654, -0.1894],
        [ 0.7582,  0.5466]])

In [54]:
pos_preds_expanded = pos_preds.repeat_interleave(neg_preds.size(0) // pos_preds.size(0))
print(f"{pos_preds_expanded}")
print(f"{neg_preds}")
diff2 = pos_preds_expanded - neg_preds
print(diff2)
diff2 = diff2.reshape(pos_preds.size(0), -1)
diff2

tensor([0.2625, 0.2625, 0.8114, 0.8114, 0.6996, 0.6996, 0.9176, 0.9176])
tensor([0.2163, 0.2267, 0.6081, 0.9973, 0.4342, 0.8890, 0.1594, 0.3710])
tensor([ 0.0462,  0.0358,  0.2033, -0.1859,  0.2654, -0.1894,  0.7582,  0.5466])


tensor([[ 0.0462,  0.0358],
        [ 0.2033, -0.1859],
        [ 0.2654, -0.1894],
        [ 0.7582,  0.5466]])

In [10]:
def bpr_loss(preds, edge_label):
    """
    Compute the Bayesian Personalized Ranking (BPR) loss.
    It encourages the model to rank positive edges higher than negative edges.
    Formula: -log(sigmoid(pos_preds - neg_preds))
    """
    EPS = 1e-15

    pos_mask = edge_label == 1
    pos_preds = preds[pos_mask]
    neg_preds = preds[~pos_mask]

    pos_preds_expanded = pos_preds.repeat(neg_preds.size(0) // pos_preds.size(0))

    # Compute the difference between positive and negative predictions
    diff = pos_preds_expanded - neg_preds  # Shape: 

    # Apply the BPR loss formula
    loss = F.softplus(-diff).mean()  # -log(sigmoid(x)) = softplus(-x)
    return loss

In [None]:
neg_edge_index