In [1]:
import os.path as osp
from tqdm.auto import tqdm
import numpy as np
import wandb

import torch
from sklearn.metrics import roc_auc_score

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GraphUNet
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import DataLoader

from dataset_processing import RNADataset

  from .autonotebook import tqdm as notebook_tqdm


In [104]:
dataset = RNADataset(root="./data2/")
dataset = dataset.shuffle()
train_data, val_data, test_data = dataset[0:655], dataset[655:873], dataset[873:]

train_dataloader = DataLoader(train_data, batch_size=8, shuffle=False)
val_dataloader = DataLoader(val_data, batch_size=1, shuffle=False)


In [105]:
train_data[117].edge_index

tensor([[ 0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  8,  8,  9,
          9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18,
         18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24, 25, 25, 26, 26, 27,
         27, 28, 28, 29, 29, 30, 30, 31, 31, 32, 32, 33, 33, 34, 34, 35, 35, 36,
         36, 37, 37, 38, 38, 39, 39, 40, 40, 41, 41, 42, 42, 43, 43, 44, 44, 45,
         45, 46, 46, 47],
        [ 1,  0,  2,  1,  3,  2,  4,  3,  5,  4,  6,  5,  7,  6,  8,  7,  9,  8,
         10,  9, 11, 10, 12, 11, 13, 12, 14, 13, 15, 14, 16, 15, 17, 16, 18, 17,
         19, 18, 20, 19, 21, 20, 22, 21, 23, 22, 24, 23, 25, 24, 26, 25, 27, 26,
         28, 27, 29, 28, 30, 29, 31, 30, 32, 31, 33, 32, 34, 33, 35, 34, 36, 35,
         37, 36, 38, 37, 39, 38, 40, 39, 41, 40, 42, 41, 43, 42, 44, 43, 45, 44,
         46, 45, 47, 46]])

In [3]:
def del_diags(y):
    main_diag = torch.ones([1, y.size(0)])
    diag1 = torch.ones([1, y.size(0) - 1])
    return y * (torch.diag_embed(main_diag).to(device) + 
            torch.diag_embed(diag1, offset=1).to(device) + 
            torch.diag_embed(diag1, offset=-1).to(device) - 1) * (-1)

In [28]:
def precision(y_pred, y_true):
    y_pred = del_diags(y_pred)
    y_pred[(y_pred > 0.5)] = 1
    y_pred[(y_pred <= 0.5)] = 0 
    
    tp = torch.sum(y_pred * y_true)
    fp = torch.sum((1 - y_true) * y_pred)
    
    return tp / (tp + fp + epsilon)

def recall(y_pred, y_true):
    y_pred = del_diags(y_pred)
    y_pred[(y_pred > 0.5)] = 1
    y_pred[(y_pred <= 0.5)] = 0
    
    tp = torch.sum(y_pred * y_true)
    fn = torch.sum(y_true * (1 - y_pred))
    
    return tp / (tp + fn + epsilon)

def f_loss(y_pred, y_true):
    y_pred = del_diags(y_pred)

    tp = torch.sum(y_pred * y_true)
    fn = torch.sum(y_true * (1 - y_pred))
    fp = torch.sum((1 - y_true) * y_pred)
    precision = tp / (tp + fp + epsilon)
    recall = tp / (tp + fn + epsilon)

    k = 1
    k1 = 1 - torch.abs(precision - recall)

#     k2 = 1 - torch.abs(K.mean(precision) - K.mean(recall))
    #calculate upgraded f1 score
    f1 = ((1 + k * k) * precision * recall) / (k * k * precision + recall + epsilon)
#     tw = K.sum(K.cast(y_true * y_pred, ’float32’), axis=[1, 2, 3])
#     fw = K.sum(K.cast((1 - y_true) * y_pred, ’float32’), axis=[1, 2, 3])
#     fb = K.sum(K.cast(y_true * (1 - y_pred), ’float32’), axis=[1, 2, 3])
    return 1 - f1


