In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn

Using backend: pytorch


In [2]:
class LabelPropagation(nn.Module):
    r"""

    Description
    -----------
    Introduced in `Learning from Labeled and Unlabeled Data with Label Propagation <https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.14.3864&rep=rep1&type=pdf>`_

    .. math::
        \mathbf{Y}^{\prime} = \alpha \cdot \mathbf{D}^{-1/2} \mathbf{A}
        \mathbf{D}^{-1/2} \mathbf{Y} + (1 - \alpha) \mathbf{Y},

    where unlabeled data is inferred by labeled data via propagation.

    Parameters
    ----------
        num_layers: int
            The number of propagations.
        alpha: float
            The :math:`\alpha` coefficient.
    """
    def __init__(self, num_layers, alpha):
        super(LabelPropagation, self).__init__()

        self.num_layers = num_layers
        self.alpha = alpha
    
    # 计算公式: Y' = a D^{-0.5} A D^{-0.5} Y + (1-a) Y
    @torch.no_grad()
    def forward(self, g, labels, mask=None, post_step=lambda y: y.clamp_(0., 1.)):
        with g.local_scope():
            # one-hot化
            if labels.dtype == torch.long:
                labels = F.one_hot(labels.view(-1)).to(torch.float32)
            
            # 这里面mask是train_idx，只保留这些标签，其余要做test的标签全部置0
            y = labels
            if mask is not None:
                y = torch.zeros_like(labels)
                y[mask] = labels[mask]
            
            # 计算公式: Y' = a D^{-0.5} A D^{-0.5} Y + (1-a) Y
            
            # (1-a) Y
            last = (1 - self.alpha) * y
            # degs = D 
            degs = g.in_degrees().float().clamp(min=1)
            # norm = D^{-0.5}
            norm = torch.pow(degs, -0.5).to(labels.device).unsqueeze(1)

            for _ in range(self.num_layers):
                # Assume the graphs to be undirected
                # D^{-0.5} Y
                g.ndata['h'] = y * norm
                # A D^{-0.5} Y
                g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
                # last = (1-a) Y
                # g.ndata.pop('h') = A D^{-0.5} Y
                # g.ndata.pop('h') * norm = D^{-0.5} A D^{-0.5} Y
                y = last + self.alpha * g.ndata.pop('h') * norm
                # post_step=lambda y: y.clamp_(0., 1.)
                # [min, max] => 取值范围[0, 1]
                y = post_step(y)
                last = (1 - self.alpha) * y
            
            return y