In [22]:
import math
import os.path as osp
from tqdm.auto import tqdm
import numpy as np
import wandb

import torch
from torch.nn import BCEWithLogitsLoss, Conv1d, MaxPool1d, ModuleList
from sklearn.metrics import roc_auc_score

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, MLP, global_sort_pool
from torch_geometric.loader import DataLoader

from dataset_processing import RNADataset


In [23]:
dataset = RNADataset(root="./data/")
dataset = dataset.shuffle()
train_data, val_data, test_data = dataset[0:655], dataset[655:873], dataset[873:]

train_dataloader = DataLoader(train_data, batch_size=1, shuffle=False)
val_dataloader = DataLoader(val_data, batch_size=1, shuffle=False)


In [5]:
def del_diags(y):
    main_diag = torch.ones([1, y.size(0)])
    diag1 = torch.ones([1, y.size(0) - 1])
    return y * (torch.diag_embed(main_diag).to(device) + 
            torch.diag_embed(diag1, offset=1).to(device) + 
            torch.diag_embed(diag1, offset=-1).to(device) - 1) * (-1)

In [6]:
def precision(y_pred, y_true):
    y_pred = del_diags(y_pred)
    y_pred[(y_pred > 0.5)] = 1
    y_pred[(y_pred <= 0.5)] = 0 
    
    tp = torch.sum(y_pred * y_true)
    fp = torch.sum((1 - y_true) * y_pred)
    
    return tp / (tp + fp + epsilon)

def recall(y_pred, y_true):
    y_pred = del_diags(y_pred)
    y_pred[(y_pred > 0.5)] = 1
    y_pred[(y_pred <= 0.5)] = 0
    
    tp = torch.sum(y_pred * y_true)
    fn = torch.sum(y_true * (1 - y_pred))
    
    return tp / (tp + fn + epsilon)

def f1_loss(y_pred, y_true):
    y_pred = del_diags(y_pred)

    tp = torch.sum(y_pred * y_true)
    fn = torch.sum(y_true * (1 - y_pred))
    fp = torch.sum((1 - y_true) * y_pred)
    precision = tp / (tp + fp + epsilon)
    recall = tp / (tp + fn + epsilon)

#     k1 = 1 - torch.abs(precision - recall)
#     k2 = 1 - torch.abs(K.mean(precision) - K.mean(recall))
    #calculate upgraded f1 score
    f1 = 2 * precision * recall / (precision + recall + epsilon)
#     tw = K.sum(K.cast(y_true * y_pred, ’float32’), axis=[1, 2, 3])
#     fw = K.sum(K.cast((1 - y_true) * y_pred, ’float32’), axis=[1, 2, 3])
#     fb = K.sum(K.cast(y_true * (1 - y_pred), ’float32’), axis=[1, 2, 3])
    return 1 - f1


In [83]:
class DGCNN(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers, GNN=GCNConv, k=0.6):
        super().__init__()
        if k < 1:  # Transform percentile to number.
            num_nodes = sorted([data.num_nodes for data in train_data])
            k = num_nodes[int(math.ceil(k * len(num_nodes))) - 1]
            k = max(10, k)
        self.k = int(k)

        self.convs = ModuleList()
        self.convs.append(GNN(dataset.num_features, hidden_channels))
        for i in range(0, num_layers - 1):
            self.convs.append(GNN(hidden_channels, hidden_channels))
        self.convs.append(GNN(hidden_channels, 1))

        conv1d_channels = [16, 32]
        total_latent_dim = hidden_channels * num_layers + 1
        conv1d_kws = [total_latent_dim, 5]
        self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0],
                            conv1d_kws[0])
        self.maxpool1d = MaxPool1d(2, 2)
        self.conv2 = Conv1d(conv1d_channels[0], conv1d_channels[1],
                            conv1d_kws[1], 1)
        dense_dim = int((self.k - 2) / 2 + 1)
        dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]
        self.mlp = MLP([dense_dim, 128, 1], dropout=0.5, norm=None)

    def forward(self, x, edge_index, batch):
        xs = [x]
        for conv in self.convs:
            xs += [conv(xs[-1], edge_index).tanh()]
        x = torch.cat(xs[1:], dim=-1)

        # Global pooling.
#         x = global_sort_pool(x, batch, self.k)
#         x = x.unsqueeze(1)  # [num_graphs, 1, k * hidden]
#         x = self.conv1(x).relu()
#         x = self.maxpool1d(x)
#         x = self.conv2(x).relu()
#         x = x.view(x.size(0), -1)  # [num_graphs, dense_dim]

