In [1]:
import torch
import torch_geometric

In [2]:
prune_method = 'thresh_abs'
edge_top_perc = 0.2
K = 3
thresh = 0.1

In [3]:
from model.graph_learner import *
hidden_dim = 384
num_nodes = 4
embed_dim = 16
num_dynamic_graphs = 2
resolution = 98
#TODO: add support for those two parameters
undirected_graph = True
regularizations = ["feature_smoothing", "degree", "sparse"]
attn_layers = GraphLearner(
            input_size=hidden_dim,
            hidden_size=hidden_dim,
            num_nodes=num_nodes,
            embed_dim=embed_dim,
            metric_type="self_attention",
        )

In [4]:
def get_knn_graph(x, k, dist_measure="cosine", undirected=undirected_graph):

    if dist_measure == "euclidean":
        dist = torch.cdist(x, x, p=2.0)
        dist = (dist - dist.min()) / (dist.max() - dist.min())
        knn_val, knn_ind = torch.topk(
            dist, k, dim=-1, largest=False
        )  # smallest distances
    elif dist_measure == "cosine":
        norm = torch.norm(x, dim=-1, p="fro")[:, :, None]
        x_norm = x / norm
        dist = torch.matmul(x_norm, x_norm.transpose(1, 2))
        knn_val, knn_ind = torch.topk(
            dist, k, dim=-1, largest=True
        )  # largest similarities
    else:
        raise NotImplementedError
    adj_mat = (torch.ones_like(dist) * 0).scatter_(-1, knn_ind, knn_val).to(x.device)

    adj_mat = torch.clamp(adj_mat, min=0.0)  # remove negatives

    if undirected:
        adj_mat = (adj_mat + adj_mat.transpose(1, 2)) / 2

    # add self-loop
    I = (
        torch.eye(adj_mat.shape[-1], adj_mat.shape[-1])
        .unsqueeze(0)
        .repeat(adj_mat.shape[0], 1, 1)
        .to(bool)
    ).to(x.device)
    adj_mat = adj_mat * (~I) + I

    # to sparse graph
    edge_index, edge_weight = torch_geometric.utils.dense_to_sparse(adj_mat)

    return edge_index, edge_weight, adj_mat

def prune_adj_mat(adj_mat, num_nodes, method="thresh", edge_top_perc=None, knn=None, thresh=None):
    
    if method == "thresh":
        sorted, indices = torch.sort(
            adj_mat.reshape(-1, num_nodes * num_nodes),
            dim=-1,
            descending=True,
        )
        K = int((num_nodes**2) * edge_top_perc)
        mask = adj_mat > sorted[:, K].unsqueeze(1).unsqueeze(2)
        adj_mat = adj_mat * mask
    elif method == "knn":
        knn_val, knn_ind = torch.topk(
            adj_mat, knn, dim=-1, largest=True
        )
        adj_mat = (torch.ones_like(adj_mat) * 0).scatter_(-1, knn_ind, knn_val).to(adj_mat.device)
    elif method == "thresh_abs":
        mask = (adj_mat > thresh).float()
        adj_mat = adj_mat * mask
    else:
        raise NotImplementedError

    return adj_mat

def calculate_normalized_laplacian(adj):
    """
    Args:
        adj: torch tensor, shape (batch, num_nodes, num_nodes)

    L = D^-1/2 (D-A) D^-1/2 = I - D^-1/2 A D^-1/2
    D = diag(A)
    """

    batch, num_nodes, _ = adj.shape
    d = adj.sum(-1)  # (batch, num_nodes)
    d_inv_sqrt = torch.pow(d, -0.5)
    d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.0
    d_mat_inv_sqrt = torch.diag_embed(d_inv_sqrt)  # (batch, num_nodes, num_nodes)

    identity = (torch.eye(num_nodes).unsqueeze(0).repeat(batch, 1, 1)).to(
        adj.device
    )  # (batch, num_nodes, num_nodes)
    normalized_laplacian = identity - torch.matmul(
        torch.matmul(d_mat_inv_sqrt, adj), d_mat_inv_sqrt
    )

def feature_smoothing(adj, X):

    # normalized laplacian
    L = calculate_normalized_laplacian(adj)

    feature_dim = X.shape[-1]
    mat = torch.matmul(torch.matmul(X.transpose(1, 2), L), X) / (feature_dim**2)
    loss = mat.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)  # batched trace
    return loss
