In [1]:
# |default_exp models_graph


In [2]:
#| export
from torch_geometric.datasets import MoleculeNet
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch.utils.data import Subset
import torch
from torch import nn
from torch_geometric.data import Data
from torch_geometric.utils import degree


from typing import Tuple, Dict, List, Union
from torch.multiprocessing import spawn

import networkx as nx
from torch_geometric.data import Data, Batch
from torch_geometric.utils.convert import to_networkx
from torch_geometric.utils import to_dense_batch
import torch.nn.functional as F

In [None]:
#| export
def decrease_to_max_value(x, max_value):
    x[x > max_value] = max_value
    return x

class CentralityEncoding(nn.Module):
    def __init__(self, max_in_degree: int, max_out_degree: int, node_dim: int):
        """
        :param max_in_degree: max in degree of nodes
        :param max_out_degree: max in degree of nodes
        :param node_dim: hidden dimensions of node features
        """
        super().__init__()
        self.max_in_degree = max_in_degree
        self.max_out_degree = max_out_degree
        self.node_dim = node_dim
        self.z_in = nn.Parameter(torch.randn((max_in_degree, node_dim)))
        self.z_out = nn.Parameter(torch.randn((max_out_degree, node_dim)))

    def forward(self, x: torch.Tensor, edge_index: torch.LongTensor) -> torch.Tensor:
        """
        :param x: node feature matrix
        :param edge_index: edge_index of graph (adjacency list)
        :return: torch.Tensor, node embeddings after Centrality encoding
        """
        num_nodes = x.shape[0]

        in_degree = decrease_to_max_value(degree(index=edge_index[1], num_nodes=num_nodes).long(), self.max_in_degree)
        out_degree = decrease_to_max_value(degree(index=edge_index[0], num_nodes=num_nodes).long(), self.max_out_degree)

        x += self.z_in[in_degree] + self.z_out[out_degree]

        return x


class SpatialEncoding(nn.Module):
    def __init__(self, max_path_distance: int):
        """
        :param max_path_distance: max pairwise distance between nodes
        """
        super().__init__()
        self.max_path_distance = max_path_distance

        self.b = nn.Parameter(torch.randn(self.max_path_distance))

    def forward(self, x: torch.Tensor, paths) -> torch.Tensor:
        """
        :param x: node feature matrix
        :param paths: pairwise node paths
        :return: torch.Tensor, spatial Encoding matrix
        """
        spatial_matrix = torch.zeros((x.shape[0], x.shape[0])).to(next(self.parameters()).device)
        for src in paths:
            for dst in paths[src]:
                spatial_matrix[src][dst] = self.b[min(len(paths[src][dst]), self.max_path_distance) - 1]

        return spatial_matrix


def dot_product(x1, x2) -> torch.Tensor:
    return (x1 * x2).sum(dim=1)


class EdgeEncoding(nn.Module):
    def __init__(self, edge_dim: int, max_path_distance: int):
        """
        :param edge_dim: edge feature matrix number of dimension
        """
        super().__init__()
        self.edge_dim = edge_dim
        self.max_path_distance = max_path_distance
        self.edge_vector = nn.Parameter(torch.randn(self.max_path_distance, self.edge_dim))

    def forward(self, x: torch.Tensor, edge_attr: torch.Tensor, edge_paths) -> torch.Tensor:
        """
        :param x: node feature matrix
        :param edge_attr: edge feature matrix
        :param edge_paths: pairwise node paths in edge indexes
        :return: torch.Tensor, Edge Encoding matrix
        """
        cij = torch.zeros((x.shape[0], x.shape[0])).to(next(self.parameters()).device)

        for src in edge_paths:
            for dst in edge_paths[src]:
                path_ij = edge_paths[src][dst][:self.max_path_distance]
                weight_inds = [i for i in range(len(path_ij))]
                cij[src][dst] = dot_product(self.edge_vector[weight_inds], edge_attr[path_ij]).mean()

        cij = torch.nan_to_num(cij)
        return cij


