In [202]:
import os.path as osp
from tqdm.auto import tqdm
import numpy as np
import wandb

import torch
from PIL import Image
import torchvision.transforms as T
from torchvision.utils import save_image

import torch.nn.functional as F
# from ogb.nodeproppred import Evaluator, PygNodePropPredDataset
from torch.nn import LayerNorm, Linear, ReLU, Tanh

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, GatedGraphConv, GCN
from torch_geometric.nn import DeepGCNLayer, GENConv, MessageNorm
from torch_geometric.utils import negative_sampling, to_dense_adj
from torch_geometric.loader import DataLoader

from dataset_processing import RNADataset


In [2]:
dataset = RNADataset(root="./data/")
dataset = dataset.shuffle()
train_data, val_data, test_data = dataset[0:655], dataset[655:873], dataset[873:]
bs = 4

train_dataloader = DataLoader(train_data, batch_size=bs, shuffle=False)
val_dataloader = DataLoader(val_data, batch_size=1, shuffle=False)


In [203]:
save_image(torch.rand([200, 200]), "./pictures/test.jpg")

In [4]:
def del_nucl_conn(y):
#     main_diag = torch.zeros([1, y.size()[0] // bs])
#     diag1 = torch.ones([1, (y.size()[0] // bs) - 1])
    
    main_diag = torch.zeros([1, y.size()[0]])
    diag1 = torch.ones([1, y.size()[0] - 1])
#     batch_diags = []
#     for i in range(bs):
#         batch_diags.append((torch.diag_embed(main_diag).to(device) +
#             torch.diag_embed(diag1, offset=1).to(device) + 
#             torch.diag_embed(diag1, offset=-1).to(device) - 1)[0])

#     return y * torch.cat(batch_diags, 0).to(device) * (-1)
    return y * (torch.diag_embed(main_diag).to(device) + 
                    torch.diag_embed(diag1, offset=1).to(device) + 
                torch.diag_embed(diag1, offset=-1).to(device) - 1)[0] * (-1)


def del_main_diag(y):
#     main_diag = torch.ones([1, y.size()[0] // bs])
    main_diag = torch.ones([1, y.size()[0]])

#     batch_diags = []
#     for i in range(bs):
#         batch_diags.append((torch.diag_embed(main_diag).to(device) - 1)[0])
    
#     return y * torch.cat(batch_diags, 0).to(device) * (-1)
    return y * (torch.diag_embed(main_diag).to(device) - 1)[0] * (-1)


def adj_mat_split(y):
    y = list(torch.split(y, 196, dim=0))
#     print(y)
    for i in range(len(y)):
        y[i] = torch.split(y[i], 196, dim=1)
    
    y1 = []
    for i in range(len(y)):
        y1.append(y[i][i])
    
    return torch.cat(y1, 0).to(device)
    

def argm(y):
    max_ind = torch.argmax(y, dim=1)
    y1 = torch.zeros_like(y)
    k = 0
    for i in max_ind:
        y1[k][i] = 1.
        y1[i][k] = 1.
        k += 1
    
    return y1 * y


In [272]:
a = 0.8

def precision(y_pred, y_true, y_in):
    y_pred = y_pred + y_in
    y_pred = del_nucl_conn(y_pred)
    y_pred[(y_pred > a)] = 1
    y_pred[(y_pred <= a)] = 0 
#     y_pred = argm(y_pred)
    y_pred = del_main_diag(y_pred)
    y_pred = adj_mat_split(y_pred)

    tp = torch.sum(y_pred * y_true)
    fp = torch.sum((1 - y_true) * y_pred)
    
    return tp / (tp + fp + epsilon)

def recall(y_pred, y_true, y_in):
    y_pred = y_pred + y_in
    y_pred = del_nucl_conn(y_pred)
    y_pred[(y_pred > a)] = 1
    y_pred[(y_pred <= a)] = 0 
