In [1]:
import os
import math
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.nn.parameter import Parameter
from torch.autograd import Variable
from sklearn import metrics
from sklearn.model_selection import KFold, train_test_split
from scipy.stats import pearsonr
#from dataset import Dataset, collate_fn
SEED = 2333

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os

import pickle


i = 0

#feature_dir = 'Feature_vector//Feature_vector//'
graph_dir = 'tg//'
label_dir = 'tma_l//'




folders = os.listdir(graph_dir)
sequence_features = []
sequence_graphs = []
labels = []
sequence_names=[]
for i in range(len(folders)):
    sequence_features.append(pickle.load(open(graph_dir + folders[i] +  '//graph.pkl'  , "rb"))['node_features'])
    sequence_graphs.append(pickle.load(open(graph_dir + folders[i] +  '//graph.pkl'  , "rb"))['Adjacency_matrix'])
    labels.append(pickle.load(open(label_dir + folders[i] +  '//label.pkl'  , "rb"))[0][-1])
    sequence_names.append(folders[i])
    

#vectors = pickle.load(open(feature_dir + folders[i] +  '//vectors.pkl'  , "rb"))

#labels = pickle.load(open(label_dir + folders[i] +  '//label.pkl'  , "rb"))[0,-1]

In [17]:
Model_Path = './tm/'
Result_Path = './tm/'

In [4]:
MG = dict(zip(sequence_names, sequence_graphs))
WF=dict(zip(sequence_names, sequence_features))
zipped = list(zip(sequence_names, labels))
ds = pd.DataFrame(zipped, columns=['names', 'stability'])

In [5]:
def NodeM(name):
    return WF[name]
def GraphM(name):
    return MG[name]

In [6]:
names = ds['names'].values.tolist()

In [7]:
from torch.utils.data.sampler import Sampler
  
class Dataset(Dataset):

    def __init__(self, dataframe):
        self.names = dataframe['names'].values.tolist()
        self.labels = dataframe['stability'].values.tolist()

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        
        sequence_name = self.names[index]
        label = self.labels[index]
        
        sequence_feature = NodeM(sequence_name)

        # L * L
        sequence_graph = GraphM(sequence_name)
        
        
        sample = {'sequence_feature': sequence_feature,\
                  'sequence_graph': sequence_graph, \
                  'label': label, \
                  'sequence_name': sequence_name, \
                  }
        return sample


def collate_fn(batch):
    sequence_feature = []
    sequence_graph = []
    sequence_names = [] 
    labels=[]   
    for i in range(len(batch)):
        sequence_feature.append(batch[i]['sequence_feature'])
        sequence_feature=np.asarray(sequence_feature)
        sequence_graph.append(batch[i]['sequence_graph'])
        sequence_graph=np.asarray(sequence_graph)
        sequence_names.append(batch[i]['sequence_name'])
        labels.append(batch[i]['label'])
        labels= np.asarray(labels)

    sequence_feature = torch.from_numpy(sequence_feature).float()
    sequence_graph = torch.from_numpy(sequence_graph).float()
    labels= torch.from_numpy(labels)

    return sequence_feature,sequence_graph, labels, sequence_names

In [8]:
NUMBER_EPOCHS = 150
LEARNING_RATE = 1E-5
WEIGHT_DECAY = 10E-7
BATCH_SIZE = 1
NUM_CLASSES = 6

# GCN parameters
GCN_FEATURE_DIM = 2560
GCN_HIDDEN_DIM = 512
GCN_HIDDEN_DIM1 = 256
GCN_HIDDEN_DIM2 = 128
GCN_OUTPUT_DIM = 32

# Attention parameters
DENSE_DIM = 16
ATTENTION_HEADS = 4

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb
#from utils.utils import initialize_weights
import numpy as np


In [10]:
class GraphConvolution(nn.Module):

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = input @ self.weight    # X * W
        output = adj @ support           # A * X * W
        if self.bias is not None:        # A * X * W + b
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


