In [None]:
import numpy as np
import pandas as pd
import random

In [None]:
import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Data, DataLoader

In [None]:
import torch_geometric
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops
from torch_geometric.nn import GraphConv, TopKPooling, GatedGraphConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
import torch.nn.functional as F

In [None]:
from sklearn.metrics import roc_auc_score, average_precision_score

In [None]:
import plotly.express as px

In [None]:
def make_dataset(df):
    data_list = []

    for _, row in df.iterrows():
        seq = row['seq']
        stem_bb_bps = row['stem_bb_bps']
#         target_bps = row['target_bps']

        # use integer encoding for now
        seq_int = seq.upper().replace('A', '1').replace('C', '2').replace('G', '3').replace('T', '4').replace('U', '4').replace('N', '0')
        tmp = np.asarray(list(map(int, list(seq_int))), dtype=np.int16)
        # one-hot
        node_features = np.zeros((len(seq), 4))
        node_features[np.arange(len(seq)),tmp-1] = 1
        node_features = torch.from_numpy(node_features).float()
#         node_features = torch.LongTensor(node_features).unsqueeze(1)  # FIXME dtype

        # build edges
        edge_from = []
        edge_to = []
        # chain - undirected edge for now
        node_left = range(0, len(seq) - 1)
        node_right = range(1, len(seq))
        edge_from.extend(node_left)
        edge_to.extend(node_right)
        edge_from.extend(node_right)
        edge_to.extend(node_left)
        assert len(edge_from) == len(edge_to)
        n_edge_1 = len(edge_from)  # number of 'backbone' edges
        # pair matrix - undirected edge 
        # for all predicted stem bbs
        for idx_left, idx_right in stem_bb_bps:
            edge_from.append(idx_left)
            edge_to.append(idx_right)
            edge_from.append(idx_right)
            edge_to.append(idx_left)
        assert len(edge_from) == len(edge_to)
        n_edge_2 = len(edge_from) - n_edge_1  # number of 'hydrogen bond' edges
        edge_index = torch.tensor([edge_from, edge_to], dtype=torch.long)

        # edge feature, 0 for "backbone", 1 for "hydrogen bond", later on we will encode it
        edge_attr = torch.LongTensor([0] * n_edge_1 + [1] * n_edge_2).unsqueeze(1)  # FIXME dtype

        # target: edge label
        # binary matrix of size lxl
        # toy task: label all G-C in stem_bb_bps as 1 and others as 0
        target_bps = []
        for i, j in stem_bb_bps:
            if seq[i] == 'G' and seq[j] == 'C' or seq[i] == 'C' and seq[j] == 'G':
                target_bps.append((i, j))
        
        y = np.zeros((len(seq), len(seq)))
        y[tuple(zip(*target_bps))] = 1
        # mask: locations with 0 are don't-cares
        # these are pred stem bb bps
        m = np.zeros((len(seq), len(seq)))
        m[tuple(zip(*stem_bb_bps))] = 1
        
#         # make data point
#         # store both target and mask in y
#         data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr,
#                    y={'target': y, 'mask': m})
        # make data point
        data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr,
                   y=y, m=m)
        
        data_list.append(data)
    return data_list

In [None]:
class Net(torch.nn.Module):
    
    def __init__(self, n_hid=10):
        super(Net, self).__init__()
        self.conv1 = torch_geometric.nn.conv.GATConv(4, n_hid)
        self.act1 = torch.nn.ReLU()
        self.act2 = torch.nn.Sigmoid()
    
        # node-node NN
        self.node_pair_conv1 = torch.nn.Conv2d(n_hid*2, n_hid, kernel_size=1, stride=1,
                                            padding=0)
        self.node_pair_conv2 = torch.nn.Conv2d(n_hid, 1, kernel_size=1, stride=1,
                                            padding=0)
        
    def forward(self, data):
        x = self.act1(self.conv1(data.x, data.edge_index))
        
        # outer concat: L x L x 2f
        x1 = x.unsqueeze(1).repeat(1, x.size(0), 1)
        x2 = x.unsqueeze(0).repeat(x.size(0), 1, 1)
        x = torch.cat([x1, x2], axis=2)
        
        # FC along last dim
        # note that conv_2d expects Input: (N, C_{in}, H_{in}, W_{in})
        x = x.permute(2, 0, 1).unsqueeze(0)
        x = self.act1(self.node_pair_conv1(x))
        x = self.act2(self.node_pair_conv2(x))
        return x.squeeze()

In [None]:
loss_b = torch.nn.BCELoss(reduction='none')


def debug_loss(x, y, m):
    # L x L
    l = loss_b(x, y)
    return torch.mean(l)


def masked_loss_b(x, y, m):
    # L x L
#     x = x.squeeze()  # L x L
#     y = y.squeeze()  # L x L
    l = loss_b(x, y)
    n_valid_output = torch.sum(m)
    loss_spatial_sum = torch.sum(torch.mul(l, m))
    loss_spatial_mean = loss_spatial_sum / n_valid_output
    loss_batch_mean = torch.mean(loss_spatial_mean, dim=0)
    return torch.mean(loss_batch_mean)

In [None]:
def roc_prc(x, y, m):
    # true, score, mask
    mask_bool = m.eq(1)
    _x2 = x.masked_select(mask_bool).flatten().detach().cpu().numpy()
    _y2 = y.masked_select(mask_bool).flatten().detach().cpu().numpy()
    # do not compute if empty (e.g. when all elements are being masked)
    # do not compute if there's only one class
    if len(_x2) > 0 and not np.all(_x2 == _x2[0]):
        roc = roc_auc_score(_x2, _y2)
        prc = average_precision_score(_x2, _y2)
    else:
        roc = np.NaN
        prc = np.NaN
    return roc, prc

In [None]:
df = pd.read_pickle('data/debug_training_len20_200_100_s1_pred_stem_bps.pkl.gz')

In [None]:
# data_loader = DataLoader(make_dataset(df), batch_size=1)

In [None]:
data_list = make_dataset(df)

In [None]:
model = Net(n_hid=10)
model.train()

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
for epoch in range(10):
    random.shuffle(data_list)
    
    loss_all = []
    auc_all = []
    
    for data in data_list:
        y = torch.from_numpy(data.y).float()
        m = torch.from_numpy(data.m).float()
        optimizer.zero_grad()
        pred = model(data)
        loss = masked_loss_b(pred, y, m)
        loss_all.append(loss.item())
        auc, prc = roc_prc(y, pred, m)
        auc_all.append(auc)

        loss.backward()
        optimizer.step()
        
    print("Epoch {}, mean loss {}, mean AUC {}".format(epoch, np.mean(loss_all), np.mean(auc_all)))


In [None]:
px.imshow(pred.detach().numpy() * m.detach().numpy())

In [None]:
px.imshow(y)