In [20]:
import math
import os.path as osp
from itertools import chain
import wandb
from tqdm.auto import tqdm


import numpy as np
import torch
import torch.nn.functional as F
from scipy.sparse.csgraph import shortest_path
from sklearn.metrics import roc_auc_score
from torch.nn import BCEWithLogitsLoss, Conv1d, MaxPool1d, ModuleList

from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MLP, GCNConv, global_sort_pool, GatedGraphConv, GCN
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import k_hop_subgraph, to_scipy_sparse_matrix


from dataset_processing import RNADataset


In [21]:
class SEALDataset(InMemoryDataset):
    def __init__(self, dataset, num_hops, split='train'):
        self.data = dataset
        self.num_hops = num_hops
        super().__init__(dataset.root)
        index = ['train', 'val', 'test'].index(split)
        self.data, self.slices = torch.load(self.processed_paths[index])

    @property
    def processed_file_names(self):
        return ['SEAL_train_data.pt', 
#                 'SEAL_val_data.pt', 'SEAL_test_data.pt'
               ]

    def process(self):
        train_data, val_data, test_data = dataset[0:655], dataset[655:873], dataset[873:]
        self._max_z = 0
        
        train_pos_data_list, train_neg_data_list = [], []
        val_pos_data_list, val_neg_data_list = [], []
        test_pos_data_list, test_neg_data_list = [], []
        
        for g in train_data:
            train_pos_data_list.extend(self.extract_enclosing_subgraphs(
            g.edge_index, g.pos_edge_index, 1, g.x))
            train_neg_data_list.extend(self.extract_enclosing_subgraphs(
            g.edge_index, g.neg_edge_index, 0, g.x))

#         print("train")
#         for g in val_data:
#             val_pos_data_list.extend(self.extract_enclosing_subgraphs(
#             g.edge_index, g.pos_edge_index, 1, g.x))
#             val_neg_data_list.extend(self.extract_enclosing_subgraphs(
#             g.edge_index, g.neg_edge_index, 0, g.x))
            
#         print("val")
#         for g in test_data:
#             test_pos_data_list.extend(self.extract_enclosing_subgraphs(
#             g.edge_index, g.pos_edge_index, 1, g.x))
#             test_neg_data_list.extend(self.extract_enclosing_subgraphs(
#             g.edge_index, g.neg_edge_index, 0, g.x))
#         print("test")

        # Convert node labeling to one-hot features.
#         for data in chain(train_pos_data_list, train_neg_data_list,
#                           val_pos_data_list, val_neg_data_list,
#                           test_pos_data_list, test_neg_data_list):
#             # We solely learn links from structure, dropping any node features:
#             data.x = F.one_hot(data.z, self._max_z + 1).to(torch.float)
        
            
        torch.save(self.collate(train_pos_data_list + train_neg_data_list),
                   self.processed_paths[0])
#         torch.save(self.collate(val_pos_data_list + val_neg_data_list),
#                    self.processed_paths[1])
#         torch.save(self.collate(test_pos_data_list + test_neg_data_list),
#                    self.processed_paths[2])

        
    def extract_enclosing_subgraphs(self, edge_index, edge_label_index, y, x):
        data_list = []
        for src, dst in edge_label_index.t().tolist():
            sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph(
                [src, dst], self.num_hops, edge_index, relabel_nodes=True)
            src, dst = mapping.tolist()

            # Remove target link from the subgraph.
            mask1 = (sub_edge_index[0] != src) | (sub_edge_index[1] != dst)
            mask2 = (sub_edge_index[0] != dst) | (sub_edge_index[1] != src)
            sub_edge_index = sub_edge_index[:, mask1 & mask2]

            # Calculate node labeling.
            z = self.drnl_node_labeling(sub_edge_index, src, dst,
                                        num_nodes=sub_nodes.size(0))

            data = Data(x=x[sub_nodes], z=z,
                        edge_index=sub_edge_index, y=y)
            data_list.append(data)

        return data_list

    def drnl_node_labeling(self, edge_index, src, dst, num_nodes=None):
        # Double-radius node labeling (DRNL).
        src, dst = (dst, src) if src > dst else (src, dst)
        adj = to_scipy_sparse_matrix(edge_index, num_nodes=num_nodes).tocsr()

        idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
        adj_wo_src = adj[idx, :][:, idx]

        idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
        adj_wo_dst = adj[idx, :][:, idx]

        dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True,
                                 indices=src)
        dist2src = np.insert(dist2src, dst, 0, axis=0)
        dist2src = torch.from_numpy(dist2src)

        dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True,
                                 indices=dst - 1)
        dist2dst = np.insert(dist2dst, src, 0, axis=0)
        dist2dst = torch.from_numpy(dist2dst)

        dist = dist2src + dist2dst
        dist_over_2, dist_mod_2 = dist // 2, dist % 2

        z = 1 + torch.min(dist2src, dist2dst)
        z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
        z[src] = 1.
        z[dst] = 1.
        z[torch.isnan(z)] = 0.

        self._max_z = max(int(z.max()), self._max_z)

        return z.to(torch.long)