#     y_pred = argm(y_pred)
    y_pred = del_main_diag(y_pred)
    y_pred = adj_mat_split(y_pred)

    tp = torch.sum(y_pred * y_true)
    fn = torch.sum(y_true * (1 - y_pred))
    
    return tp / (tp + fn + epsilon)

def f_loss(y_pred, y_true, y_in):
    y_pred = y_pred + y_in
    y_pred = del_nucl_conn(y_pred)
#     y_pred = argm(y_pred)
    y_pred = del_main_diag(y_pred)
    y_pred = adj_mat_split(y_pred)

    tp = torch.sum(y_pred * y_true)
    fn = torch.sum(y_true * (1 - y_pred))
    fp = torch.sum((1 - y_true) * y_pred)
    precision = tp / (tp + fp + epsilon)
    recall = tp / (tp + fn + epsilon)

    k = 1
    k1 = 1 - torch.abs(precision - recall)

#     k2 = 1 - torch.abs(K.mean(precision) - K.mean(recall))
    #calculate upgraded f1 score
    f1 = ((1 + k * k) * precision * recall) / (k * k * precision + recall + epsilon)
#     tw = K.sum(K.cast(y_true * y_pred, ’float32’), axis=[1, 2, 3])
#     fw = K.sum(K.cast((1 - y_true) * y_pred, ’float32’), axis=[1, 2, 3])
#     fb = K.sum(K.cast(y_true * (1 - y_pred), ’float32’), axis=[1, 2, 3])
    return 1 - f1


def mse_loss(y_pred, y_true):
#    y_pred = torch.tanh((y_pred*3)**5)
#     y_true = y_true - y_in
    y_pred = del_nucl_conn(y_pred)
#     y_pred[(y_pred > a)] = 1
#     y_pred[(y_pred <= a)] = 0 
#     y_pred = argm(y_pred)
    y_pred = del_main_diag(y_pred)
    y_true = del_nucl_conn(y_true)
    y_true = del_main_diag(y_true)
    
#     y_pred = adj_mat_split(y_pred)
#     return torch.sqrt(torch.mean((y_pred - y_true) ** 2))
    return (torch.sum((y_pred - y_true) ** 2))


def distance_loss(y_pred, y_true):
   # print(">>>>>>>>>>>>>>>>>>>>>>>>>>")
    #y_pred = torch.tanh((y_pred*3)**5)
    y_pred = del_nucl_conn(y_pred)
    y_pred = del_main_diag(y_pred)
    y_true = del_nucl_conn(y_true)
    y_true = del_main_diag(y_true)
    
    
    
   # print(y_pred)
    #print(y_true)
    
    #print ("---------------------------------")
    
    
    c = 0.001
    
#     y_pred[0][0] = 0.01
#     y_true[0][0] = 0.01
#     y_pred[(torch.abs(y_pred) < c)] = 0
#    c = torch.rand(y_true.size()).to(device)
#     k = 0.4
#     c[(c > k)] = 0
#     c += y_true
#     y_true += c
    #y_pred.where(y_pred > c or y_pred < -c, y_pred, 0)
    
    #print(y_pred)
    
    #print("<<<<<<<<<<<<<<<<<<<<<<<<<")
    
    #c = 0.0001
    y_true_local = y_true + c
    y_true_local = torch.triu(y_true_local, diagonal=2)
    y_true_local = y_true_local.to_sparse()
    #c_sparse = torch.triu(c, diagonal=2).to_sparse()
    y_true_local = torch.cat([y_true_local.indices(), torch.unsqueeze((y_true_local.values() - c) , 0)], dim=0)