class GCN(nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.gc1 = GraphConvolution(GCN_FEATURE_DIM, GCN_HIDDEN_DIM)
        self.ln1 = nn.LayerNorm(GCN_HIDDEN_DIM)
        self.gc3 = GraphConvolution(GCN_HIDDEN_DIM, GCN_HIDDEN_DIM1)
        self.ln3 = nn.LayerNorm(GCN_HIDDEN_DIM1)
        self.gc2 = GraphConvolution(GCN_HIDDEN_DIM1, GCN_OUTPUT_DIM)
        self.ln2 = nn.LayerNorm(GCN_OUTPUT_DIM)
        self.relu1 = nn.LeakyReLU(0.2,inplace=True)
        self.relu3 = nn.LeakyReLU(0.2,inplace=True)
        self.relu4 = nn.LeakyReLU(0.2,inplace=True)
        self.relu2 = nn.LeakyReLU(0.2,inplace=True)

    def forward(self, x, adj):  			# x.shape = (seq_len, GCN_FEATURE_DIM); adj.shape = (seq_len, seq_len)
        x = self.gc1(x, adj)  				# x.shape = (seq_len, GCN_HIDDEN_DIM)
        x = self.relu1(self.ln1(x))
        x = self.gc3(x, adj)  				# x.shape = (seq_len, GCN_HIDDEN_DIM)
        x = self.relu3(self.ln3(x))
        x = self.gc2(x, adj)
        output = self.relu2(self.ln2(x))	# output.shape = (seq_len, GCN_OUTPUT_DIM)
        return output


class Attention(nn.Module):
    def __init__(self, input_dim, dense_dim, n_heads):
        super(Attention, self).__init__()
        self.input_dim = input_dim
        self.dense_dim = dense_dim
        self.n_heads = n_heads
        self.fc1 = nn.Linear(self.input_dim, self.dense_dim)
        self.fc2 = nn.Linear(self.dense_dim, self.n_heads)

    def softmax(self, input, axis=1):
        input_size = input.size()
        trans_input = input.transpose(axis, len(input_size) - 1)
        trans_size = trans_input.size()
        input_2d = trans_input.contiguous().view(-1, trans_size[-1])
        soft_max_2d = torch.softmax(input_2d, dim=1)
        soft_max_nd = soft_max_2d.view(*trans_size)
        return soft_max_nd.transpose(axis, len(input_size) - 1)

    def forward(self, input):  				# input.shape = (1, seq_len, input_dim)
        x = torch.tanh(self.fc1(input))  	# x.shape = (1, seq_len, dense_dim)
        x = self.fc2(x)  					# x.shape = (1, seq_len, attention_hops)
        x = self.softmax(x, 1)
        attention = x.transpose(1, 2)  		# attention.shape = (1, attention_hops, seq_len)
        return attention


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.gcn = GCN()
        self.attention = Attention(GCN_OUTPUT_DIM, DENSE_DIM, ATTENTION_HEADS)
        self.fc_final = nn.Linear(GCN_OUTPUT_DIM, NUM_CLASSES)

        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

    def forward(self, x, adj):  											# x.shape = (seq_len, FEATURE_DIM); adj.shape = (seq_len, seq_len)
        x = x.float()
        x = self.gcn(x, adj)  												# x.shape = (seq_len, GAT_OUTPUT_DIM)

        x = x.unsqueeze(0).float()  										# x.shape = (1, seq_len, GAT_OUTPUT_DIM)
        att = self.attention(x)  											# att.shape = (1, ATTENTION_HEADS, seq_len)
        node_feature_embedding = att @ x 									# output.shape = (1, ATTENTION_HEADS, GAT_OUTPUT_DIM)
        node_feature_embedding_avg = torch.sum(node_feature_embedding,
                                               1) / self.attention.n_heads  # node_feature_embedding_avg.shape = (1, GAT_OUTPUT_DIM)
        logits = torch.sigmoid(self.fc_final(node_feature_embedding_avg))  	# output.shape = (1, NUM_CLASSES)
        Y_hat = torch.topk(logits, 1, dim = 1)[1]
        Y_prob = torch.softmax(logits, dim = 1)
        return logits, Y_hat, Y_prob

In [11]:
class GraphAttentionLayer(nn.Module):
    def __init__(self, inp, out, slope):
        
        super(GraphAttentionLayer, self).__init__()
        self.W = nn.Linear(inp, out, bias=False)
        self.a = nn.Linear(out*2, 1, bias=False)
        self.leakyrelu = nn.LeakyReLU(slope)
        self.softmax = nn.Softmax(dim=1)
  
    def forward(self, h, adj):
        Wh = self.W(h)
        Whcat = self.Wh_concat(Wh, adj)
        e = self.leakyrelu(self.a(Whcat).squeeze(2))
        zero_vec = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = self.softmax(attention)
        h_hat = torch.mm(attention, Wh)

        return h_hat
 
    def Wh_concat(self, Wh, adj):
        N = Wh.size(0)
        Whi = Wh.repeat_interleave(N, dim=0)
        Whj = Wh.repeat(N, 1)
        WhiWhj = torch.cat([Whi, Whj], dim=1)
        WhiWhj = WhiWhj.view(N, N, Wh.size(1)*2)

        return WhiWhj
 
class MultiHeadGAT(nn.Module):
    def __init__(self, inp, out, heads, slope):
        super(MultiHeadGAT, self).__init__()
        self.attentions = nn.ModuleList([GraphAttentionLayer(inp, out, slope) for _ in range(heads)])
        self.tanh = nn.Tanh()
  
    def forward(self, h, adj):
        heads_out = [att(h, adj) for att in self.attentions]
        out = torch.stack(heads_out, dim=0).mean(0)
    
        return self.tanh(out)
 
class GAT(nn.Module):
    def __init__(self, inp= 2560, out = 256, out2 =64, heads=4, slope=0.01):
        super(GAT, self).__init__()
        self.gat1 = MultiHeadGAT(inp, out, heads, slope)
        self.gat2 = MultiHeadGAT(out, out2, heads, slope)
        self.attention = Attention(out2, DENSE_DIM, ATTENTION_HEADS)
        #self.fc_final = nn.Linear(GCN_OUTPUT_DIM, NUM_CLASSES)
        self.fc_final = nn.Linear(out2, NUM_CLASSES)

        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
  
    def forward(self, h, adj):
        out = self.gat1(h, adj)
        out = self.gat2(out, adj)
        att = self.attention(out.unsqueeze(0).float())  											# att.shape = (1, ATTENTION_HEADS, seq_len)
        node_feature_embedding = att @ out 									# output.shape = (1, ATTENTION_HEADS, GAT_OUTPUT_DIM)
        node_feature_embedding_avg = torch.sum(node_feature_embedding,
                                               1) / self.attention.n_heads  # node_feature_embedding_avg.shape = (1, GAT_OUTPUT_DIM)
        logits = torch.sigmoid(self.fc_final(node_feature_embedding_avg))  	# output.shape = (1, NUM_CLASSES)
        Y_hat = torch.topk(logits, 1, dim = 1)[1]
        Y_prob = torch.softmax(logits, dim = 1)
        return logits, Y_hat, Y_prob
        #return out

In [12]:
def train_one_epoch(model, data_loader, epoch):

    epoch_loss_train = 0.0
    n_batches = 0
    for data in tqdm(data_loader):
        model.optimizer.zero_grad()
        sequence_feature,sequence_graph, labels, sequence_names = data

        sequence_feature = torch.squeeze(sequence_feature)
        sequence_graph = torch.squeeze(sequence_graph)
        if torch.cuda.is_available():
            features = Variable(sequence_feature.cuda())
            graphs = Variable(sequence_graph.cuda())
            y_true = Variable(labels.cuda())
        else:
            features = Variable(sequence_feature)
            graphs = Variable(sequence_graph)
            y_true = Variable(labels)

        logits,y_pred,Y_hat= model(features, graphs)
        #print(logits)
        #print(y_pred)
        #print(Y_hat)
        y_true = y_true.float()
        y_pred = y_pred.squeeze(0)
        #Y_hat = torch.argmax(y_pred)
        #print(y_pred)
        #print(logits)

        # calculate loss
        loss = model.criterion(logits, y_true.to(dtype=torch.long,non_blocking=False))
        #l2_lambda = 0.001
        #l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())
        #loss = loss + l2_lambda * l2_norm
        #print(loss)
        model.optimizer.zero_grad()
        #loss.requires_grad = True
        loss.backward()
       # clip_grad_norm_(model.parameters(), max_norm=10)

      #update
        model.optimizer.step()

        epoch_loss_train += loss.item()
        n_batches += 1

    epoch_loss_train_avg = epoch_loss_train / n_batches
    return epoch_loss_train_avg



In [13]:
def evaluate(model, data_loader):
    model.eval()

    epoch_loss = 0.0
    n_batches = 0
    valid_pred = []
    valid_true = []
    valid_name = []
    valid_logit = []

    for data in tqdm(data_loader):
        with torch.no_grad():
            sequence_feature,sequence_graph, labels, sequence_names = data

            sequence_feature = torch.squeeze(sequence_feature)
            sequence_graph = torch.squeeze(sequence_graph)
            if torch.cuda.is_available():
                features = Variable(sequence_feature.cuda())
                graphs = Variable(sequence_graph.cuda())
                y_true = Variable(labels.cuda())
            else:
                features = Variable(sequence_feature)
                graphs = Variable(sequence_graph)
                y_true = Variable(labels)

            logits, y_pred, Y_hat = model(features, graphs)
            #logits = logits.mean(dim=0)
            #Y_hat = torch.argmax(Y_prob)
            y_true = y_true.float()
            y_pred = y_pred.squeeze(0)

            loss = model.criterion(logits,y_true.to(dtype=torch.long,non_blocking=False))
            #l2_lambda = 0.001
            #l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())
            #loss = loss + l2_lambda * l2_norm
            #print(loss)
                
            y_pred = y_pred.cpu().detach().numpy().tolist()
            y_true = y_true.cpu().detach().numpy().tolist()
            valid_pred.append(y_pred)
            valid_true.append(y_true)
            valid_name.extend(sequence_names)
            valid_logit.append(logits)
            epoch_loss += loss.item()
            n_batches += 1
    epoch_loss_avg = epoch_loss / n_batches
    print(epoch_loss_avg)
    return epoch_loss_avg, valid_true, valid_pred, valid_name, valid_logit


In [14]:
def train(model, dataframe,valid_dataframe,fold=0):
    train_loader = DataLoader(dataset=Dataset(dataframe) ,batch_size=1, shuffle=True, num_workers=0,collate_fn=collate_fn)
    valid_loader = DataLoader(dataset=Dataset(valid_dataframe) ,batch_size=1, shuffle=True, num_workers=0,collate_fn=collate_fn)

    train_losses = []
    train_binary_acc = []

    valid_losses = []
    valid_binary_acc = []

    best_val_loss = 1000
    best_epoch = 0

    for epoch in range(NUMBER_EPOCHS):
        print("\n========== Train epoch " + str(epoch + 1) + " ==========")
        model.train()

        epoch_loss_train_avg = train_one_epoch(model, train_loader, epoch + 1)
        print(epoch_loss_train_avg)
        #print("========== Evaluate Train set ==========")
        #_, train_true, train_pred, _, _ = evaluate(model, train_loader)
        #print(len(train_pred), len(train_pred)
        #result_train,binpt,bintt = analysis(train_true, train_pred)
        #print("Train binary acc: ", result_train['binary_acc'])

        #train_binary_acc.append(result_train['binary_acc'])
        print("========== Evaluate Valid set ==========")
        epoch_loss_valid_avg, valid_true, valid_pred, valid_name ,valid_logit= evaluate(model, valid_loader)
        result_valid, binp,bint = analysis(valid_true, valid_pred)
        print("Valid binary acc: ", result_valid['binary_acc'])
        valid_binary_acc.append(result_valid['binary_acc'])
        if best_val_loss > epoch_loss_valid_avg:
            best_val_loss = epoch_loss_valid_avg
            best_epoch = epoch + 1
            checkpoint = {'state_dict': model.state_dict()}
            torch.save(checkpoint,  os.path.join(Model_Path, 'Fold' + str(fold) + '_score_best_model.pkl'))    
            #torch.save(model.state_dict(), os.path.join(Model_Path, 'Fold' + str(fold) + '_best_model.pkl'))
            valid_detail_dataframe = pd.DataFrame({'names': valid_name, 'stability': valid_true, 'prediction': valid_pred})
            valid_detail_dataframe.sort_values(by=['names'], inplace=True)
            valid_detail_dataframe.to_csv(Result_Path + 'Fold' + str(fold) + "_binary_valid_detail.csv", header=True, sep=',')
 
    #result_all = {
     #   'Train_binary_acc': train_binary_acc,
     #   'Valid_binary_acc': valid_binary_acc,}
    #result = pd.DataFrame(result_all)
    #print("Fold", str(fold), "Best epoch at", str(best_epoch))
    #result.to_csv('result.csv')

def analysis(y_true, y_pred):
    #print(len(y_pred))
    binary_pred = y_pred
    #print(binary_pred, y_true)
    binary_true= y_true


    binary_acc = metrics.accuracy_score(binary_true, binary_pred)
    
    result = {

        'binary_acc': binary_acc,

    }
    return result , binary_pred, binary_true

def cross_validation(all_dataframe,fold_number=10):
    print("split_seed: ", SEED)
    sequence_names = all_dataframe['names'].values
    sequence_labels = all_dataframe['stability'].values
    kfold = KFold(n_splits=fold_number, shuffle=True)
    fold = 0

    for train_index, valid_index in kfold.split(sequence_names, sequence_labels):
        print("\n========== Fold " + str(fold + 1) + " ==========")
        train_dataframe = all_dataframe.iloc[train_index, :]
        valid_dataframe = all_dataframe.iloc[valid_index, :]
        print("Training on", str(train_dataframe.shape[0]), "examples, Validation on", str(valid_dataframe.shape[0]),
              "examples")
        model= Model()

        train(model, train_dataframe, valid_dataframe, fold + 1)
        fold += 1

In [21]:
cross_validation(ds,fold_number=5)

split_seed:  2333

Training on 5188 examples, Validation on 1298 examples



100%|█████████████████████████████████████████████████████████████████████████████| 6486/6486 [00:59<00:00, 109.35it/s]


0.475210611320483


100%|█████████████████████████████████████████████████████████████████████████████| 1298/1298 [00:03<00:00, 410.87it/s]


0.48435246088012524
Valid binary acc:  0.8197226502311248



100%|█████████████████████████████████████████████████████████████████████████████| 6486/6486 [01:00<00:00, 106.71it/s]


0.4856039162103754


100%|█████████████████████████████████████████████████████████████████████████████| 1298/1298 [00:03<00:00, 408.42it/s]


0.47971504244763974
Valid binary acc:  0.8204930662557781



100%|█████████████████████████████████████████████████████████████████████████████| 6486/6486 [00:59<00:00, 109.16it/s]


0.48935707377093535


100%|█████████████████████████████████████████████████████████████████████████████| 1298/1298 [00:03<00:00, 421.41it/s]


0.5080279653492987
Valid binary acc:  0.7912172573189522



100%|█████████████████████████████████████████████████████████████████████████████| 6486/6486 [01:01<00:00, 106.07it/s]


0.49518576868606723


100%|█████████████████████████████████████████████████████████████████████████████| 1298/1298 [00:03<00:00, 400.73it/s]


0.486120246201524
Valid binary acc:  0.8204930662557781



100%|█████████████████████████████████████████████████████████████████████████████| 6486/6486 [01:01<00:00, 104.93it/s]


0.49612602906263137


100%|█████████████████████████████████████████████████████████████████████████████| 1298/1298 [00:03<00:00, 416.19it/s]


0.4931870181406224
Valid binary acc:  0.8073959938366718



100%|█████████████████████████████████████████████████████████████████████████████| 6486/6486 [01:01<00:00, 105.45it/s]


0.5017467832771157


100%|█████████████████████████████████████████████████████████████████████████████| 1298/1298 [00:03<00:00, 403.32it/s]


0.5050909720785997
Valid binary acc:  0.8050847457627118



100%|█████████████████████████████████████████████████████████████████████████████| 6486/6486 [01:01<00:00, 105.84it/s]


0.499015933090509


100%|█████████████████████████████████████████████████████████████████████████████| 1298/1298 [00:03<00:00, 424.27it/s]


0.49706059101917344
Valid binary acc:  0.802773497688752



100%|█████████████████████████████████████████████████████████████████████████████| 6486/6486 [01:04<00:00, 100.87it/s]


0.5061654080552082


100%|█████████████████████████████████████████████████████████████████████████████| 1298/1298 [00:03<00:00, 389.67it/s]


0.5331651047905347
Valid binary acc:  0.7596302003081664



100%|██████████████████████████████████████████████████████████████████████████████| 6486/6486 [01:06<00:00, 97.99it/s]


0.5067933029436308


100%|█████████████████████████████████████████████████████████████████████████████| 1298/1298 [00:03<00:00, 378.20it/s]


0.5041512264446779
Valid binary acc:  0.7912172573189522



 16%|████████████▍                                                                | 1052/6486 [00:10<00:52, 103.44it/s]


KeyboardInterrupt: 

In [26]:
def test(test_dataframe):
    test_loader = DataLoader(dataset=Dataset(test_dataframe) ,batch_size=1, shuffle=False, num_workers=0,collate_fn=collate_fn)
    test_result = {}
    for model_name in sorted(os.listdir(Model_Path)):
        print(model_name)
        model_s = Model()
        if torch.cuda.is_available():
            model_s.cuda()
        model_state = torch.load(Model_Path + model_name)['state_dict']
        model_s.load_state_dict(model_state, strict=True)
        model_s.eval()

    epoch_loss_valid_avg, valid_true, valid_pred, valid_name, valid_score = evaluate(model_s, test_loader)
    #test_detail_dataframe = pd.DataFrame({'names': valid_name.cpu(), 'target': valid_true.cpu(), 'prediction': valid_pred.cpu(), 'Attention': valid_score.cpu()})

    return valid_true,valid_pred,valid_score

In [24]:
def test(test_dataframe):
    test_loader = DataLoader(dataset=Dataset(test_dataframe) ,batch_size=1, shuffle=False, num_workers=0,collate_fn=collate_fn)
    test_result = {}
    for model_name in sorted(os.listdir(Model_Path)):
        print(model_name)
        model_s = Model()
        if torch.cuda.is_available():
            model_s.cuda()
        model_s.load_state_dict(torch.load(Model_Path + model_name),strict=True)
        #model_s.load_state_dict(model_state, strict=True)
        model_s.eval()

    epoch_loss_valid_avg, valid_true, valid_pred, valid_name, valid_score = evaluate(model_s, test_loader)
    #test_detail_dataframe = pd.DataFrame({'names': valid_name.cpu(), 'target': valid_true.cpu(), 'prediction': valid_pred.cpu(), 'Attention': valid_score.cpu()})

    return valid_true,valid_pred,valid_score