class GraphormerAttentionHead(nn.Module):
    def __init__(self, dim_in: int, dim_q: int, dim_k: int, edge_dim: int, max_path_distance: int):
        """
        :param dim_in: node feature matrix input number of dimension
        :param dim_q: query node feature matrix input number dimension
        :param dim_k: key node feature matrix input number of dimension
        :param edge_dim: edge feature matrix number of dimension
        """
        super().__init__()
        self.edge_encoding = EdgeEncoding(edge_dim, max_path_distance)

        self.q = nn.Linear(dim_in, dim_q)
        self.k = nn.Linear(dim_in, dim_k)
        self.v = nn.Linear(dim_in, dim_k)

    def forward(self,
                query: torch.Tensor,
                key: torch.Tensor,
                value: torch.Tensor,
                edge_attr: torch.Tensor,
                b: torch.Tensor,
                edge_paths,
                ptr) -> torch.Tensor:
        """
        :param query: node feature matrix
        :param key: node feature matrix
        :param value: node feature matrix
        :param edge_attr: edge feature matrix
        :param b: spatial Encoding matrix
        :param edge_paths: pairwise node paths in edge indexes
        :param ptr: batch pointer that shows graph indexes in batch of graphs
        :return: torch.Tensor, node embeddings after attention operation
        """
        batch_mask_neg_inf = torch.full(size=(query.shape[0], query.shape[0]), fill_value=-1e6).to(next(self.parameters()).device)
        batch_mask_zeros = torch.zeros(size=(query.shape[0], query.shape[0])).to(next(self.parameters()).device)

        # OPTIMIZE: get rid of slices: rewrite to torch
        if type(ptr) == type(None):
            batch_mask_neg_inf = torch.ones(size=(query.shape[0], query.shape[0])).to(next(self.parameters()).device)
            batch_mask_zeros += 1
        else:
            for i in range(len(ptr) - 1):
                batch_mask_neg_inf[ptr[i]:ptr[i + 1], ptr[i]:ptr[i + 1]] = 1
                batch_mask_zeros[ptr[i]:ptr[i + 1], ptr[i]:ptr[i + 1]] = 1

        query = self.q(query)
        key = self.k(key)
        value = self.v(value)

        c = self.edge_encoding(query, edge_attr, edge_paths)
        a = query.mm(key.transpose(0, 1)) / query.size(-1) ** 0.5
        a = (a + b + c) * batch_mask_neg_inf
        softmax = torch.softmax(a, dim=-1) * batch_mask_zeros
        x = softmax.mm(value)
        return x


# FIX: sparse attention instead of regular attention, due to specificity of GNNs(all nodes in batch will exchange attention)
class GraphormerMultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, dim_in: int, dim_q: int, dim_k: int, edge_dim: int, max_path_distance: int):
        """
        :param num_heads: number of attention heads
        :param dim_in: node feature matrix input number of dimension
        :param dim_q: query node feature matrix input number dimension
        :param dim_k: key node feature matrix input number of dimension
        :param edge_dim: edge feature matrix number of dimension
        """
        super().__init__()
        self.heads = nn.ModuleList(
            [GraphormerAttentionHead(dim_in, dim_q, dim_k, edge_dim, max_path_distance) for _ in range(num_heads)]
        )
        self.linear = nn.Linear(num_heads * dim_k, dim_in)

    def forward(self,
                x: torch.Tensor,
                edge_attr: torch.Tensor,
                b: torch.Tensor,
                edge_paths,
                ptr) -> torch.Tensor:
        """
        :param x: node feature matrix
        :param edge_attr: edge feature matrix
        :param b: spatial Encoding matrix
        :param edge_paths: pairwise node paths in edge indexes
        :param ptr: batch pointer that shows graph indexes in batch of graphs
        :return: torch.Tensor, node embeddings after all attention heads
        """
        return self.linear(
            torch.cat([
                attention_head(x, x, x, edge_attr, b, edge_paths, ptr) for attention_head in self.heads
            ], dim=-1)
        )