#     y_true_local = torch.cat([y_true_local.indices(), torch.unsqueeze(y_true_local.values() * 100, 0)], dim=0)
    
    y_pred_local = torch.triu(y_pred, diagonal=2)
    y_pred_local = y_pred_local.to_sparse()
    y_pred_local = torch.cat([y_pred_local.indices(), torch.unsqueeze(y_pred_local.values(), 0)], dim=0)
    dist2 = torch.cdist(y_true_local.t(), y_pred_local.t())
    x_min,ind1 = torch.min(dist2, dim=0)
    y_min,ind1 = torch.min(dist2, dim=1)
#     return torch.sum(torch.sqrt(x_min))/(x_min.size()[0]) + torch.sum(torch.sqrt(y_min))/(y_min.size()[0])
    return torch.sum(x_min ** 2)/(x_min.size()[0]) + torch.sum(y_min ** 2)/(y_min.size()[0])

    #return torch.sum(x_min)/(x_min.size()[0])
    


In [279]:
class DeepGCNModel(torch.nn.Module):
    def __init__(self, num_layers, out_channels):
        super(DeepGCNModel, self).__init__()
        self.layers = torch.nn.ModuleList()
#         conv1 = GCN(dataset.num_features, out_channels, num_layers=2)
#         layer = DeepGCNLayer(conv1, block='res', dropout=0)

#         self.layers.append(layer)
        for i in range(1, num_layers + 1):

            conv = GatedGraphConv(out_channels, num_layers=6)

            layer = DeepGCNLayer(conv, block='res', dropout=0.7)
            self.layers.append(layer)

        self.lin = Linear(out_channels, 1)

    def forward(self, x, edge_index):
        x = self.layers[0].conv(x, edge_index)
        for layer in self.layers[1:]:
            x = layer(x, edge_index)
            
        x1, x2 = torch.split(x, out_channels // 2, dim=1)
        #!!return torch.tanh(((torch.tanh(x1 @ x2.t()))*20)**11)
#         return torch.tanh(((torch.tanh(x1 @ x2.t())))**11)
#         return (x1 @ x2.t()).sigmoid()

        
        return torch.tanh(((torch.tanh(x1 @ x2.t()))*1000))
        #return torch.tanh(torch.tanh(x @ x.t())**11)

In [280]:
num_layers = 3
# num_features = dataset.num_features
out_channels = 200


lr = 0.0001
epochs = 1500
beta = 1


model = DeepGCNModel(num_layers, out_channels)
# model = torch.load("./models1/GatedGCN_6_4e-05_Adam_200_2_60.pt")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
train_data = list(map(lambda x: x.to(device), train_data))
val_data = list(map(lambda x: x.to(device), val_data))
print(device)

# optimizer = torch.optim.Adam(model.parameters(), lr=lr)
#optimizer = torch.optim.Adagrad(model.parameters(), lr=lr)
optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, momentum=0)
#optimizer = torch.optim.Adamax(model.parameters(), lr=lr)
# optimizer = torch.optim.Adadelta(model.parameters(), lr=lr)

# optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

# criterion = torch.nn.BCELoss()
# criterion = torch.nn.BCEWithLogitsLoss()
# criterion = torch.nn.CrossEntropyLoss()

# def RMSELoss(y_pred, y_true):
#     return torch.sqrt(torch.mean((y_pred - y_true) ** 2))

# criterion = f_loss
# criterion = mse_loss
criterion = distance_loss

cuda


In [281]:
run_name = "ResGatedGCN_dist(2)_diff_mult_dist_01_dropout_07_" \
+ str(num_layers) + "*6_" + str(lr) + "_" + "RMSprop(0)_" + str(out_channels) + "_" + str(a)


wandb.init(
    # set the wandb project where this run will be logged
    project="secondary_structure_prediction1",
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": lr,
    "architecture": "ResGatedGCN",
    "epochs": epochs,
    "optimizer": "RMSProp",
    "out_channels": out_channels,
    "loss": "dist_loss",
    "beta": beta,
    "train:val:test": "655:218:218"
    },
    name=run_name
)

epsilon = 1e-10