In [108]:
class UNetModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, depth, act):
        super(UNetModel, self).__init__()
        self.unet = GraphUNet(in_channels, hidden_channels, out_channels, depth)

    def forward(self, x, edge_index):
        x = self.unet(x.to(torch.float), edge_index)
        prob_adj = (x @ x.t()).sigmoid()

#         prob_adj = (x @ x.t())

#         return (prob_adj > 0).nonzero(as_tuple=False).t()
        return prob_adj


In [126]:
hidden_channels = 128
out_channels = 200
depth = 3
num_features = dataset.num_features


lr = 0.0001
epochs = 150
act = torch.nn.functional.leaky_relu

model = UNetModel(num_features, hidden_channels, out_channels, depth, act)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
train_data = list(map(lambda x: x.to(device), train_data))
val_data = list(map(lambda x: x.to(device), val_data))
print(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# optimizer = torch.optim.Adagrad(model.parameters(), lr=lr)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

# criterion = torch.nn.BCELoss()
# criterion = torch.nn.BCEWithLogitsLoss()
# criterion = torch.nn.CrossEntropyLoss()
# criterion = torch.nn.KLDivLoss()

# def RMSELoss(y_pred, y_true):
#     return torch.sqrt(torch.mean((y_pred - y_true) ** 2))

criterion = f_loss

cuda


In [127]:
run_name = "UNet_" + str(depth) + "_" + str(epochs) + "_" + str(lr) + \
    "_" + str(hidden_channels) + "_" + str(out_channels)


wandb.init(
    # set the wandb project where this run will be logged
    project="secondary_structure_prediction1",
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": lr,
    "architecture": "UNet",
    "epochs": epochs,
    "optimizer": "Adam",
    "out_channels": out_channels,
    "loss": "f_loss",
#     "act": "LeakyRelu",
#     "beta": 0.001,
    "train:val:test": "655:218:218"
    },
    name=run_name
)

epsilon = 1e-10

0,1
train_f1,▁▇█▇██▆█
train_loss,█▃▂▁▁▁▂▁
train_precision,▁▇█▇██▆█
train_recall,█▁▂▂▂▂▂▃
val_f1,▁▄██▅▅▆▆
val_loss,█▆▃▁▂▁▁▁
val_precision,▁▄██▅▅▆▆
val_recall,▁▃▅▃▂▇█▇

0,1
train_f1,0.0536
train_loss,0.94901
train_precision,0.02782
train_recall,0.73326
val_f1,0.0547
val_loss,0.94766
val_precision,0.02842
val_recall,0.72779


In [128]:
def train():
    for epoch in range(1, epochs + 1):
        model.train()
        train_loss = []
        train_recall = []
        train_precision = []
        for g in tqdm(train_data, ncols=100):
            g.to(device)
            optimizer.zero_grad()

            out = model(g.x, g.edge_index)
            y_true = g.adj_mat
            loss = criterion(out, y_true.to(torch.float32))            
            
            loss.backward()
            optimizer.step()
            
#             out = out.sigmoid()
            
            train_loss.append(loss.item())
            train_precision.append(precision(out, y_true).item())
            train_recall.append(recall(out, y_true).item())
            
        train_prec = np.mean(train_precision)
        train_rec = np.mean(train_recall)
        train_f1 = (2 * train_prec * train_rec) / (train_prec + train_rec)
        print(f'Epoch: {epoch:03d}, loss: {np.mean(train_loss)}, f1: {train_f1}, precision: {train_prec}, recall: {train_rec}')
        
        val_loss = []
        val_recall = []
        val_precision = []
        with torch.no_grad():
            for g in tqdm(val_data, ncols=100):
                g.to(device)
                out = model(g.x, g.edge_index)
        
                y_true = g.adj_mat
                loss = criterion(out, y_true.to(torch.float32))
                                
                val_loss.append(loss.item())
                