In [22]:
def del_diags(y):
    main_diag = torch.ones([1, y.size(0)])
    diag1 = torch.ones([1, y.size(0) - 1])
    return y * (torch.diag_embed(main_diag).to(device) + 
            torch.diag_embed(diag1, offset=1).to(device) + 
            torch.diag_embed(diag1, offset=-1).to(device) - 1) * (-1)

Exception in thread SystemMonitor:
Traceback (most recent call last):
  File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/home/vdshk/SecondaryStructurePredictionGNN/venv/lib/python3.8/site-packages/wandb/sdk/internal/system/system_monitor.py", line 118, in _start
    asset.start()
  File "/home/vdshk/SecondaryStructurePredictionGNN/venv/lib/python3.8/site-packages/wandb/sdk/internal/system/assets/gpu.py", line 355, in start
    self.metrics_monitor.start()
  File "/home/vdshk/SecondaryStructurePredictionGNN/venv/lib/python3.8/site-packages/wandb/sdk/internal/system/assets/interfaces.py", line 168, in start
    logger.info(f"Started {self._process.name}")
AttributeError: 'NoneType' object has no attribute 'name'


In [23]:
def precision(y_pred, y_true):
    y_pred = del_diags(y_pred)
    y_pred[(y_pred > 0.5)] = 1
    y_pred[(y_pred <= 0.5)] = 0 
    
    tp = torch.sum(y_pred * y_true)
    fp = torch.sum((1 - y_true) * y_pred)
    
    return tp / (tp + fp + epsilon)

def recall(y_pred, y_true):
    y_pred = del_diags(y_pred)
    y_pred[(y_pred > 0.5)] = 1
    y_pred[(y_pred <= 0.5)] = 0
    
    tp = torch.sum(y_pred * y_true)
    fn = torch.sum(y_true * (1 - y_pred))
    
    return tp / (tp + fn + epsilon)

def f_loss(y_pred, y_true):
    y_pred = del_diags(y_pred)

    tp = torch.sum(y_pred * y_true)
    fn = torch.sum(y_true * (1 - y_pred))
    fp = torch.sum((1 - y_true) * y_pred)
    precision = tp / (tp + fp + epsilon)
    recall = tp / (tp + fn + epsilon)

    k = 0.001
    k1 = 1 - torch.abs(precision - recall)

#     k2 = 1 - torch.abs(K.mean(precision) - K.mean(recall))
    #calculate upgraded f1 score
    f1 = ((1 + k * k) * precision * recall) / (k * k * precision + recall + epsilon)
#     tw = K.sum(K.cast(y_true * y_pred, ’float32’), axis=[1, 2, 3])
#     fw = K.sum(K.cast((1 - y_true) * y_pred, ’float32’), axis=[1, 2, 3])
#     fb = K.sum(K.cast(y_true * (1 - y_pred), ’float32’), axis=[1, 2, 3])
    return 1 - f1


In [31]:
dataset = RNADataset(root="./data/")

In [37]:
train_data, val_data, test_data = dataset[0:655], dataset[655:873], dataset[873:]

train_data = SEALDataset(train_data, num_hops=2, split="train")

Processing...
Done!


In [26]:
class GCNModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers, out_channels, dropout):
        super(GCNModel, self).__init__()
        self.gcn = GCN(in_channels, hidden_channels, num_layers, out_channels, dropout=dropout)

    def encode(self, x, edge_index):
        x = self.gcn(x.to(torch.float), edge_index)
        return x
    
    def decode(self, x, edge_index):
        logits = torch.mean((x[edge_index[0]] * x[edge_index[1]]))
        return logits.sum(dim=-1)
    
    def decode_all(self, z):
#         prob_adj = (z @ z.t())
        prob_adj = (z @ z.t()).sigmoid()

        return prob_adj

In [27]:
hidden_channels = 128
out_channels = 64
num_layers = 5
num_features = dataset.num_features


dr = 0.9
lr = 0.00005
epochs = 150