In [282]:
def train():
    for epoch in range(1, epochs + 1):
        model.train()
        train_loss = []
        train_recall = []
        train_precision = []  
        k = 0
        for g in tqdm(train_data):
            g.to(device)
            optimizer.zero_grad()

            out = model(g.x, g.edge_index)
            y_true = g.adj_mat
#             print(out.size(),y_true.size())
            loss = criterion(out, y_true.to(torch.float32) - torch.squeeze(to_dense_adj(g.edge_index)))          
#             loss = criterion(out, y_true.to(torch.float32))          

            loss.backward()
            optimizer.step()
            
#             out = out.sigmoid()
            if k <= 2:
                save_image(y_true, f"./pictures/{k}/{epoch}_t.jpg")
                save_image(torch.squeeze(to_dense_adj(g.edge_index)), f"./pictures/{k}/{epoch}_p.jpg")
                save_image(out + torch.squeeze(to_dense_adj(g.edge_index)), f"./pictures/{k}/{epoch}_nn.jpg")
            k += 1
            train_loss.append(loss.item())
            train_precision.append(precision(out, y_true, torch.squeeze(to_dense_adj(g.edge_index))).item())
            train_recall.append(recall(out, y_true, torch.squeeze(to_dense_adj(g.edge_index))).item())
            
        train_prec = np.mean(train_precision)
        train_rec = np.mean(train_recall)
        train_f1 = (2 * train_prec * train_rec) / (train_prec + train_rec)
        print(f'Epoch: {epoch:03d}, loss: {np.mean(train_loss)}, f1: {train_f1}, precision: {train_prec}, recall: {train_rec}')
        
        val_loss = []
        val_recall = []
        val_precision = []
        with torch.no_grad():
            for g in val_data:
                g.to(device)
                out = model(g.x, g.edge_index)
        
                y_true = g.adj_mat
                loss = criterion(out, y_true.to(torch.float32) - torch.squeeze(to_dense_adj(g.edge_index)))
#                 loss = criterion(out, y_true.to(torch.float32))          

                                
                val_loss.append(loss.item())
                
#                 out = out.sigmoid()
                
                val_precision.append(precision(out, y_true, torch.squeeze(to_dense_adj(g.edge_index))).item())
                val_recall.append(recall(out, y_true, torch.squeeze(to_dense_adj(g.edge_index))).item())
 
            
            prec = np.mean(val_precision)
            rec = np.mean(val_recall)
            f1 = (2 * prec * rec) / (prec + rec)
            print(f'val_loss: {np.mean(val_loss)}, val_f1: {f1}, val_precision: {prec}, val_recall: {rec}')
              
            wandb.log({"train_loss": np.mean(train_loss), "train_f1": train_f1, "train_precision": train_prec, 
                       "train_recall": train_rec,
                       "val_loss": np.mean(val_loss), "val_f1": f1, "val_precision": prec, "val_recall": rec})
#     wandb.finish()

In [283]:
train()
torch.save(model, "./models1/" + run_name + "_" + str(epochs) + ".pt")

100%|█████████████████████████████████████████| 655/655 [00:15<00:00, 42.92it/s]


Epoch: 001, loss: 0.26824436614872843, f1: 0.01631741928560638, precision: 0.010228219247451549, recall: 0.04032311365719288
val_loss: 0.22105226886176735, val_f1: 0.00579749487533227, val_precision: 0.0034384734580430007, val_recall: 0.01846726996601995


100%|█████████████████████████████████████████| 655/655 [00:13<00:00, 47.05it/s]


Epoch: 002, loss: 0.25562456257868815, f1: 0.004032807952499705, precision: 0.0023466979510698266, recall: 0.014326301538420997
val_loss: 0.28729069379047245, val_f1: 0.002580898211542894, val_precision: 0.001454842029088134, val_recall: 0.011420197156898746


100%|█████████████████████████████████████████| 655/655 [00:13<00:00, 50.08it/s]


