In [1]:
import os.path as osp
from tqdm.auto import tqdm
import numpy as np
import wandb

import torch

import torch_geometric.transforms as T
from torch_geometric.nn import GatedGraphConv
from torch_geometric.loader import DataLoader

from dataset_processing import RNADataset


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
print(device)

dataset = RNADataset(root="./data/")
dataset = dataset.shuffle()
train_data, val_data, test_data = dataset[0:655], dataset[655:873], dataset[873:]

train_dataloader = DataLoader(train_data, batch_size=8, shuffle=False)
val_dataloader = DataLoader(val_data, batch_size=1, shuffle=False)


cuda


In [3]:
def del_nucl_conn(y):
    main_diag = torch.zeros([1, y.size(0)])
    diag1 = torch.ones([1, y.size(0) - 1])

    return y * (torch.diag_embed(main_diag).to(device) + 
                torch.diag_embed(diag1, offset=1).to(device) + 
            torch.diag_embed(diag1, offset=-1).to(device) - 1)[0] * (-1)

In [4]:
def del_main_diag(y):
    main_diag = torch.ones([1, y.size(0)])
    return y * (torch.diag_embed(main_diag).to(device) - 1)[0] * (-1)

In [5]:
a = torch.rand(4, 4)
print(a)
# print(del_nucl_conn(a))
b = del_nucl_conn(a.to(device))
print(b)
# argm(b)
c = argm(b)
print(c)
# argm(b)
print(del_main_diag(c))

tensor([[0.6422, 0.4878, 0.6735, 0.1176],
        [0.7981, 0.7673, 0.1410, 0.3133],
        [0.8790, 0.6172, 0.2034, 0.7041],
        [0.2295, 0.8916, 0.4766, 0.6383]])
tensor([[0.6422, -0.0000, 0.6735, 0.1176],
        [-0.0000, 0.7673, -0.0000, 0.3133],
        [0.8790, -0.0000, 0.2034, -0.0000],
        [0.2295, 0.8916, -0.0000, 0.6383]], device='cuda:0')


NameError: name 'argm' is not defined

In [35]:
def argm(y):
    max_ind = torch.argmax(y, dim=1)
    y1 = torch.zeros_like(y)
    k = 0
    for i in max_ind:
        y1[k][i] = 1.
        y1[i][k] = 1.
        k += 1
    
    return y1 * y


In [1]:
def precision(y_pred, y_true):
    y_pred = del_nucl_conn(y_pred)
#     y_pred[(y_pred > 0.5)] = 1
#     y_pred[(y_pred <= 0.5)] = 0 
    y_pred = argm(y_pred)
    y_pred = del_main_diag(y_pred)
    
    tp = torch.sum(y_pred * y_true)
    fp = torch.sum((1 - y_true) * y_pred)
    
    return tp / (tp + fp + epsilon)

def recall(y_pred, y_true):
    y_pred = del_nucl_conn(y_pred)
#     y_pred[(y_pred > 0.5)] = 1
#     y_pred[(y_pred <= 0.5)] = 0 
    y_pred = argm(y_pred)
    y_pred = del_main_diag(y_pred)
    
    tp = torch.sum(y_pred * y_true)
    fn = torch.sum(y_true * (1 - y_pred))
    
    return tp / (tp + fn + epsilon)

def f_loss(y_pred, y_true):
    y_pred = del_nucl_conn(y_pred)
#     y_pred = argm(y_pred)
    y_pred = del_main_diag(y_pred)
    
    tp = torch.sum(y_pred * y_true)
    fn = torch.sum(y_true * (1 - y_pred))
    fp = torch.sum((1 - y_true) * y_pred)
    precision = tp / (tp + fp + epsilon)
    recall = tp / (tp + fn + epsilon)

    k = 1
    k1 = 1 - torch.abs(precision - recall)

#     k2 = 1 - torch.abs(K.mean(precision) - K.mean(recall))
    #calculate upgraded f1 score
    f1 = ((1 + k * k) * precision * recall) / (k * k * precision + recall + epsilon)