class GraphormerEncoderLayer(nn.Module):
    def __init__(self, node_dim, edge_dim, n_heads, max_path_distance):
        """
        :param node_dim: node feature matrix input number of dimension
        :param edge_dim: edge feature matrix input number of dimension
        :param n_heads: number of attention heads
        """
        super().__init__()

        self.node_dim = node_dim
        self.edge_dim = edge_dim
        self.n_heads = n_heads

        self.attention = GraphormerMultiHeadAttention(
            dim_in=node_dim,
            dim_k=node_dim,
            dim_q=node_dim,
            num_heads=n_heads,
            edge_dim=edge_dim,
            max_path_distance=max_path_distance,
        )
        self.ln_1 = nn.LayerNorm(node_dim)
        self.ln_2 = nn.LayerNorm(node_dim)
        self.ff = nn.Linear(node_dim, node_dim)

    def forward(self,
                x: torch.Tensor,
                edge_attr: torch.Tensor,
                b: torch,
                edge_paths,
                ptr) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        h′(l) = MHA(LN(h(l−1))) + h(l−1)
        h(l) = FFN(LN(h′(l))) + h′(l)

        :param x: node feature matrix
        :param edge_attr: edge feature matrix
        :param b: spatial Encoding matrix
        :param edge_paths: pairwise node paths in edge indexes
        :param ptr: batch pointer that shows graph indexes in batch of graphs
        :return: torch.Tensor, node embeddings after Graphormer layer operations
        """
        x_prime = self.attention(self.ln_1(x), edge_attr, b, edge_paths, ptr) + x
        x_new = self.ff(self.ln_2(x_prime)) + x_prime

        return x_new



def floyd_warshall_source_to_all(G, source, cutoff=None):
    if source not in G:
        raise nx.NodeNotFound("Source {} not in G".format(source))

    edges = {edge: i for i, edge in enumerate(G.edges())}

    level = 0  # the current level
    nextlevel = {source: 1}  # list of nodes to check at next level
    node_paths = {source: [source]}  # paths dictionary  (paths to key from source)
    edge_paths = {source: []}

    while nextlevel:
        thislevel = nextlevel
        nextlevel = {}
        for v in thislevel:
            for w in G[v]:
                if w not in node_paths:
                    node_paths[w] = node_paths[v] + [w]
                    edge_paths[w] = edge_paths[v] + [edges[tuple(node_paths[w][-2:])]]
                    nextlevel[w] = 1

        level = level + 1

        if (cutoff is not None and cutoff <= level):
            break

    return node_paths, edge_paths


def all_pairs_shortest_path(G) -> Tuple[Dict[int, List[int]], Dict[int, List[int]]]:
    paths = {n: floyd_warshall_source_to_all(G, n) for n in G}
    node_paths = {n: paths[n][0] for n in paths}
    edge_paths = {n: paths[n][1] for n in paths}
    return node_paths, edge_paths


def shortest_path_distance(data: Data) -> Tuple[Dict[int, List[int]], Dict[int, List[int]]]:
    G = to_networkx(data)
    node_paths, edge_paths = all_pairs_shortest_path(G)
    return node_paths, edge_paths


def batched_shortest_path_distance(data) -> Tuple[Dict[int, List[int]], Dict[int, List[int]]]:
    graphs = [to_networkx(sub_data) for sub_data in data.to_data_list()]
    relabeled_graphs = []
    shift = 0
    for i in range(len(graphs)):
        num_nodes = graphs[i].number_of_nodes()
        relabeled_graphs.append(nx.relabel_nodes(graphs[i], {i: i + shift for i in range(num_nodes)}))
        shift += num_nodes

    paths = [all_pairs_shortest_path(G) for G in relabeled_graphs]
    node_paths = {}
    edge_paths = {}

    for path in paths:
        for k, v in path[0].items():
            node_paths[k] = v
        for k, v in path[1].items():
            edge_paths[k] = v

    return node_paths, edge_paths



class Graphormer(nn.Module):
    def __init__(self,
                 num_layers: int,
                 node_dim: int,
                 input_edge_dim: int,
                 edge_dim: int,
                 n_heads: int,
                 max_in_degree: int,
                 max_out_degree: int,
                 max_path_distance: int):
        """
        :param num_layers: number of Graphormer layers
        :param input_node_dim: input dimension of node features
        :param node_dim: hidden dimensions of node features
        :param input_edge_dim: input dimension of edge features
        :param edge_dim: hidden dimensions of edge features
        :param output_dim: number of output node features
        :param n_heads: number of attention heads
        :param max_in_degree: max in degree of nodes
        :param max_out_degree: max in degree of nodes
        :param max_path_distance: max pairwise distance between two nodes
        """
        super().__init__()

        self.num_layers = num_layers
        self.node_dim = node_dim
        self.input_edge_dim = input_edge_dim
        self.edge_dim = edge_dim
        self.n_heads = n_heads
        self.max_in_degree = max_in_degree
        self.max_out_degree = max_out_degree
        self.max_path_distance = max_path_distance

        self.edge_in_lin = nn.Linear(self.input_edge_dim, self.edge_dim)

        self.centrality_encoding = CentralityEncoding(
            max_in_degree=self.max_in_degree,
            max_out_degree=self.max_out_degree,
            node_dim=self.node_dim
        )

        self.spatial_encoding = SpatialEncoding(
            max_path_distance=max_path_distance,
        )

        self.layers = nn.ModuleList([
            GraphormerEncoderLayer(
                node_dim=self.node_dim,
                edge_dim=self.edge_dim,
                n_heads=self.n_heads,
                max_path_distance=self.max_path_distance) for _ in range(self.num_layers)
        ])



    def forward(self, data: Union[Data]) -> torch.Tensor:
        """
        :param data: input graph of batch of graphs
        :return: torch.Tensor, output node embeddings
        """
        x = data.x.float()
        edge_index = data.edge_index.long()
        edge_attr = data.edge_attr.float()

        if type(data) == Data:
            ptr = None
            node_paths, edge_paths = shortest_path_distance(data)
        else:
            ptr = data.ptr
            node_paths, edge_paths = batched_shortest_path_distance(data)


        edge_attr = self.edge_in_lin(edge_attr)

        x = self.centrality_encoding(x, edge_index)
        b = self.spatial_encoding(x, node_paths)

        for layer in self.layers:
            x = layer(x, edge_attr, b, edge_paths, ptr)



        return x
        

def to_graph_batch_edge_attr(seq, mask, adj_matrix, p=0.49):
    res = []
    for i in range(len(seq)):
        m = adj_matrix[i] > p
        res.append(Data(x=seq[i][mask[i]], edge_index=m.nonzero().T, edge_attr=adj_matrix[i][m].view(-1, 1)))
    return Batch.from_data_list(res)


class PytorchBatchWrapper(nn.Module):
    def __init__(self, md):
        super().__init__()
        self.md = md

    def forward(self, seq, mask, adj_matrix):
        batch = to_graph_batch_edge_attr(seq, mask, adj_matrix)
        out = self.md(batch)
        out, _ = to_dense_batch(out, batch.batch)
        return out


class Conv1D(nn.Conv1d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.src_key_padding_mask = None

    def forward(self, x, src_key_padding_mask=None):
        if src_key_padding_mask is not None:
            self.src_key_padding_mask = src_key_padding_mask
        if self.src_key_padding_mask is not None:
            x = torch.where(
                self.src_key_padding_mask.unsqueeze(-1)
                .expand(-1, -1, x.shape[-1])
                .bool(),
                torch.zeros_like(x),
                x,
            )
        return super().forward(x.permute(0, 2, 1)).permute(0, 2, 1)


class ResBlock(nn.Sequential):
    def __init__(self, d_model):
        super().__init__(
            nn.LayerNorm(d_model), nn.GELU(), Conv1D(d_model, d_model, 3, padding=1)
        )
        self.src_key_padding_mask = None

    def forward(self, x, src_key_padding_mask=None):
        self[-1].src_key_padding_mask = (
            src_key_padding_mask
            if src_key_padding_mask is not None
            else self.src_key_padding_mask
        )
        return x + super().forward(x)


class Extractor(nn.Sequential):
    def __init__(self, d_model, in_ch=4):
        super().__init__(
            nn.Embedding(in_ch, d_model // 4),
            Conv1D(d_model // 4, d_model, 7, padding=3),
            ResBlock(d_model),
        )

    def forward(self, x, src_key_padding_mask=None):
        for i in [1, 2]:
            self[i].src_key_padding_mask = src_key_padding_mask
        return super().forward(x)
    
    
class GraphformerV0(nn.Module):
    def __init__(self, dim=192, depth=12, head_size=32, graph_layers=4, **kwargs):
        super().__init__()

        self.extractor = Extractor(dim)
        self.graph_layers =PytorchBatchWrapper(
                    Graphormer(
                    num_layers=graph_layers,
                    node_dim=dim,
                    input_edge_dim=1,
                    edge_dim=8,
                    n_heads=4,
                    max_in_degree=5,
                    max_out_degree=5,
                    max_path_distance=5,
                    )
        )

        self.proj_out = nn.Linear(dim, 2)

    def forward(self, x0):
        mask = x0["mask"]
        L0 = mask.shape[1]
        Lmax = mask.sum(-1).max()
        mask = mask[:, :Lmax]
        x = x0["seq"][:, :Lmax]
        x = self.extractor(x, src_key_padding_mask=~mask)
        x = self.graph_layers(x, mask, x0["bpp"])
        x = self.proj_out(x)
        x = F.pad(x, (0, 0, 0, L0 - Lmax, 0, 0))
        return x
    
    
        
        

In [None]:
#|hide
#|eval: false
from nbdev.doclinks import nbdev_export
nbdev_export()

In [None]:
# from pathlib import Path
# import sys
# sys.path.append('..')
# import os
# import pandas as pd
# from rnacomp.dataset import LenMatchBatchSampler, RNA_DatasetBaselineSplitbppV0, DeviceDataLoader, RNA_DatasetBaselineSplitssV1, RNA_DatasetBaselineSplit, RNA_DatasetBaselineSplitbppV2

# class CFG:
#     path = Path("../data/")
#     pathbb = Path("../data/Ribonanza_bpp_files")
#     split_id = Path('../eda/fold_split.csv')
#     pathss = Path("../eda/train_ss_vienna_rna.parquet")
#     bs = 16
#     num_workers = 8
#     device = 'cpu'
#     adjnact_prob = 0.5



# fns = list(CFG.pathbb.rglob("*.txt"))
# bpp_df = pd.DataFrame({"bpp": fns})
# bpp_df['sequence_id'] = bpp_df['bpp'].apply(lambda x: x.stem)
# ss = pd.read_parquet(CFG.pathss)[["sequence_id", "ss_full"]]
# df = pd.read_parquet(CFG.path/'train_data.parquet')
# split = pd.read_csv(CFG.split_id)
# df = pd.merge(df, split, on='sequence_id')
# df = pd.merge(df, bpp_df, on='sequence_id')
# df = pd.merge(df, ss, on='sequence_id')
# df_train = df.query('is_train==True').reset_index(drop=True)
# df_valid = df.query('is_train==False').reset_index(drop=True)

# ds_val = RNA_DatasetBaselineSplitbppV2(df_valid, mode='eval')
# ds_val_len = RNA_DatasetBaselineSplitbppV2(df_valid, mode='eval', mask_only=True)
# sampler_val = torch.utils.data.SequentialSampler(ds_val_len)
# len_sampler_val = LenMatchBatchSampler(sampler_val, batch_size=CFG.bs, 
#                drop_last=False)
# dl_val= DeviceDataLoader(torch.utils.data.DataLoader(ds_val, 
#                batch_sampler=len_sampler_val, num_workers=CFG.num_workers), CFG.device)

# batch = next(iter(dl_val))[0]

In [None]:
# with torch.no_grad():
#      out =  GraphformerV0()(batch)