#                 out = out.sigmoid()
                
                val_precision.append(precision(out, y_true).item())
                val_recall.append(recall(out, y_true).item())
 
            
            prec = np.mean(val_precision)
            rec = np.mean(val_recall)
            f1 = (2 * prec * rec) / (prec + rec)
            print(f'val_loss: {np.mean(val_loss)}, val_f1: {f1}, val_precision: {prec}, val_recall: {rec}')
              
            wandb.log({"train_loss": np.mean(train_loss), "train_f1": train_f1, "train_precision": train_prec, 
                       "train_recall": train_rec,
                       "val_loss": np.mean(val_loss), "val_f1": f1, "val_precision": prec, "val_recall": rec})
#     wandb.finish()

In [129]:
train()
torch.save(model, "./models1/" + run_name + ".pt")

100%|█████████████████████████████████████████████████████████████| 655/655 [00:13<00:00, 48.00it/s]


Epoch: 001, loss: 0.9564117250551705, f1: 0.048399164260043266, precision: 0.02490677572680168, recall: 0.8522893413332583


100%|█████████████████████████████████████████████████████████████| 218/218 [00:02<00:00, 73.89it/s]


val_loss: 0.9501886709567604, val_f1: 0.053723170987499624, val_precision: 0.02790109275780413, val_recall: 0.7209835028566352


100%|█████████████████████████████████████████████████████████████| 655/655 [00:13<00:00, 47.46it/s]


Epoch: 002, loss: 0.950272004203942, f1: 0.05350514683116522, precision: 0.027816967974694174, recall: 0.6991537786167087


100%|█████████████████████████████████████████████████████████████| 218/218 [00:02<00:00, 74.66it/s]


val_loss: 0.9542467588678413, val_f1: 0.04821797242790296, val_precision: 0.025040142260568745, val_recall: 0.64832575243274


100%|█████████████████████████████████████████████████████████████| 655/655 [00:13<00:00, 49.02it/s]


Epoch: 003, loss: 0.9524099703963477, f1: 0.051561857446933014, precision: 0.026798567948294164, recall: 0.678916410953944


100%|█████████████████████████████████████████████████████████████| 218/218 [00:02<00:00, 75.33it/s]


val_loss: 0.9513495370335535, val_f1: 0.05169321310033403, val_precision: 0.026860359759783424, val_recall: 0.6848305326399453


100%|█████████████████████████████████████████████████████████████| 655/655 [00:13<00:00, 49.40it/s]


Epoch: 004, loss: 0.9496321328723704, f1: 0.05363965683474845, precision: 0.02788679541287702, recall: 0.7009767605148199


100%|█████████████████████████████████████████████████████████████| 218/218 [00:02<00:00, 75.59it/s]


val_loss: 0.9482684263942438, val_f1: 0.05397583405069812, val_precision: 0.028068606467242567, val_recall: 0.7009536625321852


100%|█████████████████████████████████████████████████████████████| 655/655 [00:13<00:00, 49.24it/s]


Epoch: 005, loss: 0.9482178109292766, f1: 0.054722209929577825, precision: 0.02845593225957385, recall: 0.7111493777455264


100%|█████████████████████████████████████████████████████████████| 218/218 [00:02<00:00, 74.82it/s]


val_loss: 0.9472603398725528, val_f1: 0.05489706000859135, val_precision: 0.028549068827404205, val_recall: 0.7120420977883383


100%|█████████████████████████████████████████████████████████████| 655/655 [00:13<00:00, 49.20it/s]


Epoch: 006, loss: 0.9469472504754103, f1: 0.05595627319148581, precision: 0.02910099888277304, recall: 0.7251038094513289


100%|█████████████████████████████████████████████████████████████| 218/218 [00:02<00:00, 75.98it/s]


val_loss: 0.947360385175145, val_f1: 0.05480991394630593, val_precision: 0.028491302206471496, val_recall: 0.7187428886600591


100%|█████████████████████████████████████████████████████████████| 655/655 [00:13<00:00, 49.07it/s]


Epoch: 007, loss: 0.9471746417402311, f1: 0.05566771355137401, precision: 0.02894954941669152, recall: 0.7222218580373371