#         return self.mlp(x)
        return (x @ x.t()).sigmoid()

In [84]:
hidden_channels = 128
out_channels = 150
num_layers = 4
num_features = dataset.num_features


dr = 0.2
lr = 0.0001
epochs = 500

model = DGCNN(hidden_channels=32, num_layers=3)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
train_data = list(map(lambda x: x.to(device), train_data))
val_data = list(map(lambda x: x.to(device), val_data))
print(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# optimizer = torch.optim.Adagrad(model.parameters(), lr=lr)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

# criterion = torch.nn.BCELoss()
# criterion = torch.nn.BCEWithLogitsLoss()
# criterion = torch.nn.CrossEntropyLoss()
# criterion = torch.nn.KLDivLoss()

# def RMSELoss(y_pred, y_true):
#     return torch.sqrt(torch.mean((y_pred - y_true) ** 2))

criterion = f1_loss

cuda


In [26]:
run_name = "DGCNN_" + str(num_layers) + "_" + str(epochs) + "_" + str(lr) + "_" + "Adam_" + str(out_channels)


wandb.init(
    # set the wandb project where this run will be logged
    project="secondary_structure_prediction1",
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": lr,
    "architecture": "DGCNN",
    "epochs": epochs,
    "optimizer": "Adam",
    "out_channels": out_channels,
    "loss": "CELoss",
    "train:val:test": "655:218:218"
    },
    name=run_name
)

epsilon = 1e-10

[34m[1mwandb[0m: Currently logged in as: [33mchi-vinny0702[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [85]:
def train():
    for epoch in range(1, epochs + 1):
        model.train()
        train_loss = []
        train_recall = []
        train_precision = []
#         for g in tqdm(train_data, ncols=100):
        for g in tqdm(train_dataloader, ncols=100):
            g.to(device)
            optimizer.zero_grad()

            out = model(g.x, g.edge_index, g.batch)
            y_true = g.edge_label_index
            loss = criterion(out, y_true.to(torch.float32))            
            
            loss.backward()
            optimizer.step()
                        
            train_loss.append(loss.item())
            train_precision.append(precision(out, y_true).item())
            train_recall.append(recall(out, y_true).item())
            
        train_prec = np.mean(train_precision)
        train_rec = np.mean(train_recall)
        train_f1 = (2 * train_prec * train_rec) / (train_prec + train_rec)
        print(f'Epoch: {epoch:03d}, loss: {np.mean(train_loss)}, f1: {train_f1}, precision: {train_prec}, recall: {train_rec}')
        
        val_loss = []
        val_recall = []
        val_precision = []
        with torch.no_grad():
#             for g in tqdm(val_data, ncols=100):
            for g in tqdm(val_dataloader, ncols=100):

                g.to(device)
                out = model(g.x, g.edge_index, g.batch)
        
                y_true = g.edge_label_index
                loss = criterion(out, y_true.to(torch.float32))
                                
                val_loss.append(loss.item())
                
#                 out = out.sigmoid()
                
                val_precision.append(precision(out, y_true).item())
                val_recall.append(recall(out, y_true).item())
 
            
            prec = np.mean(val_precision)
            rec = np.mean(val_recall)
            f1 = (2 * prec * rec) / (prec + rec)
            print(f'val_loss: {np.mean(val_loss)}, val_f1: {f1}, val_precision: {prec}, val_recall: {rec}')
              
            wandb.log({"train_loss": np.mean(train_loss), "train_f1": train_f1, "train_precision": train_prec, 
                       "train_recall": train_rec,
                       "val_loss": np.mean(val_loss), "val_f1": f1, "val_precision": prec, "val_recall": rec})
#     wandb.finish()

In [77]:
train_data[0]

Data(x=[72, 4], edge_index=[2, 378], edge_label_index=[72, 72])

In [86]:
train()
torch.save(model, "./models1/" + run_name + ".pt")

100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 122.81it/s]


Epoch: 001, loss: 0.9585552319315553, f1: 0.04421297741962869, precision: 0.02263992811012871, recall: 0.9382308747022207


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 206.54it/s]


val_loss: 0.9567201517590689, val_f1: 0.04918989755258823, val_precision: 0.02529288530819665, val_recall: 0.8913091522291166


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 113.59it/s]


Epoch: 002, loss: 0.9509755328411365, f1: 0.05350269325079246, precision: 0.02761827220621284, recall: 0.8522369039877681


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 193.06it/s]


val_loss: 0.9526484739889792, val_f1: 0.05123210499452371, val_precision: 0.026412869501354838, val_recall: 0.8491202475280937


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 122.25it/s]