#     tw = K.sum(K.cast(y_true * y_pred, ’float32’), axis=[1, 2, 3])
#     fw = K.sum(K.cast((1 - y_true) * y_pred, ’float32’), axis=[1, 2, 3])
#     fb = K.sum(K.cast(y_true * (1 - y_pred), ’float32’), axis=[1, 2, 3])
    return 1 - f1


def distance_loss(y_pred, y_true):
   # print(">>>>>>>>>>>>>>>>>>>>>>>>>>")
    #y_pred = torch.tanh((y_pred*3)**5)
    y_pred = del_nucl_conn(y_pred)
    y_pred = del_main_diag(y_pred)
    y_true = del_nucl_conn(y_true)
    y_true = del_main_diag(y_true)
    
    
    
   # print(y_pred)
    #print(y_true)
    
    #print ("---------------------------------")
    
    
    c = 0.001
    
#     y_pred[0][0] = 0.01
#     y_true[0][0] = 0.01
#     y_pred[(torch.abs(y_pred) < c)] = 0
#    c = torch.rand(y_true.size()).to(device)
#     k = 0.4
#     c[(c > k)] = 0
#     c += y_true
#     y_true += c
    #y_pred.where(y_pred > c or y_pred < -c, y_pred, 0)
    
    #print(y_pred)
    
    #print("<<<<<<<<<<<<<<<<<<<<<<<<<")
    
    #c = 0.0001
    y_true_local = y_true + c
    y_true_local = torch.triu(y_true_local, diagonal=2)
    y_true_local = y_true_local.to_sparse()
    #c_sparse = torch.triu(c, diagonal=2).to_sparse()
    y_true_local = torch.cat([y_true_local.indices(), torch.unsqueeze((y_true_local.values() - c) , 0)], dim=0)
#     y_true_local = torch.cat([y_true_local.indices(), torch.unsqueeze(y_true_local.values() * 100, 0)], dim=0)
    
    y_pred_local = torch.triu(y_pred, diagonal=2)
    y_pred_local = y_pred_local.to_sparse()
    y_pred_local = torch.cat([y_pred_local.indices(), torch.unsqueeze(y_pred_local.values(), 0)], dim=0)
    dist2 = torch.cdist(y_true_local.t(), y_pred_local.t())
    x_min,ind1 = torch.min(dist2, dim=0)
    y_min,ind1 = torch.min(dist2, dim=1)
#     return torch.sum(torch.sqrt(x_min))/(x_min.size()[0]) + torch.sum(torch.sqrt(y_min))/(y_min.size()[0])
    return torch.sum(x_min ** 2)/(x_min.size()[0]) + torch.sum(y_min ** 2)/(y_min.size()[0])

In [45]:
class GatedGCNModel(torch.nn.Module):
    def __init__(self, num_layers, out_channels):
        super(GatedGCNModel, self).__init__()
        self.ggcn = GatedGraphConv(out_channels, num_layers)
#         self.gcn = GCN(out_channels, hidden_channels, 2, 64, dropout=dropout)

    def forward(self, x, edge_index, edge_weight):
        x = self.ggcn(x, edge_index)
#         x = self.gcn(x.to(torch.float), edge_index)
        prob_adj = (x @ x.t()).sigmoid()

#         prob_adj = (x @ x.t())

#         return (prob_adj > 0).nonzero(as_tuple=False).t()
        return prob_adj


In [51]:
hidden_channels = 128
out_channels = 200
num_layers = 20
num_features = dataset.num_features


dr = 0.7
lr = 0.000002
epochs = 1500
beta = 1


model = GatedGCNModel(num_layers, out_channels)
# model = torch.load("./models1/GatedGCN_6_4e-05_Adam_200_2_60.pt")

model = model.to(device)
train_data = list(map(lambda x: x.to(device), train_data))
val_data = list(map(lambda x: x.to(device), val_data))
print(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.99))
# optimizer = torch.optim.Adagrad(model.parameters(), lr=lr)
# optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
# optimizer = torch.optim.Adamax(model.parameters(), lr=lr)
# optimizer = torch.optim.Adadelta(model.parameters(), lr=lr)

# optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

# criterion = torch.nn.BCELoss()
# criterion = torch.nn.BCEWithLogitsLoss()
# criterion = torch.nn.CrossEntropyLoss()
# criterion = torch.nn.KLDivLoss()