100%|█████████████████████████████████████████████████████████████| 218/218 [00:03<00:00, 70.30it/s]


val_loss: 0.9475909345740572, val_f1: 0.05436464613058118, val_precision: 0.028274343712169917, val_recall: 0.7037983644719518


100%|█████████████████████████████████████████████████████████████| 655/655 [00:13<00:00, 48.66it/s]


Epoch: 008, loss: 0.9474131363948793, f1: 0.05537100600472769, precision: 0.028802175380694787, recall: 0.7140883292633159


100%|█████████████████████████████████████████████████████████████| 218/218 [00:02<00:00, 76.03it/s]


val_loss: 0.9470388656909313, val_f1: 0.05506990345676788, val_precision: 0.028640552948297367, val_recall: 0.7132917626188435


100%|█████████████████████████████████████████████████████████████| 655/655 [00:13<00:00, 49.25it/s]


Epoch: 009, loss: 0.947046730536541, f1: 0.05569566062539156, precision: 0.028961581735736656, recall: 0.7241447212359378


100%|█████████████████████████████████████████████████████████████| 218/218 [00:02<00:00, 73.91it/s]


val_loss: 0.9467635958566578, val_f1: 0.05520206601797091, val_precision: 0.028703304802264094, val_recall: 0.7187345880980885


100%|█████████████████████████████████████████████████████████████| 655/655 [00:13<00:00, 49.22it/s]


Epoch: 010, loss: 0.9469556265204917, f1: 0.05578122270758847, precision: 0.029017562827962726, recall: 0.7181476628052369


100%|█████████████████████████████████████████████████████████████| 218/218 [00:02<00:00, 75.97it/s]


val_loss: 0.9478037548721383, val_f1: 0.054109489015644704, val_precision: 0.028126353636136227, val_recall: 0.7101015579809836


100%|█████████████████████████████████████████████████████████████| 655/655 [00:13<00:00, 48.87it/s]


Epoch: 011, loss: 0.9474598158406847, f1: 0.05496229301044916, precision: 0.028578916048090763, recall: 0.7154337483507986


100%|█████████████████████████████████████████████████████████████| 218/218 [00:03<00:00, 68.55it/s]


val_loss: 0.945672396672975, val_f1: 0.05696230243812407, val_precision: 0.029659688729075116, val_recall: 0.7167714838314494


100%|█████████████████████████████████████████████████████████████| 655/655 [00:13<00:00, 47.43it/s]


Epoch: 012, loss: 0.94767626951669, f1: 0.055004042679420065, precision: 0.02860963919594571, recall: 0.7103739825596336


100%|█████████████████████████████████████████████████████████████| 218/218 [00:02<00:00, 74.70it/s]


val_loss: 0.9476039054196909, val_f1: 0.05471734206046391, val_precision: 0.02846596182383802, val_recall: 0.7033300462665908


100%|█████████████████████████████████████████████████████████████| 655/655 [00:13<00:00, 49.31it/s]


Epoch: 013, loss: 0.9478034346158268, f1: 0.054655071457296825, precision: 0.028432710203413234, recall: 0.7030527049689802


100%|█████████████████████████████████████████████████████████████| 218/218 [00:02<00:00, 76.05it/s]


val_loss: 0.9492816104801423, val_f1: 0.05292332293884149, val_precision: 0.027524616882747568, val_recall: 0.6852094498932908


100%|█████████████████████████████████████████████████████████████| 655/655 [00:13<00:00, 48.79it/s]


Epoch: 014, loss: 0.9472767717965687, f1: 0.05542295997608934, precision: 0.02882821135656581, recall: 0.715366671535805


100%|█████████████████████████████████████████████████████████████| 218/218 [00:03<00:00, 68.10it/s]


val_loss: 0.9486001754572632, val_f1: 0.05343754266613376, val_precision: 0.027792466235124942, val_recall: 0.6916122528801271


 51%|███████████████████████████████▍                             | 337/655 [00:07<00:06, 46.66it/s]


KeyboardInterrupt: 