Epoch: 003, loss: 0.9483764396369002, f1: 0.05545002265935311, precision: 0.02867613679814714, recall: 0.8359004610594902


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 211.34it/s]


val_loss: 0.9513494096217899, val_f1: 0.05186729872458531, val_precision: 0.026758832801475164, val_recall: 0.840969600535314


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 125.70it/s]


Epoch: 004, loss: 0.9473468919746748, f1: 0.0560411753367791, precision: 0.02899743144343543, recall: 0.8317860958685401


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 210.72it/s]


val_loss: 0.9507056168459971, val_f1: 0.052186086194936136, val_precision: 0.026933471491739767, val_recall: 0.8362119095314533


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 127.66it/s]


Epoch: 005, loss: 0.9467337723906714, f1: 0.05627201974528552, precision: 0.029125385811438656, recall: 0.828271752527652


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 209.66it/s]


val_loss: 0.9502872015358111, val_f1: 0.05252237358589019, val_precision: 0.027114496312163974, val_recall: 0.8344672553309607


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 126.73it/s]


Epoch: 006, loss: 0.9463461511917697, f1: 0.056662006185841614, precision: 0.029333153040368706, recall: 0.8292547952810316


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 208.29it/s]


val_loss: 0.95000690905326, val_f1: 0.05244466960504162, val_precision: 0.027075765294518425, val_recall: 0.8319245553618178


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 127.21it/s]


Epoch: 007, loss: 0.9460882300638971, f1: 0.05670034882251457, precision: 0.02935563510226953, recall: 0.8277174657765236


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 209.84it/s]


val_loss: 0.9498136636860873, val_f1: 0.05256624431050416, val_precision: 0.02714207061510875, val_recall: 0.8305252187569206


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 118.76it/s]


Epoch: 008, loss: 0.9459057170016165, f1: 0.056814377095324865, precision: 0.02941719386691071, recall: 0.8273815561114377


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 201.78it/s]


val_loss: 0.9496767458018907, val_f1: 0.052567701407044656, val_precision: 0.02714342388661937, val_recall: 0.8299859919405859


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 124.25it/s]


Epoch: 009, loss: 0.9457669664885252, f1: 0.056866999850075234, precision: 0.029446346882574317, recall: 0.8266428562066027


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 206.99it/s]


val_loss: 0.9495778403697758, val_f1: 0.05254714954814287, val_precision: 0.02713446370902719, val_recall: 0.8281200190476321


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 124.72it/s]


Epoch: 010, loss: 0.9456552489113261, f1: 0.05691830663601604, precision: 0.02947416887280908, recall: 0.8264010210528628


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 201.51it/s]


val_loss: 0.9495058330374027, val_f1: 0.05257855480192001, val_precision: 0.02715192409659471, val_recall: 0.8274586434484622


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 121.22it/s]


Epoch: 011, loss: 0.9455626704310642, f1: 0.05694524261919406, precision: 0.029489372673494215, recall: 0.8258063579788645


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 194.27it/s]


val_loss: 0.9494533270870874, val_f1: 0.0524140187362104, val_precision: 0.027065592179279944, val_recall: 0.8261383975591134


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 123.91it/s]


Epoch: 012, loss: 0.9454855789665048, f1: 0.057032541021097664, precision: 0.02953594691983872, recall: 0.8260022929606546


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 208.59it/s]


val_loss: 0.9494144282209764, val_f1: 0.052436418233784636, val_precision: 0.027077281298628506, val_recall: 0.8263773853899142


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 123.70it/s]


Epoch: 013, loss: 0.9454220103853531, f1: 0.057017259739134706, precision: 0.029528639204124744, recall: 0.8253071977891995


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 204.27it/s]


val_loss: 0.9493840265711513, val_f1: 0.05240812509870732, val_precision: 0.027063045441386623, val_recall: 0.8255830728406206


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 125.61it/s]


Epoch: 014, loss: 0.9453691315104943, f1: 0.05705083280128625, precision: 0.029546374047374112, recall: 0.8255216320962396


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 206.09it/s]


val_loss: 0.94935913561681, val_f1: 0.05242959244891146, val_precision: 0.027074247917989184, val_recall: 0.8258124306934689


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 125.15it/s]