# def MSELoss(y_pred, y_true):
#     return torch.mean((y_pred - y_true) ** 2)

# criterion = f_loss
criterion = distance_loss

cuda


In [52]:
run_name = "GatedGCN_argm_" + str(num_layers) + "_" + str(lr) + "_" + "Adam_" + str(out_channels) + "_" + str(beta)


wandb.init(
    # set the wandb project where this run will be logged
    project="secondary_structure_prediction1",
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": lr,
    "architecture": "GatedGCN",
    "epochs": epochs,
    "optimizer": "Adam",
    "out_channels": out_channels,
    "loss": "f_loss(argmax)",
    "beta": beta,
    "train:val:test": "655:218:218"
    },
    name=run_name
)

epsilon = 1e-10

In [53]:
def train():
    for epoch in range(1, epochs + 1):
        model.train()
        train_loss = []
        train_recall = []
        train_precision = []
        for g in tqdm(train_data):
            g.to(device)
            optimizer.zero_grad()

            out = model(g.x, g.edge_index, g.edge_weight)
            y_true = g.adj_mat
            loss = criterion(out, y_true.to(torch.float32))            
            
            loss.backward()
            optimizer.step()
            
#             out = out.sigmoid()
            
            train_loss.append(loss.item())
            train_precision.append(precision(out, y_true).item())
            train_recall.append(recall(out, y_true).item())
            
        train_prec = np.mean(train_precision)
        train_rec = np.mean(train_recall)
        train_f1 = (2 * train_prec * train_rec) / (train_prec + train_rec)
        print(f'Epoch: {epoch:03d}, loss: {np.mean(train_loss)}, f1: {train_f1}, precision: {train_prec}, recall: {train_rec}')
        
        val_loss = []
        val_recall = []
        val_precision = []
        with torch.no_grad():
            for g in tqdm(val_data):
                g.to(device)
                out = model(g.x, g.edge_index, g.edge_weight)
        
                y_true = g.adj_mat
                loss = criterion(out, y_true.to(torch.float32))
                                
                val_loss.append(loss.item())
                
#                 out = out.sigmoid()
                
                val_precision.append(precision(out, y_true).item())
                val_recall.append(recall(out, y_true).item())
 
            
            prec = np.mean(val_precision)
            rec = np.mean(val_recall)
            f1 = (2 * prec * rec) / (prec + rec)
            print(f'val_loss: {np.mean(val_loss)}, val_f1: {f1}, val_precision: {prec}, val_recall: {rec}')
              
            wandb.log({"train_loss": np.mean(train_loss), "train_f1": train_f1, "train_precision": train_prec, 
                       "train_recall": train_rec,
                       "val_loss": np.mean(val_loss), "val_f1": f1, "val_precision": prec, "val_recall": rec})
#     wandb.finish()

In [54]:
train()
torch.save(model, "./models1/" + run_name + "_" + str(epochs) + ".pt")

  0%|          | 0/655 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
run_name = "GatedGCN_" + str(num_layers) + "_" + "140" + "_" + str(lr) + "_" + "Adam_" + str(out_channels)


torch.save(model, "./models/" + run_name + ".pt")

In [50]:
wandb.finish()

VBox(children=(Label(value='0.002 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.196361…

0,1
train_f1,▁▄▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇██▇██
train_loss,█▅▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁
train_precision,▁▄▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇██▇██
train_recall,▁▄▆▇▇▇▇▇▇▇▇█▇█████████
val_f1,▁▄▅▅▄▅▆▆▆▆▇▇███▇███▇▇█
val_loss,█▅▄▄▅▄▃▃▃▃▂▂▂▁▁▂▁▁▁▂▂▂
val_precision,▁▄▄▅▄▅▅▅▆▆▆▇▇██▇███▇▇█
val_recall,▁▅▅▆▅▆▆▆▇▇▇▇███▇███▇▇█

0,1
train_f1,0.03285
train_loss,0.96788
train_precision,0.02357
train_recall,0.0542
val_f1,0.03687
val_loss,0.96397
val_precision,0.02605
val_recall,0.06308
