In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import confusion_matrix
import math
from sklearn.model_selection import KFold

pos = pd.read_csv('/Users/jiaming/Desktop/self_attention/datas/pos_domain_encoding.csv')
pos = pd.read_csv('/Users/jiaming/Desktop/self_attention/datas/pos_encoding_OH_ND.csv')

neg = pd.read_csv('/Users/jiaming/Desktop/self_attention/datas/neg_domain_encoding.csv')[:1892]
neg = pd.read_csv('/Users/jiaming/Desktop/self_attention/datas/neg_encoding_OH_ND.csv')[:1892]

In [None]:
pos = pos.iloc[:,1:]
neg = neg.iloc[:,1:]

In [None]:
pos.shape[1]

In [None]:
pos_arr = np.array_split(pos,int(pos.shape[1] / 5),axis=1) # 9改41
pos_arr = np.stack(pos_arr)
pos_arr = pos_arr.reshape((1892, 41, 5))

neg_arr = np.array_split(neg,int(neg.shape[1] / 5),axis=1)
neg_arr = np.stack(neg_arr)
neg_arr = neg_arr.reshape((1892, 41, 5))

In [None]:
raw_data = np.concatenate((pos_arr,neg_arr),axis = 0)
raw_labels = np.concatenate(([1] * pos_arr.shape[0], [0] * neg_arr.shape[0]),axis = 0)

In [None]:
np.random.seed(123)
indices = np.random.permutation(raw_labels.shape[0])
data_with_extra = raw_data[indices,:,:]
labels = raw_labels[indices]

In [None]:
data_with_extra = np.expand_dims(data_with_extra,1)
labels = np.expand_dims(labels,-1)

In [None]:
def train_test_split(data,label,train_size = 0.8):
    if data.shape[0] != label.shape[0]:
        return
    else:
        num_samples = data.shape[0]
        train_sample = int(num_samples * train_size)

        train_data = data[:train_sample]
        train_labels = label[:train_sample]

        test_data = data[train_sample:]
        test_labels = label[train_sample:]

        return(train_data,train_labels,test_data,test_labels)
    
(train_data,train_labels,test_data,test_labels) = train_test_split(data_with_extra, labels)
print(train_data.shape)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class Net_sa(nn.Module):
    
    class SelfAttention(nn.Module):
        
        def __init__(self, input_dim):
            super(Net_sa.SelfAttention, self).__init__()
            self.input_dim = input_dim
            self.query = nn.Linear(input_dim, input_dim)
            self.key = nn.Linear(input_dim, input_dim)
            self.value = nn.Linear(input_dim, input_dim)
            self.softmax = nn.Softmax(dim=-1)
        def forward(self, x):
            query = self.query(x)
            key = self.key(x)
            value = self.value(x)
            scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.input_dim, dtype=torch.float32))
            attention_weights = self.softmax(scores)
            output = torch.matmul(attention_weights, value)

            return output
    def __init__(self):
        super(Net_sa, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(1, 2, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Dropout(0.6)
        )
        self.attention = self.SelfAttention(2 * 41 * 5) 
        self.seq = nn.Sequential(
            nn.Linear(2 * 41 * 5, 25),
            nn.ReLU(),
            nn.Dropout(0.6),
            nn.Linear(25, 3),
            nn.ReLU(),
            nn.Dropout(0),
            nn.Linear(3, 1),
            nn.Sigmoid()  )
        
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0),-1)
        x = self.attention(x)
        x = self.seq(x)
        return x

In [None]:
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

model = Net_sa()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.93)

num_epochs = 50
kf = KFold(n_splits=5)

for train_indx, val_indx in kf.split(X=train_data):
    
    train_data_splited = torch.from_numpy(train_data[train_indx,:,:,:]).to(device).type(torch.float)
    train_labels_splited = torch.from_numpy(train_labels[train_indx]).to(device).type(torch.float)
    
    val_data_splited = torch.from_numpy(train_data[val_indx,:,:,:]).type(torch.float)
    val_labels_splited = torch.from_numpy(train_labels[val_indx]).type(torch.float)
    
    model = model.to(device)

    for epoch in range(num_epochs):
        outputs = model(train_data_splited)
        loss = criterion(outputs, train_labels_splited)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_losses.append(loss.item())
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")
    
        with torch.no_grad():
            preds = model(val_data_splited)

In [None]:
def predict(model, input):
    data = torch.from_numpy(input).to(device).type(torch.float)
    output = model(data)
    output = output.detach().cpu().numpy()
    return output

test_preds = predict(model,test_data)

In [None]:
def metrics_output(preds,labels):

    metrics_fpr, metrics_tpr, thresholds = roc_curve(labels.squeeze(-1), preds.squeeze(-1))
    roc_auc = auc(metrics_fpr, metrics_tpr)
    
    best_threshold = thresholds[np.argmax(metrics_tpr - metrics_fpr)]

    test_pred_binary = np.where(preds > best_threshold, 1 , 0)

    metrics_tn, metrics_fp, metrics_fn, metrics_tp = confusion_matrix(np.squeeze(labels,axis=-1), np.squeeze(test_pred_binary,axis=-1)).ravel()
    metrics_sn = metrics_tp / (metrics_tp + metrics_fn)
    metrics_sp = metrics_tn / (metrics_tn + metrics_fp)
    metrics_ACC = (metrics_tp + metrics_tn) / (metrics_tn + metrics_fp + metrics_fn + metrics_tp)
    metrics_pre = metrics_tp / (metrics_tp + metrics_fp)
    metrics_F1 = 2 * (metrics_pre * metrics_sn) / (metrics_pre + metrics_sn)
    metrics_MCC = (metrics_tp * metrics_tn - metrics_fp * metrics_fn) / math.sqrt((metrics_tp + metrics_fp)*
                                                                                  (metrics_tp + metrics_fn)*
                                                                                  (metrics_tn + metrics_fp)*
                                                                                  (metrics_tn + metrics_fn))
    
    return (metrics_fpr, metrics_tpr,roc_auc, metrics_sn, metrics_sp, metrics_ACC, metrics_F1, metrics_MCC)


In [None]:
metrics_fpr, metrics_tpr, roc_auc, metrics_sn, metrics_sp, metrics_ACC, metrics_F1, metrics_MCC = metrics_output(preds,val_labels_splited)
print(roc_auc, metrics_sn, metrics_sp, metrics_ACC, metrics_MCC)
metrics_fpr, metrics_tpr, roc_auc, metrics_sn, metrics_sp, metrics_ACC, metrics_F1, metrics_MCC = metrics_output(test_preds,test_labels)
print(roc_auc, metrics_sn, metrics_sp, metrics_ACC, metrics_MCC)


In [None]:
plt.figure()
plt.plot(metrics_fpr, metrics_tpr, color='darkorange', label='DL model (AUC = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc="lower right")
plt.show()

In [None]:
torch.save(model.state_dict(), 'self_attention_seq_parameters.pth')