Epoch: 003, loss: 0.33279641516620423, f1: 0.007465740388807624, precision: 0.005544600227380256, recall: 0.011424038102779225
val_loss: 0.3806228638976949, val_f1: 0.003910365277731591, val_precision: 0.0025465087810266386, val_recall: 0.008419870868872066


100%|█████████████████████████████████████████| 655/655 [00:12<00:00, 51.29it/s]


Epoch: 004, loss: 0.41222274216133675, f1: 0.004999181691735127, precision: 0.003411631005681084, recall: 0.009350116373189077
val_loss: 0.4407611657788447, val_f1: 0.004909096339119875, val_precision: 0.003642552843052518, val_recall: 0.007525914344389778


100%|█████████████████████████████████████████| 655/655 [00:13<00:00, 50.04it/s]


Epoch: 005, loss: 0.49926225784739464, f1: 0.0037771129171830765, precision: 0.0024504255897984024, recall: 0.008236378930395341
val_loss: 0.5499841307520593, val_f1: 0.003273617598170051, val_precision: 0.0024282539267639776, val_recall: 0.0050219367770905346


100%|█████████████████████████████████████████| 655/655 [00:12<00:00, 50.75it/s]


Epoch: 006, loss: 0.6033689463269392, f1: 0.004006912220368212, precision: 0.0033356238405383726, recall: 0.0050164673804440575
val_loss: 0.6614232811171117, val_f1: 0.001287442095198129, val_precision: 0.0007761924745463723, val_recall: 0.003771767575601372


100%|█████████████████████████████████████████| 655/655 [00:12<00:00, 53.81it/s]


Epoch: 007, loss: 0.721413021506244, f1: 0.0023506424725435134, precision: 0.0014911033098500592, recall: 0.005549793772911297
val_loss: 0.7582952199896814, val_f1: 0.0029434152734194104, val_precision: 0.002322601613699706, val_recall: 0.004017175611006011


100%|█████████████████████████████████████████| 655/655 [00:12<00:00, 52.01it/s]


Epoch: 008, loss: 0.816151365799872, f1: 0.001173348701672996, precision: 0.000667741017252605, recall: 0.0048323996180227695
val_loss: 0.8958599817845116, val_f1: 0.0023223052224478295, val_precision: 0.001718957900975344, val_recall: 0.0035782601778318575


100%|█████████████████████████████████████████| 655/655 [00:12<00:00, 52.09it/s]


Epoch: 009, loss: 0.8962363801754385, f1: 0.002936480921867253, precision: 0.002225399989712955, recall: 0.00431536840317586
val_loss: 0.931377995767793, val_f1: 0.0014171611528358513, val_precision: 0.0008718691960643601, val_recall: 0.0037834209073847585


100%|█████████████████████████████████████████| 655/655 [00:11<00:00, 54.65it/s]


Epoch: 010, loss: 0.9532344544747176, f1: 0.0010328402687096254, precision: 0.0005831083304156329, recall: 0.004515475003432682
val_loss: 1.0050909169308773, val_f1: 0.0012867339626459907, val_precision: 0.0008030653890615347, val_recall: 0.0032352592770454533


100%|█████████████████████████████████████████| 655/655 [00:11<00:00, 55.50it/s]


Epoch: 011, loss: 1.0700062165051016, f1: 0.003207871066841501, precision: 0.002881297342237512, recall: 0.003617937500297113
val_loss: 1.1273870291786456, val_f1: 0.002081820703774592, val_precision: 0.0015303393924147438, val_recall: 0.0032547028957430374


100%|█████████████████████████████████████████| 655/655 [00:12<00:00, 54.55it/s]


Epoch: 012, loss: 1.1659509768752196, f1: 0.00133983739837746, precision: 0.0008545636577986693, recall: 0.0031004809355007785


KeyboardInterrupt: 

In [None]:
wandb.finish()