In [1]:
import os.path as osp
from tqdm.auto import tqdm
import numpy as np
import wandb

import torch
from sklearn.metrics import roc_auc_score

import torch.nn.functional as F
# from ogb.nodeproppred import Evaluator, PygNodePropPredDataset
from torch.nn import LayerNorm, Linear, ReLU, Tanh
from torchvision.ops import MLP
from torchvision.utils import save_image


import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, GatedGraphConv, SAGEConv, GAT
from torch_geometric.nn import DeepGCNLayer, GENConv
from torch_geometric.utils import negative_sampling, to_dense_adj
from torch_geometric.loader import DataLoader

from dataset_processing4 import RNADataset


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = RNADataset(root="./data4/")
dataset = dataset.shuffle()
train_data, val_data, test_data = dataset[0:655], dataset[655:873], dataset[873:]
bs = 1

train_dataloader = DataLoader(train_data, batch_size=bs, shuffle=False)
val_dataloader = DataLoader(val_data, batch_size=1, shuffle=False)

In [3]:
train_data[0]

Data(x=[32, 8], edge_index=[2, 284], y=[32], s=12)

In [4]:
dataset = list(filter(lambda x: x.s.item() <= 150, dataset))

In [5]:
dataset

[Data(x=[32, 8], edge_index=[2, 284], y=[32], s=12),
 Data(x=[265, 8], edge_index=[2, 7622], y=[265], s=39),
 Data(x=[1318, 8], edge_index=[2, 90516], y=[1318], s=81),
 Data(x=[118, 8], edge_index=[2, 2220], y=[118], s=26),
 Data(x=[4182, 8], edge_index=[2, 512632], y=[4182], s=144),
 Data(x=[1214, 8], edge_index=[2, 79780], y=[1214], s=77),
 Data(x=[996, 8], edge_index=[2, 57972], y=[996], s=75),
 Data(x=[1075, 8], edge_index=[2, 66582], y=[1075], s=71),
 Data(x=[576, 8], edge_index=[2, 25308], y=[576], s=54),
 Data(x=[2063, 8], edge_index=[2, 176642], y=[2063], s=101),
 Data(x=[322, 8], edge_index=[2, 10474], y=[322], s=43),
 Data(x=[2246, 8], edge_index=[2, 204260], y=[2246], s=110),
 Data(x=[186, 8], edge_index=[2, 4576], y=[186], s=29),
 Data(x=[1829, 8], edge_index=[2, 150110], y=[1829], s=99),
 Data(x=[97, 8], edge_index=[2, 1678], y=[97], s=22),
 Data(x=[2498, 8], edge_index=[2, 237924], y=[2498], s=118),
 Data(x=[560, 8], edge_index=[2, 24332], y=[560], s=55),
 Data(x=[348, 8]

In [6]:
train_data, val_data, test_data = dataset[0:655], dataset[655:873], dataset[873:]

In [7]:
def del_nucl_conn(y_pr, y_tr):    
    y_tr_local = torch.clone(y_tr)
    y_tr_local[(y_tr_local == 1)] = 0
    y_pr = y_pr + y_tr_local
    y_pr[(y_pr < 0)] = 0
    y_tr_local1 = torch.clone(y_tr)
    y_tr_local1[(y_tr_local1 == -1)] = 0
    return y_pr, y_tr_local1


def del_main_diag(y):
    main_diag = torch.ones([1, y.size()[0]])
    return y * (torch.diag_embed(main_diag).to(device) - 1)[0] * (-1)


def adj_mat_split(y):
    y = list(torch.split(y, 196, dim=0))
    for i in range(len(y)):
        y[i] = torch.split(y[i], 196, dim=1)
    
    y1 = []
    for i in range(len(y)):
        y1.append(argm(y[i][i]))
    return torch.cat(y1, 0).to(device)
    

def argm(y):
    max_ind = torch.argmax(y, dim=1)
    y1 = torch.zeros_like(y)
    k = 0
    for i in max_ind:
        y1[k][i] = 1.
        y1[i][k] = 1.
        k += 1
    
    return y1 * y


In [8]:
a = 0.5
beta = 1
def precision(y_pred, y_true):
    y_pred[(y_pred > a)] = 1
    y_pred[(y_pred <= a)] = 0
    y_pred, y_true = del_nucl_conn(y_pred, y_true)
     
#     y_pred = argm(y_pred)
    tp = torch.sum(y_pred * y_true)
    fp = torch.sum((1 - y_true) * y_pred)
    
    return tp / (tp + fp + epsilon)

def recall(y_pred, y_true):
    y_pred[(y_pred > a)] = 1
    y_pred[(y_pred <= a)] = 0 
    y_pred, y_true = del_nucl_conn(y_pred, y_true)
    
#     y_pred = argm(y_pred)
#     y_pred = adj_mat_split(y_pred)

    tp = torch.sum(y_pred * y_true)
    fn = torch.sum(y_true * (1 - y_pred))
    
    return tp / (tp + fn + epsilon)

def f_loss(y_pred, y_true):
    y_pred, y_true = del_nucl_conn(y_pred, y_true)    
    tp = torch.sum(y_pred * y_true)
    fn = torch.sum(y_true * (1 - y_pred))
    fp = torch.sum((1 - y_true) * y_pred)
    precision = tp / (tp + fp + epsilon)
    recall = tp / (tp + fn + epsilon)

    k1 = 1 - torch.abs(precision - recall)

    f1 = ((1 + beta * beta) * precision * recall) / (beta * beta * precision + recall + epsilon)
    
    
    # dp = abs(torch.mean(y_pred) - torch.mean(y_true))

    return (1 - f1)
#     return dp

def mse_loss(y_pred, y_true):
    y_pred, y_true = del_nucl_conn(y_pred, y_true)
    
    return torch.mean((y_pred - y_true) ** 2)

In [9]:
class DeepGCNModel(torch.nn.Module):
    def __init__(self, num_blocks, hidden_channels, num_layers, dr, mlp):
        super(DeepGCNModel, self).__init__()
        self.layers = torch.nn.ModuleList()
        gat1 = GAT(8, hidden_channels, num_layers=num_layers)
        layer = DeepGCNLayer(gat1, block='res', act=torch.relu, dropout=dr)
        self.layers.append(layer)
        for i in range(1, num_blocks):
            conv = GAT(hidden_channels, hidden_channels, num_layers=num_layers, act=torch.relu)            

            layer = DeepGCNLayer(conv, block='res', act=torch.relu, dropout=dr)
            self.layers.append(layer)
            
            
        self.mlp = MLP(hidden_channels, mlp, dropout=0)

        
    def forward(self, x, edge_index):
        x = self.layers[0].conv(x, edge_index)
        for layer in self.layers[1:]:
            x = layer(x, edge_index)
        
        x = self.mlp(x)
        
        return x.tanh()


In [10]:
num_blocks = 7
num_layers = 2
hidden_channels = 70

dr = 0.15


lr = 0.00001
epochs = 150000
mlp = [20, 10, 1]


model = DeepGCNModel(num_blocks, hidden_channels, num_layers, dr,mlp)
# model = torch.load("./models1/ResGATConv__MSE_mlp_[200, 150, 100, 50, 10, 1]_1*2_0.0001_Adam_250_0.05_0.2_1_1500.pt")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
train_data = list(map(lambda x: x.to(device), train_data))
val_data = list(map(lambda x: x.to(device), val_data))
print(device)

# optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# optimizer = torch.optim.Adagrad(model.parameters(), lr=lr)
optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
#optimizer = torch.optim.Adamax(model.parameters(), lr=lr)
# optimizer = torch.optim.Adadelta(model.parameters(), lr=lr)

# optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

# criterion = torch.nn.BCELoss()
# criterion = torch.nn.BCEWithLogitsLoss()
# criterion = torch.nn.CrossEntropyLoss()
# criterion = torch.nn.KLDivLoss()

def RMSELoss(y_pred, y_true):
    y_pred, y_true = del_nucl_conn(y_pred, y_true)

    return torch.sqrt(torch.sum((y_pred - y_true) ** 2))

# criterion = mse_loss
# criterion = RMSELoss
criterion = f_loss

epsilon = 1e-30

cuda


In [11]:
model

DeepGCNModel(
  (layers): ModuleList(
    (0-6): 7 x DeepGCNLayer(block=res)
  )
  (mlp): MLP(
    (0): Linear(in_features=70, out_features=20, bias=True)
    (1): ReLU()
    (2): Dropout(p=0, inplace=False)
    (3): Linear(in_features=20, out_features=10, bias=True)
    (4): ReLU()
    (5): Dropout(p=0, inplace=False)
    (6): Linear(in_features=10, out_features=1, bias=True)
    (7): Dropout(p=0, inplace=False)
  )
)

In [12]:
run_name = f'ResGATConv_relu_last_tanh_mlp_{mlp}_{num_blocks}*{num_layers}_{lr}_RMSProp_{hidden_channels}_{a}_{dr}_{beta}'


wandb.init(
    # set the wandb project where this run will be logged
    project="secondary_structure_prediction1",
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": lr,
    "architecture": "ResGATConv",
    "epochs": epochs,
    "optimizer": "RMSProp",
    "out_channels": hidden_channels,
    "loss": "f_loss",
    "beta": beta,
    "train:val:test": "655:218:218",
    "dataset": 4
    },
    name=run_name
)



[34m[1mwandb[0m: Currently logged in as: [33mchi-vinny0702[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [13]:
def train():
    for epoch in range(1, epochs + 1):
        model.train()
        train_loss = []
        train_recall = []
        train_precision = []
        for g in tqdm(train_data):

            optimizer.zero_grad()

            out = model(g.x, g.edge_index)
            y_true = g.y.unsqueeze(dim=1)
            loss = criterion(out, y_true)            
            
            loss.backward()
            optimizer.step()
            
            train_loss.append(loss.item())
            train_precision.append(precision(out, y_true).item())
            train_recall.append(recall(out, y_true).item())
            
        train_prec = np.mean(train_precision)
        train_rec = np.mean(train_recall)
        train_f1 = (2 * train_prec * train_rec) / (train_prec + train_rec)
        print(f'Epoch: {epoch:03d}, loss: {np.mean(train_loss)}, f1: {train_f1}, precision: {train_prec}, recall: {train_rec}')
        
        val_loss = []
        val_recall = []
        val_precision = []
        with torch.no_grad():
            for g in val_data:
                g.to(device)
                out = model(g.x, g.edge_index)
        
                y_true = g.y.unsqueeze(dim=1)
                loss = criterion(out, y_true)
                                
                val_loss.append(loss.item())
                                
                val_precision.append(precision(out, y_true).item())
                val_recall.append(recall(out, y_true).item())
 
            
            prec = np.mean(val_precision)
            rec = np.mean(val_recall)
            f1 = (2 * prec * rec) / (prec + rec)
            print(f'val_loss: {np.mean(val_loss)}, val_f1: {f1}, val_precision: {prec}, val_recall: {rec}')
              
            wandb.log({"train_loss": np.mean(train_loss), "train_f1": train_f1, "train_precision": train_prec, 
                       "train_recall": train_rec,
                       "val_loss": np.mean(val_loss), "val_f1": f1, "val_precision": prec, "val_recall": rec})
#     wandb.finish()

In [14]:
train()
torch.save(model, "./models1/" + run_name + "_" + str(epochs) + ".pt")

100%|█████████████████████████████████████████| 655/655 [00:33<00:00, 19.60it/s]
  train_f1 = (2 * train_prec * train_rec) / (train_prec + train_rec)


Epoch: 001, loss: 0.949062467440394, f1: nan, precision: 0.0, recall: 0.0


  f1 = (2 * prec * rec) / (prec + rec)


val_loss: 0.9309547024035673, val_f1: nan, val_precision: 0.0, val_recall: 0.0


100%|█████████████████████████████████████████| 655/655 [00:33<00:00, 19.78it/s]


Epoch: 002, loss: 0.9239199612886851, f1: 0.035707757985505025, precision: 0.019552992066752366, recall: 0.20545822385101373
val_loss: 0.9166770629379728, val_f1: 0.0851242840866743, val_precision: 0.04470101979116849, val_recall: 0.8895184195369755


 83%|█████████████████████████████████▉       | 543/655 [00:27<00:05, 20.94it/s]

In [None]:
torch.save(model, "./models1/" + run_name + "_" + str(22) + ".pt")

In [87]:
wandb.finish()