Epoch: 015, loss: 0.9453238129615784, f1: 0.057135450724313384, precision: 0.0295919464705583, recall: 0.8253829107484745


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 207.85it/s]


val_loss: 0.9493382245028784, val_f1: 0.05246184197882124, val_precision: 0.027090821224008473, val_recall: 0.8263949263806737


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 125.38it/s]


Epoch: 016, loss: 0.9452844545131421, f1: 0.057114652372647616, precision: 0.02958155241848204, recall: 0.8247885232663337


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 207.65it/s]


val_loss: 0.9493201227363096, val_f1: 0.052434989422944216, val_precision: 0.02707663359154665, val_recall: 0.8262709492663725


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 125.84it/s]


Epoch: 017, loss: 0.9452499381458487, f1: 0.057158791281914804, precision: 0.029605019377110853, recall: 0.8249551291684158


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 206.71it/s]


val_loss: 0.9493037153274642, val_f1: 0.052399735527316554, val_precision: 0.027058184377575685, val_recall: 0.8259432954252313


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 125.56it/s]


Epoch: 018, loss: 0.9452193695170279, f1: 0.05713961951610673, precision: 0.029594795299580194, recall: 0.8249069252996954


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 203.28it/s]


val_loss: 0.9492878634995277, val_f1: 0.052386968727589696, val_precision: 0.027051375936417748, val_recall: 0.8259432954252313


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 125.96it/s]


Epoch: 019, loss: 0.9451921554012153, f1: 0.05710170126005583, precision: 0.02957445373616487, recall: 0.8249054795003119


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 205.85it/s]


val_loss: 0.9492716693550075, val_f1: 0.0523675359714032, val_precision: 0.027041364147182587, val_recall: 0.8256156413106743


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 120.45it/s]


Epoch: 020, loss: 0.9451680329919772, f1: 0.05707230540228944, precision: 0.029558843955440263, recall: 0.8247803833648448


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 168.88it/s]


val_loss: 0.9492548887335926, val_f1: 0.05240358749864364, val_precision: 0.027060352058636897, val_recall: 0.8258376347089033


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 125.57it/s]


Epoch: 021, loss: 0.9451468714320933, f1: 0.057041738575311765, precision: 0.029542445759324513, recall: 0.8247803833648448


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 211.17it/s]


val_loss: 0.9492379828877405, val_f1: 0.05238472314790908, val_precision: 0.027050659583813145, val_recall: 0.8254949312417879


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 121.04it/s]


Epoch: 022, loss: 0.9451283741543312, f1: 0.05706086148535071, precision: 0.02955229511459621, recall: 0.8250994282824392


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 196.73it/s]


val_loss: 0.9492217045311534, val_f1: 0.052370523380393344, val_precision: 0.027043377237240655, val_recall: 0.825224426224691


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 122.55it/s]


Epoch: 023, loss: 0.9451119863349973, f1: 0.05703565193101438, precision: 0.029538648234651858, recall: 0.8251955850433758


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 197.49it/s]


val_loss: 0.949206773841053, val_f1: 0.05238713097529629, val_precision: 0.027051810290527726, val_recall: 0.8256191730225851


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 122.13it/s]


Epoch: 024, loss: 0.9450971013717069, f1: 0.05703945776931649, precision: 0.029540771488562396, recall: 0.8251318707720924


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 196.17it/s]


val_loss: 0.949193574419809, val_f1: 0.0524188237313437, val_precision: 0.027068039268861876, val_recall: 0.8262459210572987


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 122.99it/s]


Epoch: 025, loss: 0.9450833981273739, f1: 0.05704774403982662, precision: 0.029544355146202992, recall: 0.8258043503033295


100%|████████████████████████████████████████████████████████████| 218/218 [00:01<00:00, 202.59it/s]


val_loss: 0.9491820111187226, val_f1: 0.05242147140057856, val_precision: 0.027069451264960163, val_recall: 0.8262459210572987


100%|████████████████████████████████████████████████████████████| 655/655 [00:05<00:00, 123.96it/s]


Epoch: 026, loss: 0.9450709595935035, f1: 0.057072415362430504, precision: 0.029556692852318742, recall: 0.8265048326881788


 58%|██████████████████████████████████▋                         | 126/218 [00:00<00:00, 188.75it/s]


KeyboardInterrupt: 

In [None]:
run_name = "GatedGCN_" + str(num_layers) + "_" + "140" + "_" + str(lr) + "_" + "Adam_" + str(out_channels)


torch.save(model, "./models/" + run_name + ".pt")