model = GCNModel(num_features, hidden_channels, num_layers, out_channels, dr)
# model = torch.load("./models/GatedGCN_6_180_7e-05_Adam_300.pt")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
print(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# optimizer = torch.optim.Adagrad(model.parameters(), lr=lr)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

# criterion = torch.nn.BCELoss()
criterion = torch.nn.BCEWithLogitsLoss()
# criterion = torch.nn.CrossEntropyLoss()
# criterion = torch.nn.KLDivLoss()

# def RMSELoss(y_pred, y_true):
#     return torch.sqrt(torch.mean((y_pred - y_true) ** 2))

# criterion = f_loss

cuda


In [28]:
run_name = "SEAL_GCN" + str(num_layers) + "_" + str(epochs) + "_" + str(lr) + "_" + "Adam_" + str(out_channels) + "_" + str(dr)


wandb.init(
    # set the wandb project where this run will be logged
    project="secondary_structure_prediction1",
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": lr,
    "dataset": "SEAL",
    "architecture": "GatedGCN",
    "epochs": epochs,
    "optimizer": "Adam",
    "hidden_channels": hidden_channels,
    "out_channels": out_channels,
    "loss": "BCEWithLogits",
    "dropout": dr,
#     "beta": 0.001,
    "train:val:test": "655:218:218"
    },
    name=run_name
)

epsilon = 1e-10

0,1
train_loss,▁
val_loss,▁
val_precision,▁
val_recall,▁

0,1
train_loss,0.60606
val_f1,
val_loss,0.96608
val_precision,0.0
val_recall,0.0


In [29]:
def train():
    for epoch in range(1, epochs + 1):
        model.train()
        train_loss = []
        train_recall = []
        train_precision = []
        for g in tqdm(train_data, ncols=100):
            g.to(device)
            optimizer.zero_grad()
#             print(g.x[g.edge_index[0]])
            z = model.encode(g.x, g.edge_index)
#             print(z)
            out = model.decode(z, g.edge_index)
#             print(out)
            loss = criterion(out.view(-1), g.y.to(torch.float))         
            
            loss.backward()
            optimizer.step()
            
#             out = out.sigmoid()
            
            train_loss.append(loss.item())
#             train_precision.append(precision(out, g.y).item())
#             train_recall.append(recall(out, g.y).item())
            
#         train_prec = np.mean(train_precision)
#         train_rec = np.mean(train_recall)
#         train_f1 = (2 * train_prec * train_rec) / (train_prec + train_rec)
        print(f'Epoch: {epoch:03d}, loss: {np.mean(train_loss)}')
        
        val_loss = []
        val_recall = []
        val_precision = []
        with torch.no_grad():
            for g in tqdm(val_data, ncols=100):
                g.to(device)
                z = model.encode(g.x, g.edge_index)
                out = model.decode_all(z)
#                 print(g.x.size(), z.t().size())
#                 print(z)
#                 print(out)
                y_true = g.adj_mat
                loss = criterion(out, y_true.to(torch.float32))
                                
                val_loss.append(loss.item())
                
#                 out = out.sigmoid()
                
                val_precision.append(precision(out, y_true).item())
                val_recall.append(recall(out, y_true).item())
 
            
            prec = np.mean(val_precision)
            rec = np.mean(val_recall)
            f1 = (2 * prec * rec) / (prec + rec)
            print(f'val_loss: {np.mean(val_loss)}, val_f1: {f1}, val_precision: {prec}, val_recall: {rec}')
              
            wandb.log({"train_loss": np.mean(train_loss), 
#                        "train_f1": train_f1, "train_precision": train_prec, 
#                        "train_recall": train_rec,
                       "val_loss": np.mean(val_loss), "val_f1": f1, "val_precision": prec, "val_recall": rec})
#     wandb.finish()

In [30]:
train()
torch.save(model, "./models1/" + run_name + ".pt")

100%|██████████████████████████████████████████████████████| 184694/184694 [23:02<00:00, 133.61it/s]


Epoch: 001, loss: 0.6080964769501022


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 188.03it/s]
  f1 = (2 * prec * rec) / (prec + rec)


val_loss: 0.9660842407734023, val_f1: nan, val_precision: 0.0, val_recall: 0.0


100%|██████████████████████████████████████████████████████| 184694/184694 [22:34<00:00, 136.33it/s]


Epoch: 002, loss: 0.723656157618705


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 198.62it/s]


val_loss: 0.9660842407734023, val_f1: nan, val_precision: 0.0, val_recall: 0.0


 85%|██████████████████████████████████████████████        | 157575/184694 [19:18<03:19, 136.05it/s]


KeyboardInterrupt: 