#TODO: add support for those two parameters
def regularization_loss(x, adj, reduce="mean"):
    """
    Referred to https://github.com/hugochan/IDGL/blob/master/src/core/model_handler.py#L1116
    """
    batch, num_nodes, _ = x.shape
    n = num_nodes

    loss = {}

    if "feature_smoothing" in regularizations:
        curr_loss = feature_smoothing(adj=adj, X=x) / (n**2)
        if reduce == "mean":
            loss["feature_smoothing"] = torch.mean(curr_loss)
        elif reduce == "sum":
            loss["feature_smoothing"] = torch.sum(curr_loss)
        else:
            loss["feature_smoothing"] = curr_loss

    if "degree" in regularizations:
        ones = torch.ones(batch, num_nodes, 1).to(x.device)
        curr_loss = -(1 / n) * torch.matmul(
            ones.transpose(1, 2), torch.log(torch.matmul(adj, ones))
        ).squeeze(-1).squeeze(-1)
        if reduce == "mean":
            loss["degree"] = torch.mean(curr_loss)
        elif reduce == "sum":
            loss["degree"] = torch.sum(curr_loss)
        else:
            loss["degree"] = curr_loss

    if "sparse" in regularizations:
        curr_loss = (
            1 / (n**2) * torch.pow(torch.norm(adj, p="fro", dim=(-1, -2)), 2)
        )

        if reduce == "mean":
            loss["sparse"] = torch.mean(curr_loss)
        elif reduce == "sum":
            loss["sparse"] = torch.sum(curr_loss)
        else:
            loss["sparse"] = curr_loss

    if "symmetric" in regularizations and undirected_graph:
        curr_loss = torch.norm(adj - adj.transpose(1, 2), p="fro", dim=(-1, -2))
        if reduce == "mean":
            loss["symmetric"] = torch.mean(curr_loss)
        elif reduce == "sum":
            loss["symmetric"] = torch.sum(curr_loss)
        else:
            loss["symmetric"] = curr_loss

    return loss

In [5]:
inputs = torch.randn(32, 4, 196, 384)  # (batch_size, num_nodes, seq_len, hidden_dim)
batch,num_nodes,seq_len,hidden_dim = inputs.size()
#这一段是加入的
x = inputs.permute(0,2,1,3).contiguous() #(batch, seq_len, num_nodes, hidden_dim)
x_ = []
for t in range(num_dynamic_graphs):
    start = t * resolution
    stop = start + resolution
    curr_x = torch.mean(x[:, start:stop, :, :], dim=1)
    x_.append(curr_x)
x_ = torch.stack(
    x_, dim=1
)  # (batch, num_dynamic_graphs, num_nodes, hidden_dim)
x = x_.reshape(
    -1, num_nodes, hidden_dim
)  # (batch * num_dynamic_graphs, num_nodes, hidden_dim)


In [6]:
x.shape

torch.Size([64, 4, 384])

In [7]:
edge_index, edge_weight, adj_mat = get_knn_graph(
    x,
    K,
    dist_measure="cosine",
    undirected=True,
)
# edge_index = edge_index.to(x.device)
# edge_weight = edge_weight.to(x.device)
# adj_mat = adj_mat.to(x.device)

In [8]:
x.shape

torch.Size([64, 4, 384])

In [9]:
# learn adj mat
attn_weight = attn_layers(
    x
)  # (batch*num_dynamic_graphs, num_nodes, num_nodes)

# to undirected
attn_weight = (attn_weight + attn_weight.transpose(1, 2)) / 2
raw_attn_weight = attn_weight.clone()

In [10]:
attn_weight.shape

torch.Size([64, 4, 4])

In [11]:
# add residual
if len(adj_mat.shape) == 2:
    adj_mat = torch.cat([adj_mat] * num_dynamic_graphs * batch, dim=0)
elif len(adj_mat.shape) == 3 and (adj_mat.shape != attn_weight.shape):
    adj_mat = torch.cat([adj_mat] * num_dynamic_graphs, dim=0)

In [12]:
# knn graph weight (aka residual weight) decay
# if self.decay_residual_weight:
#     assert (epoch is not None) and (epoch_total is not None)
#     residual_weight = calculate_cosine_decay_weight(
#         max_weight=self.residual_weight, epoch=epoch, epoch_total=epoch_total, min_weight=0
#     )
# else:
#     residual_weight = 0.6
# add knn graph
residual_weight = 0.6
adj_mat = (
    residual_weight * adj_mat + (1 - residual_weight) * attn_weight
)

In [13]:
# prune graph
adj_mat = prune_adj_mat(
    adj_mat,
    num_nodes,
    method=prune_method,
    edge_top_perc=edge_top_perc,
    knn=K,
    thresh=thresh,
)

# regularization loss
#TODO add support for those two parameters
reg_losses = regularization_loss(x, adj=adj_mat)

# back to sparse graph
edge_index, edge_weight = torch_geometric.utils.dense_to_sparse(adj_mat)

TypeError: matmul(): argument 'other' (position 2) must be Tensor, not NoneType

In [None]:
# add self-loop
edge_index, edge_weight = torch_geometric.utils.remove_self_loops(
    edge_index=edge_index, edge_attr=edge_weight
)
edge_index, edge_weight = torch_geometric.utils.add_self_loops(
    edge_index=edge_index,
    edge_attr=edge_weight,
    fill_value=1,
)

x = x.view(
    batch * num_dynamic_graphs * num_nodes, -1
)  # (batch * num_dynamic_graphs * num_nodes, hidden_dim)
for i in range(len(self.gnn_layers)):
    # gnn layer
    x = self.gnn_layers[i](
        x, edge_index=edge_index, edge_attr=edge_weight.reshape(-1, 1)
    )
    x = self.dropout(
        self.activation(x)
    )  # (batch * num_dynamic_graphs * num_nodes, hidden_dim)
x = x.view(batch * num_dynamic_graphs, num_nodes, -1).view(
    batch, num_dynamic_graphs, num_nodes, -1
)  
 # (batch, num_dynamic_graphs, num_nodes, hidden_dim)

# temporal pool
if self.temporal_pool == "last":
    x = x[:, -1, :, :]  # (batch, num_nodes, hidden_dim)
elif self.temporal_pool == "mean":
    x = torch.mean(x, dim=1)
else:
    raise NotImplementedError
